use axonml_autograd::Variable;
use axonml_nn::{Linear, Module, Parameter};
use axonml_tensor::Tensor;
use crate::attention::KVCacheEntry;
use crate::config::GPT2Config;
use crate::embedding::GPT2Embedding;
use crate::transformer::TransformerDecoder;
#[derive(Debug)]
pub struct GPT2 {
pub config: GPT2Config,
pub wte: GPT2Embedding,
pub h: TransformerDecoder,
}
impl GPT2 {
pub fn new(config: &GPT2Config) -> Self {
let wte = GPT2Embedding::new(
config.vocab_size,
config.n_ctx,
config.n_embd,
config.dropout,
);
let h = TransformerDecoder::new(
config.n_layer,
config.n_embd,
config.n_head,
config.n_ctx,
config.dropout,
config.layer_norm_eps,
&config.activation,
);
Self {
config: config.clone(),
wte,
h,
}
}
pub fn small() -> Self {
Self::new(&GPT2Config::small())
}
pub fn medium() -> Self {
Self::new(&GPT2Config::medium())
}
pub fn large() -> Self {
Self::new(&GPT2Config::large())
}
pub fn xl() -> Self {
Self::new(&GPT2Config::xl())
}
pub fn tiny() -> Self {
Self::new(&GPT2Config::tiny())
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let hidden_states = self.wte.forward_ids(input_ids);
self.h.forward(&hidden_states)
}
pub fn forward_with_past(
&self,
input_ids: &Tensor<u32>,
past_key_values: Option<Vec<(Tensor<f32>, Tensor<f32>)>>,
) -> (Variable, Vec<KVCacheEntry>) {
let hidden_states = self.forward_ids(input_ids);
let cache = past_key_values.unwrap_or_default();
(hidden_states, cache)
}
}
impl Module for GPT2 {
fn forward(&self, input: &Variable) -> Variable {
self.h.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.wte.parameters());
params.extend(self.h.parameters());
params
}
fn train(&mut self) {
self.wte.train();
self.h.train();
}
fn eval(&mut self) {
self.wte.eval();
self.h.eval();
}
}
#[derive(Debug)]
pub struct GPT2LMHead {
pub transformer: GPT2,
pub lm_head: Linear,
}
impl GPT2LMHead {
pub fn new(config: &GPT2Config) -> Self {
let transformer = GPT2::new(config);
let lm_head = Linear::new(config.n_embd, config.vocab_size);
Self {
transformer,
lm_head,
}
}
pub fn small() -> Self {
Self::new(&GPT2Config::small())
}
pub fn medium() -> Self {
Self::new(&GPT2Config::medium())
}
pub fn large() -> Self {
Self::new(&GPT2Config::large())
}
pub fn tiny() -> Self {
Self::new(&GPT2Config::tiny())
}
pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
let hidden_states = self.transformer.forward_ids(input_ids);
self.lm_head.forward(&hidden_states)
}
pub fn forward_with_loss(
&self,
input_ids: &Tensor<u32>,
labels: &Tensor<u32>,
) -> (Variable, Variable) {
let logits = self.forward_ids(input_ids);
let logits_data = logits.data();
let shape = logits_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let _vocab_size = shape[2];
if seq_len > 1 {
let shift_logits = logits.narrow(1, 0, seq_len - 1);
let labels_vec = labels.to_vec();
let mut shift_labels_data = Vec::with_capacity(batch_size * (seq_len - 1));
for b in 0..batch_size {
for s in 1..seq_len {
shift_labels_data.push(labels_vec[b * seq_len + s]);
}
}
let shift_labels =
Tensor::from_vec(shift_labels_data, &[batch_size, seq_len - 1]).unwrap();
let loss = Self::cross_entropy_loss(&shift_logits, &shift_labels);
(logits, loss)
} else {
let zero_loss = Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
(logits, zero_loss)
}
}
fn cross_entropy_loss(logits: &Variable, labels: &Tensor<u32>) -> Variable {
let logits_data = logits.data();
let shape = logits_data.shape();
let batch_size = shape[0];
let seq_len = shape[1];
let vocab_size = shape[2];
let logits_flat = logits.reshape(&[batch_size * seq_len, vocab_size]);
let labels_vec = labels.to_vec();
let valid_labels: Vec<f32> = labels_vec
.iter()
.map(|&l| {
let label = l as usize;
if label < vocab_size { l as f32 } else { 0.0f32 }
})
.collect();
let target_var = Variable::new(
Tensor::from_vec(valid_labels, &[batch_size * seq_len]).unwrap(),
false,
);
use axonml_nn::loss::CrossEntropyLoss;
CrossEntropyLoss::new().compute(&logits_flat, &target_var)
}
pub fn generate(
&self,
input_ids: &Tensor<u32>,
max_new_tokens: usize,
temperature: f32,
top_k: Option<usize>,
) -> Tensor<u32> {
use rand::Rng;
let mut rng = rand::thread_rng();
let batch_size = input_ids.shape()[0];
let initial_len = input_ids.shape()[1];
let mut current_ids: Vec<u32> = input_ids.to_vec().to_vec();
let mut current_len = initial_len;
for _ in 0..max_new_tokens {
if current_len >= self.transformer.config.n_ctx {
break;
}
let input_tensor =
Tensor::from_vec(current_ids.clone(), &[batch_size, current_len]).unwrap();
let logits = self.forward_ids(&input_tensor);
let logits_data = logits.data();
let vocab_size = self.transformer.config.vocab_size;
let last_logits_start = (current_len - 1) * vocab_size;
for b in 0..batch_size {
let batch_offset = b * current_len * vocab_size;
let mut last_logits: Vec<f32> = logits_data.to_vec()[batch_offset
+ last_logits_start
..batch_offset + last_logits_start + vocab_size]
.to_vec();
if temperature != 1.0 {
for logit in &mut last_logits {
*logit /= temperature;
}
}
if let Some(k) = top_k {
let mut indexed: Vec<(usize, f32)> = last_logits
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let threshold = if k < indexed.len() {
indexed[k].1
} else {
f32::NEG_INFINITY
};
for logit in &mut last_logits {
if *logit < threshold {
*logit = f32::NEG_INFINITY;
}
}
}
let max_logit = last_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> =
last_logits.iter().map(|x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
let mut cumsum = 0.0f32;
let sample: f32 = rng.r#gen();
let mut next_token = 0u32;
for (i, &p) in probs.iter().enumerate() {
cumsum += p;
if sample < cumsum {
next_token = i as u32;
break;
}
}
current_ids.push(next_token);
}
current_len += 1;
}
Tensor::from_vec(current_ids, &[batch_size, current_len]).unwrap()
}
pub fn generate_greedy(&self, input_ids: &Tensor<u32>, max_new_tokens: usize) -> Tensor<u32> {
let batch_size = input_ids.shape()[0];
let initial_len = input_ids.shape()[1];
let mut current_ids: Vec<u32> = input_ids.to_vec().to_vec();
let mut current_len = initial_len;
for _ in 0..max_new_tokens {
if current_len >= self.transformer.config.n_ctx {
break;
}
let input_tensor =
Tensor::from_vec(current_ids.clone(), &[batch_size, current_len]).unwrap();
let logits = self.forward_ids(&input_tensor);
let logits_data = logits.data();
let vocab_size = self.transformer.config.vocab_size;
for b in 0..batch_size {
let last_pos = current_len - 1;
let offset = (b * current_len + last_pos) * vocab_size;
let last_logits = &logits_data.to_vec()[offset..offset + vocab_size];
let (next_token, _) = last_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap();
current_ids.push(next_token as u32);
}
current_len += 1;
}
Tensor::from_vec(current_ids, &[batch_size, current_len]).unwrap()
}
}
impl Module for GPT2LMHead {
fn forward(&self, input: &Variable) -> Variable {
let hidden_states = self.transformer.forward(input);
self.lm_head.forward(&hidden_states)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = self.transformer.parameters();
params.extend(self.lm_head.parameters());
params
}
fn train(&mut self) {
self.transformer.train();
}
fn eval(&mut self) {
self.transformer.eval();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpt2_tiny_forward() {
let model = GPT2::tiny();
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
let output = model.forward_ids(&input_ids);
assert_eq!(output.data().shape(), &[2, 4, 128]); }
#[test]
fn test_gpt2_lm_head() {
let config = GPT2Config::tiny();
let model = GPT2LMHead::new(&config);
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4, 5, 6, 7, 8], &[2, 4]).unwrap();
let logits = model.forward_ids(&input_ids);
assert_eq!(logits.data().shape(), &[2, 4, config.vocab_size]);
}
#[test]
fn test_gpt2_generate_greedy() {
let model = GPT2LMHead::tiny();
let input_ids = Tensor::from_vec(vec![1u32, 2], &[1, 2]).unwrap();
let output = model.generate_greedy(&input_ids, 5);
assert_eq!(output.shape()[1], 7); }
#[test]
fn test_gpt2_generate_sampling() {
let model = GPT2LMHead::tiny();
let input_ids = Tensor::from_vec(vec![1u32, 2], &[1, 2]).unwrap();
let output = model.generate(&input_ids, 5, 1.0, Some(50));
assert_eq!(output.shape()[1], 7);
}
#[test]
fn test_gpt2_loss() {
let model = GPT2LMHead::tiny();
let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[1, 4]).unwrap();
let labels = Tensor::from_vec(vec![2u32, 3, 4, 5], &[1, 4]).unwrap();
let (logits, loss) = model.forward_with_loss(&input_ids, &labels);
assert_eq!(logits.data().shape()[0], 1);
assert_eq!(logits.data().shape()[1], 4);
let loss_val = loss.data().to_vec()[0];
assert!(loss_val > 0.0); }
#[test]
fn test_gpt2_parameter_count() {
let model = GPT2::tiny();
let params = model.parameters();
assert!(!params.is_empty());
let total: usize = params.iter().map(|p| p.data().numel()).sum();
assert!(total > 0);
}
}