use crate::config::GptConfig;
use crate::error::{LmError, LmResult};
use crate::layer::{
embedding::{LearnedPositionalEmbedding, TokenEmbedding},
norm::LayerNorm,
transformer::{GptBlock, PastKvCache},
};
#[derive(Debug, Clone)]
pub struct Gpt2Model {
pub config: GptConfig,
pub token_embed: TokenEmbedding,
pub pos_embed: LearnedPositionalEmbedding,
pub blocks: Vec<GptBlock>,
pub ln_f: LayerNorm,
}
impl Gpt2Model {
pub fn new(config: GptConfig) -> LmResult<Self> {
config.validate()?;
let token_embed = TokenEmbedding::new(config.vocab_size, config.n_embd)?;
let pos_embed = LearnedPositionalEmbedding::new(config.n_positions, config.n_embd)?;
let blocks: Vec<GptBlock> = (0..config.n_layers)
.map(|_| {
GptBlock::new(
config.n_embd,
config.n_heads,
config.ffn_intermediate,
config.layer_norm_eps,
)
})
.collect::<LmResult<_>>()?;
let ln_f = LayerNorm::new(config.n_embd, config.layer_norm_eps)?;
Ok(Self {
config,
token_embed,
pos_embed,
blocks,
ln_f,
})
}
pub fn forward(
&self,
token_ids: &[u32],
past_kv: Option<&PastKvCache>,
) -> LmResult<(Vec<f32>, PastKvCache)> {
let seq_len = token_ids.len();
if seq_len == 0 {
return Err(LmError::EmptyInput {
context: "Gpt2Model::forward token_ids",
});
}
let past_len = past_kv.map_or(0, |c| c.past_len());
let total_len = past_len + seq_len;
if total_len > self.config.n_positions {
return Err(LmError::SequenceTooLong {
total_len,
max_pos: self.config.n_positions,
});
}
let tok_emb = self.token_embed.forward(token_ids)?;
let pos_emb = self.pos_embed.forward(seq_len, past_len)?;
let mut h: Vec<f32> = tok_emb
.iter()
.zip(pos_emb.iter())
.map(|(&t, &p)| t + p)
.collect();
let n_kv_heads = self.config.n_heads; let head_dim = self.config.head_dim();
let mut new_kv = PastKvCache::new(self.config.n_layers, n_kv_heads, head_dim);
for (layer_idx, block) in self.blocks.iter().enumerate() {
let past_layer = past_kv.and_then(|c| c.layer(layer_idx).ok());
let (block_out, layer_kv) = block.forward(&h, seq_len, past_layer)?;
h = block_out;
*new_kv.layer_mut(layer_idx)? = layer_kv;
}
let h = self.ln_f.forward(&h, seq_len)?;
let logits = self.lm_head(&h, seq_len)?;
Ok((logits, new_kv))
}
pub fn next_token(
&self,
token_ids: &[u32],
past_kv: Option<&PastKvCache>,
) -> LmResult<(u32, PastKvCache)> {
let (logits, new_kv) = self.forward(token_ids, past_kv)?;
let seq_len = token_ids.len();
let last_start = (seq_len - 1) * self.config.vocab_size;
let last_logits = &logits[last_start..last_start + self.config.vocab_size];
let next_tok = argmax_f32(last_logits)?;
Ok((next_tok, new_kv))
}
pub fn n_params(&self) -> usize {
let embed_params = self.token_embed.vocab_size * self.token_embed.embed_dim
+ self.pos_embed.max_positions * self.pos_embed.embed_dim
+ self.ln_f.dim * 2;
let block_params: usize = self
.blocks
.iter()
.map(|b| {
let ln = 2 * b.ln_1.dim * 2;
let hd = self.config.n_embd;
let ffd = self.config.ffn_intermediate;
let attn = 4 * hd * hd + 3 * hd + hd; let ffn = ffd * hd + ffd + hd * ffd + hd; ln + attn + ffn
})
.sum();
embed_params + block_params
}
fn lm_head(&self, h: &[f32], seq_len: usize) -> LmResult<Vec<f32>> {
let hd = self.config.n_embd;
let vs = self.config.vocab_size;
if h.len() != seq_len * hd {
return Err(LmError::DimensionMismatch {
expected: seq_len * hd,
got: h.len(),
});
}
let embed = &self.token_embed.weight.data;
let mut logits = vec![0.0_f32; seq_len * vs];
for t in 0..seq_len {
let h_row = &h[t * hd..(t + 1) * hd];
let l_row = &mut logits[t * vs..(t + 1) * vs];
for v in 0..vs {
let emb_row = &embed[v * hd..(v + 1) * hd];
l_row[v] = h_row
.iter()
.zip(emb_row.iter())
.map(|(&hi, &ei)| hi * ei)
.sum();
}
}
Ok(logits)
}
}
pub(crate) fn argmax_f32(x: &[f32]) -> LmResult<u32> {
if x.is_empty() {
return Err(LmError::EmptyInput {
context: "argmax_f32",
});
}
let (idx, _) =
x.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(best_i, best_v), (i, &v)| {
if v > best_v { (i, v) } else { (best_i, best_v) }
});
Ok(idx as u32)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GptConfig;
fn tiny_model() -> Gpt2Model {
Gpt2Model::new(GptConfig::tiny()).expect("tiny GptConfig should produce a valid Gpt2Model")
}
#[test]
fn gpt2_model_constructs() {
let m = tiny_model();
assert_eq!(m.blocks.len(), 2);
}
#[test]
fn gpt2_forward_output_shape() {
let m = tiny_model();
let (logits, kv) = m
.forward(&[0, 1, 2], None)
.expect("3-token GPT-2 forward should succeed");
assert_eq!(logits.len(), 3 * m.config.vocab_size);
assert_eq!(kv.n_layers(), 2);
assert_eq!(kv.past_len(), 3);
}
#[test]
fn gpt2_forward_empty_error() {
let m = tiny_model();
assert!(matches!(
m.forward(&[], None),
Err(LmError::EmptyInput { .. })
));
}
#[test]
fn gpt2_forward_sequence_too_long_error() {
let mut cfg = GptConfig::tiny();
cfg.n_positions = 4;
let m = Gpt2Model::new(cfg).expect("modified tiny GptConfig should still be valid");
let ids: Vec<u32> = (0..5).collect();
assert!(matches!(
m.forward(&ids, None),
Err(LmError::SequenceTooLong { .. })
));
}
#[test]
fn gpt2_forward_kv_cache_incremental() {
let m = tiny_model();
let (_, kv1) = m
.forward(&[0, 1], None)
.expect("prefill 2-token GPT-2 forward should succeed");
let (logits2, kv2) = m
.forward(&[2], Some(&kv1))
.expect("incremental decode with cache should succeed");
assert_eq!(logits2.len(), m.config.vocab_size);
assert_eq!(kv2.past_len(), 3);
}
#[test]
fn gpt2_next_token_returns_valid_id() {
let m = tiny_model();
let (tok, _) = m
.next_token(&[0], None)
.expect("next_token on valid GPT-2 should succeed");
assert!((tok as usize) < m.config.vocab_size);
}
#[test]
fn gpt2_weight_tied_lm_head() {
let mut m = Gpt2Model::new(GptConfig::tiny())
.expect("tiny GptConfig for weight-tie test should be valid");
m.token_embed.weight.data = vec![1.0_f32; m.config.vocab_size * m.config.n_embd];
let (logits, _) = m
.forward(&[0], None)
.expect("single-token weight-tied GPT-2 forward should succeed");
assert!(logits.iter().all(|&v| v.abs() < 1e-5));
}
#[test]
fn gpt2_n_params_positive() {
let m = tiny_model();
assert!(m.n_params() > 0);
}
#[test]
fn gpt2_incremental_vs_full_last_position() {
let m = tiny_model();
let (logits_full, _) = m
.forward(&[0, 1], None)
.expect("full 2-token GPT-2 forward should succeed");
let last_full = &logits_full[m.config.vocab_size..];
let (_, kv0) = m
.forward(&[0], None)
.expect("incremental token-0 GPT-2 forward should succeed");
let (logits_incr, _) = m
.forward(&[1], Some(&kv0))
.expect("incremental token-1 GPT-2 with cache should succeed");
for (&full_v, &incr_v) in last_full.iter().zip(logits_incr.iter()) {
assert!(
(full_v - incr_v).abs() < 1e-4,
"full={full_v} incr={incr_v}"
);
}
}
#[test]
fn argmax_f32_correct() {
assert_eq!(
argmax_f32(&[0.1, 0.9, 0.5]).expect("non-empty slice argmax should succeed"),
1
);
assert_eq!(
argmax_f32(&[5.0, 3.0]).expect("non-empty slice argmax should succeed"),
0
);
}
#[test]
fn argmax_f32_empty_error() {
assert!(matches!(argmax_f32(&[]), Err(LmError::EmptyInput { .. })));
}
}