use crate::config::LlamaConfig;
use crate::error::{LmError, LmResult};
use crate::layer::{
embedding::TokenEmbedding,
norm::RmsNorm,
transformer::{LlamaBlock, PastKvCache},
};
use crate::model::gpt::argmax_f32;
use crate::weights::WeightTensor;
#[derive(Debug, Clone)]
pub struct LlamaModel {
pub config: LlamaConfig,
pub embed: TokenEmbedding,
pub blocks: Vec<LlamaBlock>,
pub norm: RmsNorm,
pub lm_head: WeightTensor,
}
impl LlamaModel {
pub fn new(config: LlamaConfig) -> LmResult<Self> {
config.validate()?;
let embed = TokenEmbedding::new(config.vocab_size, config.hidden_dim)?;
let blocks: Vec<LlamaBlock> = (0..config.n_layers)
.map(|_| {
LlamaBlock::new(
config.hidden_dim,
config.n_heads,
config.n_kv_heads,
config.intermediate_dim,
config.max_position_embeddings,
config.rope_theta,
config.rms_norm_eps,
)
})
.collect::<LmResult<_>>()?;
let norm = RmsNorm::new(config.hidden_dim, config.rms_norm_eps)?;
let lm_head = WeightTensor::zeros(&[config.vocab_size, config.hidden_dim]);
Ok(Self {
config,
embed,
blocks,
norm,
lm_head,
})
}
pub fn forward(
&self,
token_ids: &[u32],
past_kv: Option<&PastKvCache>,
) -> LmResult<(Vec<f32>, PastKvCache)> {
let seq_len = token_ids.len();
if seq_len == 0 {
return Err(LmError::EmptyInput {
context: "LlamaModel::forward token_ids",
});
}
let past_len = past_kv.map_or(0, |c| c.past_len());
let total_len = past_len + seq_len;
if total_len > self.config.max_position_embeddings {
return Err(LmError::SequenceTooLong {
total_len,
max_pos: self.config.max_position_embeddings,
});
}
let mut h = self.embed.forward(token_ids)?;
let n_kv_heads = self.config.n_kv_heads;
let head_dim = self.config.head_dim();
let mut new_kv = PastKvCache::new(self.config.n_layers, n_kv_heads, head_dim);
for (layer_idx, block) in self.blocks.iter().enumerate() {
let past_layer = past_kv.and_then(|c| c.layer(layer_idx).ok());
let (block_out, layer_kv) = block.forward(&h, seq_len, past_layer)?;
h = block_out;
*new_kv.layer_mut(layer_idx)? = layer_kv;
}
let h = self.norm.forward(&h, seq_len)?;
let logits = self.apply_lm_head(&h, seq_len)?;
Ok((logits, new_kv))
}
pub fn next_token(
&self,
token_ids: &[u32],
past_kv: Option<&PastKvCache>,
) -> LmResult<(u32, PastKvCache)> {
let (logits, new_kv) = self.forward(token_ids, past_kv)?;
let seq_len = token_ids.len();
let last_start = (seq_len - 1) * self.config.vocab_size;
let last_logits = &logits[last_start..last_start + self.config.vocab_size];
let next_tok = argmax_f32(last_logits)?;
Ok((next_tok, new_kv))
}
pub fn n_params(&self) -> usize {
let embed_params = self.config.vocab_size * self.config.hidden_dim;
let norm_params = self.config.hidden_dim; let lm_head_params = self.config.vocab_size * self.config.hidden_dim;
let block_params: usize = self
.blocks
.iter()
.map(|_| {
let hd = self.config.hidden_dim;
let id = self.config.intermediate_dim;
let kv = self.config.n_kv_heads * self.config.head_dim();
let norms = 2 * hd;
let attn = hd * hd + kv * hd + kv * hd + hd * hd; let ffn = id * hd + id * hd + hd * id; norms + attn + ffn
})
.sum();
embed_params + norm_params + lm_head_params + block_params
}
fn apply_lm_head(&self, h: &[f32], seq_len: usize) -> LmResult<Vec<f32>> {
let hd = self.config.hidden_dim;
let vs = self.config.vocab_size;
if h.len() != seq_len * hd {
return Err(LmError::DimensionMismatch {
expected: seq_len * hd,
got: h.len(),
});
}
let head_data = &self.lm_head.data;
let mut logits = vec![0.0_f32; seq_len * vs];
for t in 0..seq_len {
let h_row = &h[t * hd..(t + 1) * hd];
let l_row = &mut logits[t * vs..(t + 1) * vs];
for v in 0..vs {
let lh_row = &head_data[v * hd..(v + 1) * hd];
l_row[v] = h_row
.iter()
.zip(lh_row.iter())
.map(|(&hi, &li)| hi * li)
.sum();
}
}
Ok(logits)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::LlamaConfig;
fn tiny_model() -> LlamaModel {
LlamaModel::new(LlamaConfig::tiny()).unwrap()
}
#[test]
fn llama_model_constructs() {
let m = tiny_model();
assert_eq!(m.blocks.len(), 2);
assert_eq!(m.config.n_kv_heads, 2);
}
#[test]
fn llama_forward_output_shape() {
let m = tiny_model();
let (logits, kv) = m.forward(&[0, 1, 2], None).unwrap();
assert_eq!(logits.len(), 3 * m.config.vocab_size);
assert_eq!(kv.past_len(), 3);
assert_eq!(kv.n_layers(), 2);
}
#[test]
fn llama_forward_empty_error() {
let m = tiny_model();
assert!(matches!(
m.forward(&[], None),
Err(LmError::EmptyInput { .. })
));
}
#[test]
fn llama_forward_too_long_error() {
let mut cfg = LlamaConfig::tiny();
cfg.max_position_embeddings = 4;
let m = LlamaModel::new(cfg).unwrap();
let ids: Vec<u32> = (0..5).collect();
assert!(matches!(
m.forward(&ids, None),
Err(LmError::SequenceTooLong { .. })
));
}
#[test]
fn llama_kv_cache_incremental() {
let m = tiny_model();
let (_, kv1) = m.forward(&[0, 1], None).unwrap();
let (logits2, kv2) = m.forward(&[2], Some(&kv1)).unwrap();
assert_eq!(logits2.len(), m.config.vocab_size);
assert_eq!(kv2.past_len(), 3);
}
#[test]
fn llama_next_token_valid_id() {
let m = tiny_model();
let (tok, _) = m.next_token(&[0], None).unwrap();
assert!((tok as usize) < m.config.vocab_size);
}
#[test]
fn llama_n_params_positive() {
let m = tiny_model();
assert!(m.n_params() > 0);
}
#[test]
fn llama_incremental_vs_full_last_position() {
let m = tiny_model();
let (logits_full, _) = m.forward(&[0, 1], None).unwrap();
let last_full = &logits_full[m.config.vocab_size..];
let (_, kv0) = m.forward(&[0], None).unwrap();
let (logits_incr, _) = m.forward(&[1], Some(&kv0)).unwrap();
for (&full_v, &incr_v) in last_full.iter().zip(logits_incr.iter()) {
assert!(
(full_v - incr_v).abs() < 1e-4,
"full={full_v} incr={incr_v}"
);
}
}
#[test]
fn llama_lm_head_ones_gives_nonzero_logits() {
let mut m = tiny_model();
m.embed.weight.data = vec![1.0_f32; m.config.vocab_size * m.config.hidden_dim];
m.lm_head.data = vec![1.0_f32; m.config.vocab_size * m.config.hidden_dim];
let (logits, _) = m.forward(&[0], None).unwrap();
let all_zero = logits.iter().all(|&v| v.abs() < 1e-6);
assert!(!all_zero, "expected non-zero logits with ones weights");
}
#[test]
fn llama_gqa_factor_consistent() {
let m = tiny_model();
assert_eq!(m.config.gqa_factor(), 2);
let (_, _) = m.forward(&[0], None).unwrap();
}
}