pub mod config;
pub mod error;
pub mod handle;
pub mod layer;
pub mod model;
pub mod ptx_kernels;
pub mod tokenizer;
pub mod weights;
pub use config::{GptConfig, LlamaConfig};
pub use error::{LmError, LmResult};
pub use handle::{LmHandle, SmVersion};
pub use layer::{
LayerKvCache, LayerNorm, LearnedPositionalEmbedding, MlpFfn, MultiHeadAttention, PastKvCache,
RmsNorm, RotaryEmbedding, SwiGluFfn, TokenEmbedding,
};
pub use model::{Gpt2Model, LlamaModel};
pub use tokenizer::{BpeBuilder, BpeTokenizer, Vocab};
pub use weights::{ModelWeights, WeightTensor};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn e2e_gpt2_tiny_forward() {
let cfg = GptConfig::tiny();
let m = Gpt2Model::new(cfg).expect("tiny GptConfig should produce a valid Gpt2Model");
let token_ids: Vec<u32> = vec![0, 3, 7, 2, 5];
let (logits, kv) = m
.forward(&token_ids, None)
.expect("5-token GPT-2 forward should succeed");
assert_eq!(logits.len(), 5 * 16);
assert_eq!(kv.past_len(), 5);
assert_eq!(kv.n_layers(), 2);
assert!(logits.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn e2e_llama_tiny_forward() {
let cfg = LlamaConfig::tiny();
let m = LlamaModel::new(cfg).expect("tiny LlamaConfig should produce a valid LlamaModel");
let token_ids: Vec<u32> = vec![0, 1, 2, 3];
let (logits, kv) = m
.forward(&token_ids, None)
.expect("4-token LLaMA forward should succeed");
assert_eq!(logits.len(), 4 * 16);
assert_eq!(kv.past_len(), 4);
assert_eq!(kv.n_layers(), 2);
}
#[test]
fn e2e_gpt2_incremental_decode_consistent() {
let m = Gpt2Model::new(GptConfig::tiny())
.expect("tiny GptConfig for incremental decode test should be valid");
let full_ids = vec![1u32, 2, 3];
let (logits_full, _) = m
.forward(&full_ids, None)
.expect("full 3-token GPT-2 forward should succeed");
let vs = m.config.vocab_size;
let last_full = logits_full[2 * vs..].to_vec();
let (_, kv0) = m
.forward(&[1u32], None)
.expect("incremental token-1 GPT-2 forward should succeed");
let (_, kv1) = m
.forward(&[2u32], Some(&kv0))
.expect("incremental token-2 GPT-2 with cache should succeed");
let (logits_3, _) = m
.forward(&[3u32], Some(&kv1))
.expect("incremental token-3 GPT-2 with cache should succeed");
assert_eq!(logits_3.len(), vs);
for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
assert!(
(full_v - incr_v).abs() < 1e-4,
"GPT-2 incremental mismatch: full={full_v} incr={incr_v}"
);
}
}
#[test]
fn e2e_llama_incremental_decode_consistent() {
let m = LlamaModel::new(LlamaConfig::tiny())
.expect("tiny LlamaConfig for incremental decode test should be valid");
let full_ids = vec![0u32, 5, 10];
let (logits_full, _) = m
.forward(&full_ids, None)
.expect("full 3-token LLaMA forward should succeed");
let vs = m.config.vocab_size;
let last_full = logits_full[2 * vs..].to_vec();
let (_, kv0) = m
.forward(&[0u32], None)
.expect("incremental token-0 LLaMA forward should succeed");
let (_, kv1) = m
.forward(&[5u32], Some(&kv0))
.expect("incremental token-5 LLaMA with cache should succeed");
let (logits_3, _) = m
.forward(&[10u32], Some(&kv1))
.expect("incremental token-10 LLaMA with cache should succeed");
for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
assert!(
(full_v - incr_v).abs() < 1e-4,
"LLaMA incremental mismatch: full={full_v} incr={incr_v}"
);
}
}
#[test]
fn e2e_bpe_encode_decode_roundtrip() {
let t = BpeBuilder::new()
.add_merge(b"h", b"e") .add_merge(b"l", b"l") .add_merge(b"he", b"ll") .add_merge(b"hell", b"o") .build()
.expect("BpeBuilder with 4 chained hello merges should succeed");
let original = "hello";
let ids = t.encode(original).expect("encoding 'hello' should succeed");
let decoded = t
.decode(&ids)
.expect("decoding 'hello' token ids should produce valid UTF-8");
assert_eq!(
&decoded, original,
"BPE round-trip failed: '{original}' → {ids:?} → '{decoded}'"
);
assert_eq!(
ids,
vec![259u32],
"Expected full merge to one token, got {ids:?}"
);
}
#[test]
fn e2e_rms_norm_and_layer_norm_correctness() {
use crate::layer::{LayerNorm, RmsNorm};
let dim = 8;
let x: Vec<f32> = (0..dim).map(|i| i as f32 - 3.5).collect();
let rms_norm = RmsNorm::new(dim, 1e-8).expect("dim=8 RmsNorm should be valid");
let rms_out = rms_norm
.forward(&x, 1)
.expect("1-token RmsNorm forward with matching dim should succeed");
let expected_rms = 1.0 / (x.iter().map(|&v| v * v).sum::<f32>() / dim as f32 + 1e-8).sqrt();
for (&o, &xi) in rms_out.iter().zip(x.iter()) {
assert!(
(o - xi * expected_rms).abs() < 1e-5,
"RMSNorm out[i]={o} expected {}",
xi * expected_rms
);
}
let ln = LayerNorm::new(dim, 1e-8).expect("dim=8 LayerNorm should be valid");
let ln_out = ln
.forward(&x, 1)
.expect("1-token LayerNorm forward with matching dim should succeed");
let mu: f32 = ln_out.iter().sum::<f32>() / dim as f32;
let var: f32 = ln_out.iter().map(|&v| (v - mu) * (v - mu)).sum::<f32>() / dim as f32;
assert!(mu.abs() < 1e-5, "LayerNorm mean={mu}");
assert!((var - 1.0).abs() < 1e-4, "LayerNorm var={var}");
}
#[test]
fn e2e_ptx_kernels_all_sm_versions() {
use crate::ptx_kernels::*;
let sms = [75u32, 80, 86, 90, 100, 120];
for sm in sms {
let p1 = embedding_forward_ptx(sm);
let p2 = rope_apply_ptx(sm);
let p3 = silu_gate_ptx(sm);
let p4 = rms_norm_ptx(sm);
let p5 = causal_attn_softmax_ptx(sm);
for (name, ptx) in [
("embedding_forward", &p1),
("rope_apply", &p2),
("silu_gate", &p3),
("rms_norm", &p4),
("causal_attn_softmax", &p5),
] {
let target = format!("sm_{sm}");
assert!(
ptx.contains(&target),
"SM {sm}: kernel '{name}' missing target directive"
);
}
}
}
#[test]
fn e2e_llama_gqa_multistep_decode() {
let m = LlamaModel::new(LlamaConfig::tiny())
.expect("tiny LlamaConfig for GQA multistep test should be valid");
let prefill_ids = vec![0u32, 1, 2, 3];
let (_, kv) = m
.forward(&prefill_ids, None)
.expect("4-token prefill LLaMA forward should succeed");
assert_eq!(kv.past_len(), 4);
let mut cur_kv = kv;
for step_tok in [4u32, 5, 6] {
let (logits, new_kv) = m
.forward(&[step_tok], Some(&cur_kv))
.expect("single-step LLaMA decode should succeed");
assert_eq!(logits.len(), m.config.vocab_size);
cur_kv = new_kv;
}
assert_eq!(cur_kv.past_len(), 7);
}
#[test]
fn e2e_vocab_special_token_roundtrip() {
use std::collections::HashMap;
let tokens = vec![vec![b'a'], vec![b'b'], vec![1u8, 0], vec![2u8, 0]];
let special: HashMap<String, u32> = [("<bos>".into(), 2u32), ("<eos>".into(), 3u32)]
.into_iter()
.collect();
let v = Vocab::from_tokens(tokens, special)
.expect("4-token vocabulary with BOS/EOS specials should succeed");
assert_eq!(v.special_id("<bos>"), Some(2));
assert_eq!(v.special_id("<eos>"), Some(3));
assert_eq!(v.bytes_to_id(b"a"), Some(0));
assert_eq!(
v.decode_token(0).expect("token 0 should decode to 'a'"),
"a"
);
}
#[test]
fn e2e_gpt2_greedy_decode_loop() {
let m = Gpt2Model::new(GptConfig::tiny())
.expect("tiny GptConfig for greedy decode loop test should be valid");
let mut token_ids = vec![0u32]; let (_, mut kv) = m
.forward(&token_ids, None)
.expect("initial GPT-2 forward for greedy decode should succeed");
for _ in 0..4 {
let last_tok = *token_ids
.last()
.expect("token_ids is never empty during greedy decode loop");
let (next_tok, new_kv) = m
.next_token(&[last_tok], Some(&kv))
.expect("greedy next_token step should succeed");
token_ids.push(next_tok);
kv = new_kv;
}
assert_eq!(token_ids.len(), 5);
for &t in &token_ids {
assert!((t as usize) < m.config.vocab_size);
}
}
}