use crate::tensor::Tensor;
use crate::transformer::TransformerBlock;
use crate::weights::ModelWeights;
pub struct GPTModel {
_n_layers: usize,
_n_vocab: usize,
pub n_ctx: usize,
_hidden_size: usize,
_n_heads: usize,
_intermediate_size: usize,
pub blocks: Vec<TransformerBlock>,
}
impl GPTModel {
pub fn new(
n_layers: usize,
n_vocab: usize,
n_ctx: usize,
hidden_size: usize,
n_heads: usize,
intermediate_size: usize,
) -> Self {
let blocks = (0..n_layers)
.map(|_| TransformerBlock::new(hidden_size, n_heads, intermediate_size))
.collect();
Self {
_n_layers: n_layers,
_n_vocab: n_vocab,
n_ctx,
_hidden_size: hidden_size,
_n_heads: n_heads,
_intermediate_size: intermediate_size,
blocks,
}
}
pub fn forward(&self, token_ids: &[u32], weights: &ModelWeights) -> Tensor {
let _batch_size = 1; let seq_len = token_ids.len();
let token_embeds = self.get_token_embeddings(token_ids, weights);
let pos_embeds = self.get_position_embeddings(seq_len, weights);
let mut x = token_embeds.add(&pos_embeds);
for (layer_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, weights, layer_idx);
}
let ln_f_weight = &weights.tensors["ln_f.weight"];
let ln_f_bias = &weights.tensors["ln_f.bias"];
let normalized = x.layer_norm(ln_f_weight, ln_f_bias, 1e-5);
let wte = &weights.tensors["wte.weight"];
let wte_t = Tensor {
data: wte.data.t().to_owned(),
};
let logits = normalized.matmul(&wte_t);
logits
}
fn get_token_embeddings(&self, token_ids: &[u32], weights: &ModelWeights) -> Tensor {
let wte = &weights.tensors["wte.weight"];
let vocab_size = wte.shape()[0];
let hidden_size = wte.shape()[1];
let seq_len = token_ids.len();
let mut embedding_data = Vec::with_capacity(seq_len * hidden_size);
let wte_data = wte.data.as_slice().unwrap();
for &token_id in token_ids {
let token_id = token_id as usize;
assert!(
token_id < vocab_size,
"Token ID {} out of vocabulary range",
token_id
);
let start_idx = token_id * hidden_size;
let end_idx = start_idx + hidden_size;
embedding_data.extend_from_slice(&wte_data[start_idx..end_idx]);
}
Tensor::from_shape(&[1, seq_len, hidden_size], embedding_data)
}
fn get_position_embeddings(&self, seq_len: usize, weights: &ModelWeights) -> Tensor {
let wpe = &weights.tensors["wpe.weight"];
let max_pos = wpe.shape()[0];
let hidden_size = wpe.shape()[1];
assert!(
seq_len <= max_pos,
"Sequence length {} exceeds maximum position {}",
seq_len,
max_pos
);
let mut pos_embedding_data = Vec::with_capacity(seq_len * hidden_size);
let wpe_data = wpe.data.as_slice().unwrap();
for pos in 0..seq_len {
let start_idx = pos * hidden_size;
let end_idx = start_idx + hidden_size;
pos_embedding_data.extend_from_slice(&wpe_data[start_idx..end_idx]);
}
Tensor::from_shape(&[1, seq_len, hidden_size], pos_embedding_data)
}
pub fn generate_next_token(&self, token_ids: &[u32], weights: &ModelWeights) -> u32 {
let logits = self.forward(token_ids, weights);
let seq_len = token_ids.len();
let vocab_size = logits.shape()[2];
let last_logits_start = (seq_len - 1) * vocab_size;
let last_logits_end = last_logits_start + vocab_size;
let logits_data = logits.data.as_slice().unwrap();
let last_logits = &logits_data[last_logits_start..last_logits_end];
let mut max_idx = 0;
let mut max_val = last_logits[0];
for (idx, &val) in last_logits.iter().enumerate() {
if val > max_val {
max_val = val;
max_idx = idx;
}
}
max_idx as u32
}
pub fn generate(
&self,
prompt_tokens: &[u32],
max_new_tokens: usize,
weights: &ModelWeights,
) -> Vec<u32> {
let mut tokens = prompt_tokens.to_vec();
for _ in 0..max_new_tokens {
let context_tokens = if tokens.len() > self.n_ctx {
&tokens[tokens.len() - self.n_ctx..]
} else {
&tokens
};
let next_token = self.generate_next_token(context_tokens, weights);
tokens.push(next_token);
}
tokens
}
}