1pub mod config;
52pub mod error;
53pub mod handle;
54pub mod layer;
55pub mod model;
56pub mod ptx_kernels;
57pub mod tokenizer;
58pub mod weights;
59
60pub use config::{GptConfig, LlamaConfig};
63pub use error::{LmError, LmResult};
64pub use handle::{LmHandle, SmVersion};
65pub use layer::{
66 LayerKvCache, LayerNorm, LearnedPositionalEmbedding, MlpFfn, MultiHeadAttention, PastKvCache,
67 RmsNorm, RotaryEmbedding, SwiGluFfn, TokenEmbedding,
68};
69pub use model::{Gpt2Model, LlamaModel};
70pub use tokenizer::{BpeBuilder, BpeTokenizer, Vocab};
71pub use weights::{ModelWeights, WeightTensor};
72
73#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
82 fn e2e_gpt2_tiny_forward() {
83 let cfg = GptConfig::tiny();
85 let m = Gpt2Model::new(cfg).expect("tiny GptConfig should produce a valid Gpt2Model");
86 let token_ids: Vec<u32> = vec![0, 3, 7, 2, 5];
87 let (logits, kv) = m
88 .forward(&token_ids, None)
89 .expect("5-token GPT-2 forward should succeed");
90 assert_eq!(logits.len(), 5 * 16);
92 assert_eq!(kv.past_len(), 5);
93 assert_eq!(kv.n_layers(), 2);
94 assert!(logits.iter().all(|&v| v.abs() < 1e-6));
96 }
97
98 #[test]
101 fn e2e_llama_tiny_forward() {
102 let cfg = LlamaConfig::tiny();
103 let m = LlamaModel::new(cfg).expect("tiny LlamaConfig should produce a valid LlamaModel");
104 let token_ids: Vec<u32> = vec![0, 1, 2, 3];
105 let (logits, kv) = m
106 .forward(&token_ids, None)
107 .expect("4-token LLaMA forward should succeed");
108 assert_eq!(logits.len(), 4 * 16);
109 assert_eq!(kv.past_len(), 4);
110 assert_eq!(kv.n_layers(), 2);
111 }
112
113 #[test]
116 fn e2e_gpt2_incremental_decode_consistent() {
117 let m = Gpt2Model::new(GptConfig::tiny())
120 .expect("tiny GptConfig for incremental decode test should be valid");
121
122 let full_ids = vec![1u32, 2, 3];
124 let (logits_full, _) = m
125 .forward(&full_ids, None)
126 .expect("full 3-token GPT-2 forward should succeed");
127 let vs = m.config.vocab_size;
128 let last_full = logits_full[2 * vs..].to_vec();
129
130 let (_, kv0) = m
132 .forward(&[1u32], None)
133 .expect("incremental token-1 GPT-2 forward should succeed");
134 let (_, kv1) = m
135 .forward(&[2u32], Some(&kv0))
136 .expect("incremental token-2 GPT-2 with cache should succeed");
137 let (logits_3, _) = m
138 .forward(&[3u32], Some(&kv1))
139 .expect("incremental token-3 GPT-2 with cache should succeed");
140
141 assert_eq!(logits_3.len(), vs);
142 for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
143 assert!(
144 (full_v - incr_v).abs() < 1e-4,
145 "GPT-2 incremental mismatch: full={full_v} incr={incr_v}"
146 );
147 }
148 }
149
150 #[test]
153 fn e2e_llama_incremental_decode_consistent() {
154 let m = LlamaModel::new(LlamaConfig::tiny())
155 .expect("tiny LlamaConfig for incremental decode test should be valid");
156
157 let full_ids = vec![0u32, 5, 10];
158 let (logits_full, _) = m
159 .forward(&full_ids, None)
160 .expect("full 3-token LLaMA forward should succeed");
161 let vs = m.config.vocab_size;
162 let last_full = logits_full[2 * vs..].to_vec();
163
164 let (_, kv0) = m
165 .forward(&[0u32], None)
166 .expect("incremental token-0 LLaMA forward should succeed");
167 let (_, kv1) = m
168 .forward(&[5u32], Some(&kv0))
169 .expect("incremental token-5 LLaMA with cache should succeed");
170 let (logits_3, _) = m
171 .forward(&[10u32], Some(&kv1))
172 .expect("incremental token-10 LLaMA with cache should succeed");
173
174 for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
175 assert!(
176 (full_v - incr_v).abs() < 1e-4,
177 "LLaMA incremental mismatch: full={full_v} incr={incr_v}"
178 );
179 }
180 }
181
182 #[test]
185 fn e2e_bpe_encode_decode_roundtrip() {
186 let t = BpeBuilder::new()
188 .add_merge(b"h", b"e") .add_merge(b"l", b"l") .add_merge(b"he", b"ll") .add_merge(b"hell", b"o") .build()
193 .expect("BpeBuilder with 4 chained hello merges should succeed");
194
195 let original = "hello";
196 let ids = t.encode(original).expect("encoding 'hello' should succeed");
197 let decoded = t
198 .decode(&ids)
199 .expect("decoding 'hello' token ids should produce valid UTF-8");
200 assert_eq!(
201 &decoded, original,
202 "BPE round-trip failed: '{original}' → {ids:?} → '{decoded}'"
203 );
204 assert_eq!(
206 ids,
207 vec![259u32],
208 "Expected full merge to one token, got {ids:?}"
209 );
210 }
211
212 #[test]
215 fn e2e_rms_norm_and_layer_norm_correctness() {
216 use crate::layer::{LayerNorm, RmsNorm};
217
218 let dim = 8;
219 let x: Vec<f32> = (0..dim).map(|i| i as f32 - 3.5).collect();
221 let rms_norm = RmsNorm::new(dim, 1e-8).expect("dim=8 RmsNorm should be valid");
225 let rms_out = rms_norm
226 .forward(&x, 1)
227 .expect("1-token RmsNorm forward with matching dim should succeed");
228 let expected_rms = 1.0 / (x.iter().map(|&v| v * v).sum::<f32>() / dim as f32 + 1e-8).sqrt();
229 for (&o, &xi) in rms_out.iter().zip(x.iter()) {
230 assert!(
231 (o - xi * expected_rms).abs() < 1e-5,
232 "RMSNorm out[i]={o} expected {}",
233 xi * expected_rms
234 );
235 }
236
237 let ln = LayerNorm::new(dim, 1e-8).expect("dim=8 LayerNorm should be valid");
239 let ln_out = ln
240 .forward(&x, 1)
241 .expect("1-token LayerNorm forward with matching dim should succeed");
242 let mu: f32 = ln_out.iter().sum::<f32>() / dim as f32;
243 let var: f32 = ln_out.iter().map(|&v| (v - mu) * (v - mu)).sum::<f32>() / dim as f32;
244 assert!(mu.abs() < 1e-5, "LayerNorm mean={mu}");
245 assert!((var - 1.0).abs() < 1e-4, "LayerNorm var={var}");
246 }
247
248 #[test]
251 fn e2e_ptx_kernels_all_sm_versions() {
252 use crate::ptx_kernels::*;
253 let sms = [75u32, 80, 86, 90, 100, 120];
254 for sm in sms {
255 let p1 = embedding_forward_ptx(sm);
256 let p2 = rope_apply_ptx(sm);
257 let p3 = silu_gate_ptx(sm);
258 let p4 = rms_norm_ptx(sm);
259 let p5 = causal_attn_softmax_ptx(sm);
260 for (name, ptx) in [
261 ("embedding_forward", &p1),
262 ("rope_apply", &p2),
263 ("silu_gate", &p3),
264 ("rms_norm", &p4),
265 ("causal_attn_softmax", &p5),
266 ] {
267 let target = format!("sm_{sm}");
268 assert!(
269 ptx.contains(&target),
270 "SM {sm}: kernel '{name}' missing target directive"
271 );
272 }
273 }
274 }
275
276 #[test]
279 fn e2e_llama_gqa_multistep_decode() {
280 let m = LlamaModel::new(LlamaConfig::tiny())
281 .expect("tiny LlamaConfig for GQA multistep test should be valid");
282 let prefill_ids = vec![0u32, 1, 2, 3];
284 let (_, kv) = m
285 .forward(&prefill_ids, None)
286 .expect("4-token prefill LLaMA forward should succeed");
287 assert_eq!(kv.past_len(), 4);
288
289 let mut cur_kv = kv;
291 for step_tok in [4u32, 5, 6] {
292 let (logits, new_kv) = m
293 .forward(&[step_tok], Some(&cur_kv))
294 .expect("single-step LLaMA decode should succeed");
295 assert_eq!(logits.len(), m.config.vocab_size);
296 cur_kv = new_kv;
297 }
298 assert_eq!(cur_kv.past_len(), 7);
299 }
300
301 #[test]
304 fn e2e_vocab_special_token_roundtrip() {
305 use std::collections::HashMap;
306 let tokens = vec![vec![b'a'], vec![b'b'], vec![1u8, 0], vec![2u8, 0]];
307 let special: HashMap<String, u32> = [("<bos>".into(), 2u32), ("<eos>".into(), 3u32)]
308 .into_iter()
309 .collect();
310 let v = Vocab::from_tokens(tokens, special)
311 .expect("4-token vocabulary with BOS/EOS specials should succeed");
312 assert_eq!(v.special_id("<bos>"), Some(2));
313 assert_eq!(v.special_id("<eos>"), Some(3));
314 assert_eq!(v.bytes_to_id(b"a"), Some(0));
315 assert_eq!(
316 v.decode_token(0).expect("token 0 should decode to 'a'"),
317 "a"
318 );
319 }
320
321 #[test]
324 fn e2e_gpt2_greedy_decode_loop() {
325 let m = Gpt2Model::new(GptConfig::tiny())
327 .expect("tiny GptConfig for greedy decode loop test should be valid");
328 let mut token_ids = vec![0u32]; let (_, mut kv) = m
330 .forward(&token_ids, None)
331 .expect("initial GPT-2 forward for greedy decode should succeed");
332
333 for _ in 0..4 {
334 let last_tok = *token_ids
335 .last()
336 .expect("token_ids is never empty during greedy decode loop");
337 let (next_tok, new_kv) = m
338 .next_token(&[last_tok], Some(&kv))
339 .expect("greedy next_token step should succeed");
340 token_ids.push(next_tok);
341 kv = new_kv;
342 }
343
344 assert_eq!(token_ids.len(), 5);
346 for &t in &token_ids {
348 assert!((t as usize) < m.config.vocab_size);
349 }
350 }
351}