Skip to main content

oxicuda_lm/
lib.rs

1//! `oxicuda-lm` — Large Language Model inference primitives.
2//!
3//! This crate provides the **model-layer abstractions** for LLM inference:
4//! a BPE tokenizer, transformer layer building blocks with KV-cache, and
5//! complete GPT-2 and LLaMA-2/3 model implementations.
6//!
7//! # Architecture overview
8//!
9//! ```text
10//!  ┌─────────────────────────────────────────────────┐
11//!  │               oxicuda-lm                        │
12//!  │                                                 │
13//!  │  ┌────────────┐  ┌──────────────────────────┐  │
14//!  │  │ tokenizer  │  │       layer              │  │
15//!  │  │            │  │  ┌──────────────────────┐│  │
16//!  │  │ BpeTokenizer│  │  │ TokenEmbedding       ││  │
17//!  │  │ Vocab      │  │  │ RotaryEmbedding (RoPE)││  │
18//!  │  └────────────┘  │  │ MultiHeadAttention   ││  │
19//!  │                  │  │   + LayerKvCache      ││  │
20//!  │  ┌────────────┐  │  │ MlpFfn / SwiGluFfn   ││  │
21//!  │  │  config    │  │  │ RmsNorm / LayerNorm   ││  │
22//!  │  │            │  │  │ GptBlock / LlamaBlock ││  │
23//!  │  │ GptConfig  │  │  │ PastKvCache          ││  │
24//!  │  │ LlamaConfig│  │  └──────────────────────┘│  │
25//!  │  └────────────┘  └──────────────────────────┘  │
26//!  │                                                 │
27//!  │  ┌────────────────────────────────────────────┐│
28//!  │  │                 model                      ││
29//!  │  │  Gpt2Model  ─── forward → logits + cache   ││
30//!  │  │  LlamaModel ─── forward → logits + cache   ││
31//!  │  └────────────────────────────────────────────┘│
32//!  │                                                 │
33//!  │  ┌────────────────────────────────────────────┐│
34//!  │  │  ptx_kernels (5 GPU kernel PTX strings)    ││
35//!  │  │  weights (ModelWeights, WeightTensor)       ││
36//!  │  └────────────────────────────────────────────┘│
37//!  └─────────────────────────────────────────────────┘
38//! ```
39//!
40//! # Design
41//!
42//! - **Pure Rust**: no C/CUDA SDK at compile time.
43//! - **CPU reference implementations**: all forward passes are pure-Rust
44//!   CPU implementations suitable for testing.  GPU acceleration is provided
45//!   by the PTX kernel strings (see [`ptx_kernels`]) once a CUDA driver is
46//!   available at runtime.
47//! - **No unwrap()** in library code.
48//! - **KV cache**: all attention layers return an updated [`layer::PastKvCache`]
49//!   so incremental decoding is fully supported.
50
51pub mod config;
52pub mod error;
53pub mod handle;
54pub mod layer;
55pub mod model;
56pub mod ptx_kernels;
57pub mod tokenizer;
58pub mod weights;
59
60// ── Convenient top-level re-exports ─────────────────────────────────────────
61
62pub use config::{GptConfig, LlamaConfig};
63pub use error::{LmError, LmResult};
64pub use handle::{LmHandle, SmVersion};
65pub use layer::{
66    LayerKvCache, LayerNorm, LearnedPositionalEmbedding, MlpFfn, MultiHeadAttention, PastKvCache,
67    RmsNorm, RotaryEmbedding, SwiGluFfn, TokenEmbedding,
68};
69pub use model::{Gpt2Model, LlamaModel};
70pub use tokenizer::{BpeBuilder, BpeTokenizer, Vocab};
71pub use weights::{ModelWeights, WeightTensor};
72
73// ─── Integration tests ───────────────────────────────────────────────────────
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    // ── E2E 1: GPT-2 tiny forward pass ───────────────────────────────────
80
81    #[test]
82    fn e2e_gpt2_tiny_forward() {
83        // Minimal GPT-2: 2 layers, 2 heads, hidden=8, vocab=16, max_pos=32
84        let cfg = GptConfig::tiny();
85        let m = Gpt2Model::new(cfg).expect("tiny GptConfig should produce a valid Gpt2Model");
86        let token_ids: Vec<u32> = vec![0, 3, 7, 2, 5];
87        let (logits, kv) = m
88            .forward(&token_ids, None)
89            .expect("5-token GPT-2 forward should succeed");
90        // logits shape: [seq_len × vocab_size] = [5 × 16]
91        assert_eq!(logits.len(), 5 * 16);
92        assert_eq!(kv.past_len(), 5);
93        assert_eq!(kv.n_layers(), 2);
94        // With zero weights, logits should all be zero (embeddings are zero).
95        assert!(logits.iter().all(|&v| v.abs() < 1e-6));
96    }
97
98    // ── E2E 2: LLaMA tiny forward pass ───────────────────────────────────
99
100    #[test]
101    fn e2e_llama_tiny_forward() {
102        let cfg = LlamaConfig::tiny();
103        let m = LlamaModel::new(cfg).expect("tiny LlamaConfig should produce a valid LlamaModel");
104        let token_ids: Vec<u32> = vec![0, 1, 2, 3];
105        let (logits, kv) = m
106            .forward(&token_ids, None)
107            .expect("4-token LLaMA forward should succeed");
108        assert_eq!(logits.len(), 4 * 16);
109        assert_eq!(kv.past_len(), 4);
110        assert_eq!(kv.n_layers(), 2);
111    }
112
113    // ── E2E 3: GPT-2 incremental decode consistency ───────────────────────
114
115    #[test]
116    fn e2e_gpt2_incremental_decode_consistent() {
117        // Full 3-token pass vs token-by-token with KV cache:
118        // The last-position logits must match.
119        let m = Gpt2Model::new(GptConfig::tiny())
120            .expect("tiny GptConfig for incremental decode test should be valid");
121
122        // Full pass
123        let full_ids = vec![1u32, 2, 3];
124        let (logits_full, _) = m
125            .forward(&full_ids, None)
126            .expect("full 3-token GPT-2 forward should succeed");
127        let vs = m.config.vocab_size;
128        let last_full = logits_full[2 * vs..].to_vec();
129
130        // Incremental
131        let (_, kv0) = m
132            .forward(&[1u32], None)
133            .expect("incremental token-1 GPT-2 forward should succeed");
134        let (_, kv1) = m
135            .forward(&[2u32], Some(&kv0))
136            .expect("incremental token-2 GPT-2 with cache should succeed");
137        let (logits_3, _) = m
138            .forward(&[3u32], Some(&kv1))
139            .expect("incremental token-3 GPT-2 with cache should succeed");
140
141        assert_eq!(logits_3.len(), vs);
142        for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
143            assert!(
144                (full_v - incr_v).abs() < 1e-4,
145                "GPT-2 incremental mismatch: full={full_v} incr={incr_v}"
146            );
147        }
148    }
149
150    // ── E2E 4: LLaMA incremental decode consistency ───────────────────────
151
152    #[test]
153    fn e2e_llama_incremental_decode_consistent() {
154        let m = LlamaModel::new(LlamaConfig::tiny())
155            .expect("tiny LlamaConfig for incremental decode test should be valid");
156
157        let full_ids = vec![0u32, 5, 10];
158        let (logits_full, _) = m
159            .forward(&full_ids, None)
160            .expect("full 3-token LLaMA forward should succeed");
161        let vs = m.config.vocab_size;
162        let last_full = logits_full[2 * vs..].to_vec();
163
164        let (_, kv0) = m
165            .forward(&[0u32], None)
166            .expect("incremental token-0 LLaMA forward should succeed");
167        let (_, kv1) = m
168            .forward(&[5u32], Some(&kv0))
169            .expect("incremental token-5 LLaMA with cache should succeed");
170        let (logits_3, _) = m
171            .forward(&[10u32], Some(&kv1))
172            .expect("incremental token-10 LLaMA with cache should succeed");
173
174        for (&full_v, &incr_v) in last_full.iter().zip(logits_3.iter()) {
175            assert!(
176                (full_v - incr_v).abs() < 1e-4,
177                "LLaMA incremental mismatch: full={full_v} incr={incr_v}"
178            );
179        }
180    }
181
182    // ── E2E 5: BPE tokenizer encode/decode round-trip ─────────────────────
183
184    #[test]
185    fn e2e_bpe_encode_decode_roundtrip() {
186        // Build a small BPE tokenizer on top of 256 byte tokens.
187        let t = BpeBuilder::new()
188            .add_merge(b"h", b"e") // "he" → id 256
189            .add_merge(b"l", b"l") // "ll" → id 257
190            .add_merge(b"he", b"ll") // "hell" → id 258
191            .add_merge(b"hell", b"o") // "hello" → id 259
192            .build()
193            .expect("BpeBuilder with 4 chained hello merges should succeed");
194
195        let original = "hello";
196        let ids = t.encode(original).expect("encoding 'hello' should succeed");
197        let decoded = t
198            .decode(&ids)
199            .expect("decoding 'hello' token ids should produce valid UTF-8");
200        assert_eq!(
201            &decoded, original,
202            "BPE round-trip failed: '{original}' → {ids:?} → '{decoded}'"
203        );
204        // Should be fully merged to a single token
205        assert_eq!(
206            ids,
207            vec![259u32],
208            "Expected full merge to one token, got {ids:?}"
209        );
210    }
211
212    // ── E2E 6: RMSNorm and LayerNorm correctness ──────────────────────────
213
214    #[test]
215    fn e2e_rms_norm_and_layer_norm_correctness() {
216        use crate::layer::{LayerNorm, RmsNorm};
217
218        let dim = 8;
219        // Random-ish input with a known structure
220        let x: Vec<f32> = (0..dim).map(|i| i as f32 - 3.5).collect();
221        // [-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
222
223        // RMSNorm with weight=1: rms = sqrt(mean(x^2)) = sqrt(10.5) ≈ 3.240
224        let rms_norm = RmsNorm::new(dim, 1e-8).expect("dim=8 RmsNorm should be valid");
225        let rms_out = rms_norm
226            .forward(&x, 1)
227            .expect("1-token RmsNorm forward with matching dim should succeed");
228        let expected_rms = 1.0 / (x.iter().map(|&v| v * v).sum::<f32>() / dim as f32 + 1e-8).sqrt();
229        for (&o, &xi) in rms_out.iter().zip(x.iter()) {
230            assert!(
231                (o - xi * expected_rms).abs() < 1e-5,
232                "RMSNorm out[i]={o} expected {}",
233                xi * expected_rms
234            );
235        }
236
237        // LayerNorm with weight=1, bias=0: output should have mean≈0, var≈1
238        let ln = LayerNorm::new(dim, 1e-8).expect("dim=8 LayerNorm should be valid");
239        let ln_out = ln
240            .forward(&x, 1)
241            .expect("1-token LayerNorm forward with matching dim should succeed");
242        let mu: f32 = ln_out.iter().sum::<f32>() / dim as f32;
243        let var: f32 = ln_out.iter().map(|&v| (v - mu) * (v - mu)).sum::<f32>() / dim as f32;
244        assert!(mu.abs() < 1e-5, "LayerNorm mean={mu}");
245        assert!((var - 1.0).abs() < 1e-4, "LayerNorm var={var}");
246    }
247
248    // ── E2E 7: PTX kernels for all SM versions ────────────────────────────
249
250    #[test]
251    fn e2e_ptx_kernels_all_sm_versions() {
252        use crate::ptx_kernels::*;
253        let sms = [75u32, 80, 86, 90, 100, 120];
254        for sm in sms {
255            let p1 = embedding_forward_ptx(sm);
256            let p2 = rope_apply_ptx(sm);
257            let p3 = silu_gate_ptx(sm);
258            let p4 = rms_norm_ptx(sm);
259            let p5 = causal_attn_softmax_ptx(sm);
260            for (name, ptx) in [
261                ("embedding_forward", &p1),
262                ("rope_apply", &p2),
263                ("silu_gate", &p3),
264                ("rms_norm", &p4),
265                ("causal_attn_softmax", &p5),
266            ] {
267                let target = format!("sm_{sm}");
268                assert!(
269                    ptx.contains(&target),
270                    "SM {sm}: kernel '{name}' missing target directive"
271                );
272            }
273        }
274    }
275
276    // ── E2E 8: GQA with LLaMA-3 style (4Q / 2KV) multi-step decode ────────
277
278    #[test]
279    fn e2e_llama_gqa_multistep_decode() {
280        let m = LlamaModel::new(LlamaConfig::tiny())
281            .expect("tiny LlamaConfig for GQA multistep test should be valid");
282        // Prefill phase: 4 tokens
283        let prefill_ids = vec![0u32, 1, 2, 3];
284        let (_, kv) = m
285            .forward(&prefill_ids, None)
286            .expect("4-token prefill LLaMA forward should succeed");
287        assert_eq!(kv.past_len(), 4);
288
289        // Decode: 3 more tokens one at a time
290        let mut cur_kv = kv;
291        for step_tok in [4u32, 5, 6] {
292            let (logits, new_kv) = m
293                .forward(&[step_tok], Some(&cur_kv))
294                .expect("single-step LLaMA decode should succeed");
295            assert_eq!(logits.len(), m.config.vocab_size);
296            cur_kv = new_kv;
297        }
298        assert_eq!(cur_kv.past_len(), 7);
299    }
300
301    // ── E2E 9: Vocab special token round-trip ────────────────────────────
302
303    #[test]
304    fn e2e_vocab_special_token_roundtrip() {
305        use std::collections::HashMap;
306        let tokens = vec![vec![b'a'], vec![b'b'], vec![1u8, 0], vec![2u8, 0]];
307        let special: HashMap<String, u32> = [("<bos>".into(), 2u32), ("<eos>".into(), 3u32)]
308            .into_iter()
309            .collect();
310        let v = Vocab::from_tokens(tokens, special)
311            .expect("4-token vocabulary with BOS/EOS specials should succeed");
312        assert_eq!(v.special_id("<bos>"), Some(2));
313        assert_eq!(v.special_id("<eos>"), Some(3));
314        assert_eq!(v.bytes_to_id(b"a"), Some(0));
315        assert_eq!(
316            v.decode_token(0).expect("token 0 should decode to 'a'"),
317            "a"
318        );
319    }
320
321    // ── E2E 10: GPT-2 next_token greedy decode loop ───────────────────────
322
323    #[test]
324    fn e2e_gpt2_greedy_decode_loop() {
325        // Run 5 steps of greedy decode with KV cache.
326        let m = Gpt2Model::new(GptConfig::tiny())
327            .expect("tiny GptConfig for greedy decode loop test should be valid");
328        let mut token_ids = vec![0u32]; // start token
329        let (_, mut kv) = m
330            .forward(&token_ids, None)
331            .expect("initial GPT-2 forward for greedy decode should succeed");
332
333        for _ in 0..4 {
334            let last_tok = *token_ids
335                .last()
336                .expect("token_ids is never empty during greedy decode loop");
337            let (next_tok, new_kv) = m
338                .next_token(&[last_tok], Some(&kv))
339                .expect("greedy next_token step should succeed");
340            token_ids.push(next_tok);
341            kv = new_kv;
342        }
343
344        // Should have generated 5 tokens total (1 initial + 4 decoded)
345        assert_eq!(token_ids.len(), 5);
346        // All generated token ids must be in vocab range
347        for &t in &token_ids {
348            assert!((t as usize) < m.config.vocab_size);
349        }
350    }
351}