Skip to main content

llama_runtime/
lib.rs

1//! # llama-runtime
2//!
3//! Runtime execution and verification helpers for llama.rs.
4//!
5//! This crate includes:
6//! - A `MockEngine` demonstrating the narrow-waist `LlamaEngine` trait
7//! - A Phase-1 verification harness for LLAMA-006:
8//!   `full_forward(prompt)` logits vs `prefill(prompt[:-1]) + decode(last_token)` logits.
9
10use llama_engine::{
11    DecodeResult, LlamaEngine, LlamaError, ModelHandle, ModelSpec, PrefillResult, Result, Session,
12    TokenId,
13};
14use llama_kv::{KVLayout, LayerKVCache};
15use llama_sampling::{Sampler, SamplingConfig, SamplingStrategy};
16use llama_tokenizer::{Tokenizer, WhitespaceTokenizer};
17use std::sync::Mutex;
18
19// ---------------------------------------------------------------------------
20// MockEngine — Milestone A narrow-waist demonstration
21// ---------------------------------------------------------------------------
22
23/// A mock engine implementation for Milestone A.
24///
25/// Uses a simple whitespace tokenizer and greedy sampler to demonstrate
26/// the "narrow waist" API without requiring a real model or MLX backend.
27pub struct MockEngine {
28    tokenizer: WhitespaceTokenizer,
29    sampler: Mutex<Sampler>,
30}
31
32impl MockEngine {
33    pub fn new() -> Self {
34        Self {
35            tokenizer: WhitespaceTokenizer::new(),
36            sampler: Mutex::new(
37                Sampler::new(SamplingConfig {
38                    strategy: SamplingStrategy::Greedy,
39                    ..SamplingConfig::default()
40                })
41                .expect("default config is valid"),
42            ),
43        }
44    }
45}
46
47impl Default for MockEngine {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl LlamaEngine for MockEngine {
54    fn load_model(&self, _spec: &ModelSpec) -> Result<ModelHandle> {
55        Ok(ModelHandle)
56    }
57
58    fn tokenize(&self, text: &str) -> Result<Vec<TokenId>> {
59        self.tokenizer
60            .encode(text)
61            .map_err(|e| LlamaError::Tokenization(e.to_string()))
62    }
63
64    fn detokenize(&self, tokens: &[TokenId]) -> Result<String> {
65        self.tokenizer
66            .decode(tokens)
67            .map_err(|e| LlamaError::Tokenization(e.to_string()))
68    }
69
70    fn prefill(&self, _session: &mut Session, tokens: &[TokenId]) -> Result<PrefillResult> {
71        Ok(PrefillResult {
72            tokens_processed: tokens.len(),
73        })
74    }
75
76    fn decode(&self, _session: &mut Session) -> Result<DecodeResult> {
77        let mock_logits = vec![0.1, 0.5, 0.1, 0.1, 0.2];
78        let mut sampler = self.sampler.lock().unwrap();
79        let token = sampler
80            .sample(&mock_logits, &[])
81            .map_err(|e| LlamaError::Inference(format!("{}", e)))?;
82        Ok(DecodeResult {
83            token: token as TokenId,
84        })
85    }
86
87    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
88        Ok(texts.iter().map(|_| vec![0.0; 128]).collect())
89    }
90}
91
92// ---------------------------------------------------------------------------
93// RuntimeVerifier — LLAMA-006 KV equivalence true test
94// ---------------------------------------------------------------------------
95
96/// Errors for runtime verification routines.
97#[derive(Debug, thiserror::Error)]
98pub enum RuntimeError {
99    #[error("prompt must contain at least 2 tokens")]
100    PromptTooShort,
101    #[error("invalid token id: {0}")]
102    InvalidToken(i32),
103    #[error("kv error: {0}")]
104    Kv(#[from] llama_kv::KVError),
105}
106
107/// Result of a KV equivalence run.
108#[derive(Debug, Clone, Copy)]
109pub struct KvEquivalenceReport {
110    pub max_abs_diff: f32,
111}
112
113/// Minimal deterministic runtime verifier used for LLAMA-006 true tests.
114pub struct RuntimeVerifier {
115    model: ToyModel,
116}
117
118impl Default for RuntimeVerifier {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl RuntimeVerifier {
125    pub fn new() -> Self {
126        Self {
127            model: ToyModel::default(),
128        }
129    }
130
131    /// True test:
132    /// full_forward(prompt) logits == prefill(prompt[:-1]) + decode(last_token) logits.
133    pub fn verify_kv_equivalence(
134        &self,
135        prompt: &[i32],
136    ) -> std::result::Result<KvEquivalenceReport, RuntimeError> {
137        if prompt.len() < 2 {
138            return Err(RuntimeError::PromptTooShort);
139        }
140
141        let full = self.model.full_forward(prompt)?;
142        let kv = self.model.prefill_then_decode(prompt, false)?;
143        Ok(KvEquivalenceReport {
144            max_abs_diff: max_abs_diff(&full, &kv),
145        })
146    }
147
148    /// Same flow as `verify_kv_equivalence` but injects an off-by-one position bug
149    /// in decode. Used to prove the true test catches indexing errors.
150    pub fn verify_with_off_by_one_bug(
151        &self,
152        prompt: &[i32],
153    ) -> std::result::Result<KvEquivalenceReport, RuntimeError> {
154        if prompt.len() < 2 {
155            return Err(RuntimeError::PromptTooShort);
156        }
157
158        let full = self.model.full_forward(prompt)?;
159        let kv_bug = self.model.prefill_then_decode(prompt, true)?;
160        Ok(KvEquivalenceReport {
161            max_abs_diff: max_abs_diff(&full, &kv_bug),
162        })
163    }
164}
165
166#[derive(Debug)]
167struct ToyModel {
168    embeddings: [[f32; 2]; 8],
169    out_proj: [[f32; 8]; 2],
170}
171
172impl Default for ToyModel {
173    fn default() -> Self {
174        Self {
175            embeddings: [
176                [0.4, -0.2],
177                [0.1, 0.9],
178                [0.8, 0.2],
179                [-0.5, 0.7],
180                [0.3, -0.9],
181                [-0.2, -0.3],
182                [0.6, 0.4],
183                [-0.7, 0.5],
184            ],
185            out_proj: [
186                [0.3, -0.1, 0.2, 0.5, -0.4, 0.1, 0.2, -0.3],
187                [-0.2, 0.6, -0.3, 0.1, 0.4, -0.5, 0.2, 0.3],
188            ],
189        }
190    }
191}
192
193impl ToyModel {
194    fn full_forward(&self, prompt: &[i32]) -> std::result::Result<Vec<f32>, RuntimeError> {
195        let seq_len = prompt.len();
196        let mut keys = Vec::with_capacity(seq_len * 2);
197        let mut values = Vec::with_capacity(seq_len * 2);
198
199        for (pos, &tok) in prompt.iter().enumerate() {
200            let mut kv = self.token_vec(tok)?;
201            apply_position_rotation(&mut kv, pos);
202            keys.extend_from_slice(&kv);
203            values.extend_from_slice(&kv);
204        }
205
206        let mut q = self.token_vec(*prompt.last().expect("prompt checked non-empty"))?;
207        apply_position_rotation(&mut q, seq_len - 1);
208
209        let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
210        Ok(project_logits(&ctx, &self.out_proj))
211    }
212
213    fn prefill_then_decode(
214        &self,
215        prompt: &[i32],
216        inject_off_by_one: bool,
217    ) -> std::result::Result<Vec<f32>, RuntimeError> {
218        let prefill_len = prompt.len() - 1;
219        let mut cache = LayerKVCache::new(prompt.len(), 1, 2, KVLayout::BySequence);
220
221        let mut k_prefill = Vec::with_capacity(prefill_len * 2);
222        let mut v_prefill = Vec::with_capacity(prefill_len * 2);
223        for (pos, &tok) in prompt[..prefill_len].iter().enumerate() {
224            let mut kv = self.token_vec(tok)?;
225            apply_position_rotation(&mut kv, pos);
226            k_prefill.extend_from_slice(&kv);
227            v_prefill.extend_from_slice(&kv);
228        }
229        cache.write_prefill(&k_prefill, &v_prefill, prefill_len)?;
230
231        let last = *prompt.last().expect("prompt checked len >= 2");
232        let mut q = self.token_vec(last)?;
233        let decode_pos = if inject_off_by_one {
234            prefill_len + 1
235        } else {
236            prefill_len
237        };
238        apply_position_rotation(&mut q, decode_pos);
239
240        let mut kv_last = self.token_vec(last)?;
241        apply_position_rotation(&mut kv_last, decode_pos);
242        cache.append_token(&kv_last, &kv_last)?;
243
244        let seq_len = cache.seq_len;
245        let keys = cache.k[..seq_len * 2].to_vec();
246        let values = cache.v[..seq_len * 2].to_vec();
247        let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
248        Ok(project_logits(&ctx, &self.out_proj))
249    }
250
251    fn token_vec(&self, token: i32) -> std::result::Result<[f32; 2], RuntimeError> {
252        let idx = usize::try_from(token).map_err(|_| RuntimeError::InvalidToken(token))?;
253        self.embeddings
254            .get(idx)
255            .copied()
256            .ok_or(RuntimeError::InvalidToken(token))
257    }
258}
259
260fn apply_position_rotation(v: &mut [f32; 2], position: usize) {
261    let theta = position as f32 * 0.15;
262    let (sin_t, cos_t) = theta.sin_cos();
263    let x0 = v[0];
264    let x1 = v[1];
265    v[0] = x0 * cos_t - x1 * sin_t;
266    v[1] = x0 * sin_t + x1 * cos_t;
267}
268
269fn attention_single_head(
270    q: &[f32; 2],
271    keys: &[f32],
272    values: &[f32],
273    seq_len: usize,
274    dim: usize,
275) -> [f32; 2] {
276    let scale = 1.0 / (dim as f32).sqrt();
277
278    let mut scores = vec![0.0f32; seq_len];
279    for t in 0..seq_len {
280        let k = &keys[t * dim..t * dim + dim];
281        scores[t] = (q[0] * k[0] + q[1] * k[1]) * scale;
282    }
283    let probs = softmax(&scores);
284
285    let mut out = [0.0f32; 2];
286    for (t, &p) in probs.iter().enumerate() {
287        let v = &values[t * dim..t * dim + dim];
288        out[0] += p * v[0];
289        out[1] += p * v[1];
290    }
291    out
292}
293
294fn project_logits(ctx: &[f32; 2], out_proj: &[[f32; 8]; 2]) -> Vec<f32> {
295    let mut logits = vec![0.0f32; 8];
296    for (i, logit) in logits.iter_mut().enumerate().take(8) {
297        *logit = ctx[0] * out_proj[0][i] + ctx[1] * out_proj[1][i];
298    }
299    logits
300}
301
302fn softmax(scores: &[f32]) -> Vec<f32> {
303    let max_v = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
304    let mut exps: Vec<f32> = scores.iter().map(|s| (s - max_v).exp()).collect();
305    let sum: f32 = exps.iter().sum();
306    if sum > 0.0 {
307        for e in &mut exps {
308            *e /= sum;
309        }
310    }
311    exps
312}
313
314fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
315    a.iter()
316        .zip(b.iter())
317        .map(|(&x, &y)| (x - y).abs())
318        .fold(0.0f32, f32::max)
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    // -- MockEngine tests --
326
327    #[test]
328    fn mock_engine_tokenize_roundtrip() {
329        let engine = MockEngine::new();
330        let tokens = engine.tokenize("hello world").unwrap();
331        assert_eq!(tokens.len(), 2);
332
333        let text = engine.detokenize(&tokens).unwrap();
334        assert_eq!(text, "hello world");
335    }
336
337    #[test]
338    fn mock_engine_prefill_decode() {
339        let engine = MockEngine::new();
340        let mut session = Session::new();
341
342        let tokens = engine.tokenize("hello world").unwrap();
343        let prefill = engine.prefill(&mut session, &tokens).unwrap();
344        assert_eq!(prefill.tokens_processed, 2);
345
346        let result = engine.decode(&mut session).unwrap();
347        assert!(result.token >= 0);
348    }
349
350    #[test]
351    fn mock_engine_embed() {
352        let engine = MockEngine::new();
353        let embeddings = engine.embed(&["hello", "world"]).unwrap();
354        assert_eq!(embeddings.len(), 2);
355        assert_eq!(embeddings[0].len(), 128);
356    }
357
358    #[test]
359    fn mock_engine_is_send_sync() {
360        fn assert_send_sync<T: Send + Sync>() {}
361        assert_send_sync::<MockEngine>();
362    }
363
364    // -- RuntimeVerifier tests --
365
366    #[test]
367    fn kv_true_test_equivalence_holds() {
368        let verifier = RuntimeVerifier::new();
369        let prompt = [1, 3, 2, 6];
370        let report = verifier.verify_kv_equivalence(&prompt).unwrap();
371        assert!(
372            report.max_abs_diff <= 1e-5,
373            "expected <= 1e-5, got {}",
374            report.max_abs_diff
375        );
376    }
377
378    #[test]
379    fn kv_true_test_detects_off_by_one_bug() {
380        let verifier = RuntimeVerifier::new();
381        let prompt = [1, 3, 2, 6];
382        let report = verifier.verify_with_off_by_one_bug(&prompt).unwrap();
383        assert!(
384            report.max_abs_diff > 1e-4,
385            "off-by-one bug should be detectable, got {}",
386            report.max_abs_diff
387        );
388    }
389
390    #[test]
391    fn verify_rejects_too_short_prompt() {
392        let verifier = RuntimeVerifier::new();
393        let err = verifier.verify_kv_equivalence(&[1]).unwrap_err();
394        assert!(matches!(err, RuntimeError::PromptTooShort));
395    }
396}