Skip to main content

llama_runtime/
lib.rs

1//! # llama-runtime
2//!
3//! Runtime execution and verification helpers for llama.rs.
4//!
5//! This crate includes a Phase-1 verification harness for LLAMA-006:
6//! `full_forward(prompt)` logits vs `prefill(prompt[:-1]) + decode(last_token)` logits.
7
8use llama_kv::{KVLayout, LayerKVCache};
9
10/// Errors for runtime verification routines.
11#[derive(Debug, thiserror::Error)]
12pub enum RuntimeError {
13    #[error("prompt must contain at least 2 tokens")]
14    PromptTooShort,
15    #[error("invalid token id: {0}")]
16    InvalidToken(i32),
17    #[error("kv error: {0}")]
18    Kv(#[from] llama_kv::KVError),
19}
20
21/// Result of a KV equivalence run.
22#[derive(Debug, Clone, Copy)]
23pub struct KvEquivalenceReport {
24    pub max_abs_diff: f32,
25}
26
27/// Minimal deterministic runtime verifier used for LLAMA-006 true tests.
28pub struct RuntimeVerifier {
29    model: ToyModel,
30}
31
32impl Default for RuntimeVerifier {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl RuntimeVerifier {
39    pub fn new() -> Self {
40        Self {
41            model: ToyModel::default(),
42        }
43    }
44
45    /// True test:
46    /// full_forward(prompt) logits == prefill(prompt[:-1]) + decode(last_token) logits.
47    pub fn verify_kv_equivalence(
48        &self,
49        prompt: &[i32],
50    ) -> Result<KvEquivalenceReport, RuntimeError> {
51        if prompt.len() < 2 {
52            return Err(RuntimeError::PromptTooShort);
53        }
54
55        let full = self.model.full_forward(prompt)?;
56        let kv = self.model.prefill_then_decode(prompt, false)?;
57        Ok(KvEquivalenceReport {
58            max_abs_diff: max_abs_diff(&full, &kv),
59        })
60    }
61
62    /// Same flow as `verify_kv_equivalence` but injects an off-by-one position bug
63    /// in decode. Used to prove the true test catches indexing errors.
64    pub fn verify_with_off_by_one_bug(
65        &self,
66        prompt: &[i32],
67    ) -> Result<KvEquivalenceReport, RuntimeError> {
68        if prompt.len() < 2 {
69            return Err(RuntimeError::PromptTooShort);
70        }
71
72        let full = self.model.full_forward(prompt)?;
73        let kv_bug = self.model.prefill_then_decode(prompt, true)?;
74        Ok(KvEquivalenceReport {
75            max_abs_diff: max_abs_diff(&full, &kv_bug),
76        })
77    }
78}
79
80#[derive(Debug)]
81struct ToyModel {
82    embeddings: [[f32; 2]; 8],
83    out_proj: [[f32; 8]; 2],
84}
85
86impl Default for ToyModel {
87    fn default() -> Self {
88        Self {
89            embeddings: [
90                [0.4, -0.2],
91                [0.1, 0.9],
92                [0.8, 0.2],
93                [-0.5, 0.7],
94                [0.3, -0.9],
95                [-0.2, -0.3],
96                [0.6, 0.4],
97                [-0.7, 0.5],
98            ],
99            out_proj: [
100                [0.3, -0.1, 0.2, 0.5, -0.4, 0.1, 0.2, -0.3],
101                [-0.2, 0.6, -0.3, 0.1, 0.4, -0.5, 0.2, 0.3],
102            ],
103        }
104    }
105}
106
107impl ToyModel {
108    fn full_forward(&self, prompt: &[i32]) -> Result<Vec<f32>, RuntimeError> {
109        let seq_len = prompt.len();
110        let mut keys = Vec::with_capacity(seq_len * 2);
111        let mut values = Vec::with_capacity(seq_len * 2);
112
113        for (pos, &tok) in prompt.iter().enumerate() {
114            let mut kv = self.token_vec(tok)?;
115            apply_position_rotation(&mut kv, pos);
116            keys.extend_from_slice(&kv);
117            values.extend_from_slice(&kv);
118        }
119
120        let mut q = self.token_vec(*prompt.last().expect("prompt checked non-empty"))?;
121        apply_position_rotation(&mut q, seq_len - 1);
122
123        let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
124        Ok(project_logits(&ctx, &self.out_proj))
125    }
126
127    fn prefill_then_decode(
128        &self,
129        prompt: &[i32],
130        inject_off_by_one: bool,
131    ) -> Result<Vec<f32>, RuntimeError> {
132        let prefill_len = prompt.len() - 1;
133        let mut cache = LayerKVCache::new(prompt.len(), 1, 2, KVLayout::BySequence);
134
135        let mut k_prefill = Vec::with_capacity(prefill_len * 2);
136        let mut v_prefill = Vec::with_capacity(prefill_len * 2);
137        for (pos, &tok) in prompt[..prefill_len].iter().enumerate() {
138            let mut kv = self.token_vec(tok)?;
139            apply_position_rotation(&mut kv, pos);
140            k_prefill.extend_from_slice(&kv);
141            v_prefill.extend_from_slice(&kv);
142        }
143        cache.write_prefill(&k_prefill, &v_prefill, prefill_len)?;
144
145        let last = *prompt.last().expect("prompt checked len >= 2");
146        let mut q = self.token_vec(last)?;
147        let decode_pos = if inject_off_by_one {
148            prefill_len + 1
149        } else {
150            prefill_len
151        };
152        apply_position_rotation(&mut q, decode_pos);
153
154        let mut kv_last = self.token_vec(last)?;
155        apply_position_rotation(&mut kv_last, decode_pos);
156        cache.append_token(&kv_last, &kv_last)?;
157
158        let seq_len = cache.seq_len;
159        let keys = cache.k[..seq_len * 2].to_vec();
160        let values = cache.v[..seq_len * 2].to_vec();
161        let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
162        Ok(project_logits(&ctx, &self.out_proj))
163    }
164
165    fn token_vec(&self, token: i32) -> Result<[f32; 2], RuntimeError> {
166        let idx = usize::try_from(token).map_err(|_| RuntimeError::InvalidToken(token))?;
167        self.embeddings
168            .get(idx)
169            .copied()
170            .ok_or(RuntimeError::InvalidToken(token))
171    }
172}
173
174fn apply_position_rotation(v: &mut [f32; 2], position: usize) {
175    let theta = position as f32 * 0.15;
176    let (sin_t, cos_t) = theta.sin_cos();
177    let x0 = v[0];
178    let x1 = v[1];
179    v[0] = x0 * cos_t - x1 * sin_t;
180    v[1] = x0 * sin_t + x1 * cos_t;
181}
182
183fn attention_single_head(
184    q: &[f32; 2],
185    keys: &[f32],
186    values: &[f32],
187    seq_len: usize,
188    dim: usize,
189) -> [f32; 2] {
190    let scale = 1.0 / (dim as f32).sqrt();
191
192    let mut scores = vec![0.0f32; seq_len];
193    for t in 0..seq_len {
194        let k = &keys[t * dim..t * dim + dim];
195        scores[t] = (q[0] * k[0] + q[1] * k[1]) * scale;
196    }
197    let probs = softmax(&scores);
198
199    let mut out = [0.0f32; 2];
200    for (t, &p) in probs.iter().enumerate() {
201        let v = &values[t * dim..t * dim + dim];
202        out[0] += p * v[0];
203        out[1] += p * v[1];
204    }
205    out
206}
207
208fn project_logits(ctx: &[f32; 2], out_proj: &[[f32; 8]; 2]) -> Vec<f32> {
209    let mut logits = vec![0.0f32; 8];
210    for (i, logit) in logits.iter_mut().enumerate().take(8) {
211        *logit = ctx[0] * out_proj[0][i] + ctx[1] * out_proj[1][i];
212    }
213    logits
214}
215
216fn softmax(scores: &[f32]) -> Vec<f32> {
217    let max_v = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
218    let mut exps: Vec<f32> = scores.iter().map(|s| (s - max_v).exp()).collect();
219    let sum: f32 = exps.iter().sum();
220    if sum > 0.0 {
221        for e in &mut exps {
222            *e /= sum;
223        }
224    }
225    exps
226}
227
228fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
229    a.iter()
230        .zip(b.iter())
231        .map(|(&x, &y)| (x - y).abs())
232        .fold(0.0f32, f32::max)
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn kv_true_test_equivalence_holds() {
241        let verifier = RuntimeVerifier::new();
242        let prompt = [1, 3, 2, 6];
243        let report = verifier.verify_kv_equivalence(&prompt).unwrap();
244        assert!(
245            report.max_abs_diff <= 1e-5,
246            "expected <= 1e-5, got {}",
247            report.max_abs_diff
248        );
249    }
250
251    #[test]
252    fn kv_true_test_detects_off_by_one_bug() {
253        let verifier = RuntimeVerifier::new();
254        let prompt = [1, 3, 2, 6];
255        let report = verifier.verify_with_off_by_one_bug(&prompt).unwrap();
256        assert!(
257            report.max_abs_diff > 1e-4,
258            "off-by-one bug should be detectable, got {}",
259            report.max_abs_diff
260        );
261    }
262
263    #[test]
264    fn verify_rejects_too_short_prompt() {
265        let verifier = RuntimeVerifier::new();
266        let err = verifier.verify_kv_equivalence(&[1]).unwrap_err();
267        assert!(matches!(err, RuntimeError::PromptTooShort));
268    }
269}