1use llama_kv::{KVLayout, LayerKVCache};
9
10#[derive(Debug, thiserror::Error)]
12pub enum RuntimeError {
13 #[error("prompt must contain at least 2 tokens")]
14 PromptTooShort,
15 #[error("invalid token id: {0}")]
16 InvalidToken(i32),
17 #[error("kv error: {0}")]
18 Kv(#[from] llama_kv::KVError),
19}
20
21#[derive(Debug, Clone, Copy)]
23pub struct KvEquivalenceReport {
24 pub max_abs_diff: f32,
25}
26
27pub struct RuntimeVerifier {
29 model: ToyModel,
30}
31
32impl Default for RuntimeVerifier {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl RuntimeVerifier {
39 pub fn new() -> Self {
40 Self {
41 model: ToyModel::default(),
42 }
43 }
44
45 pub fn verify_kv_equivalence(
48 &self,
49 prompt: &[i32],
50 ) -> Result<KvEquivalenceReport, RuntimeError> {
51 if prompt.len() < 2 {
52 return Err(RuntimeError::PromptTooShort);
53 }
54
55 let full = self.model.full_forward(prompt)?;
56 let kv = self.model.prefill_then_decode(prompt, false)?;
57 Ok(KvEquivalenceReport {
58 max_abs_diff: max_abs_diff(&full, &kv),
59 })
60 }
61
62 pub fn verify_with_off_by_one_bug(
65 &self,
66 prompt: &[i32],
67 ) -> Result<KvEquivalenceReport, RuntimeError> {
68 if prompt.len() < 2 {
69 return Err(RuntimeError::PromptTooShort);
70 }
71
72 let full = self.model.full_forward(prompt)?;
73 let kv_bug = self.model.prefill_then_decode(prompt, true)?;
74 Ok(KvEquivalenceReport {
75 max_abs_diff: max_abs_diff(&full, &kv_bug),
76 })
77 }
78}
79
80#[derive(Debug)]
81struct ToyModel {
82 embeddings: [[f32; 2]; 8],
83 out_proj: [[f32; 8]; 2],
84}
85
86impl Default for ToyModel {
87 fn default() -> Self {
88 Self {
89 embeddings: [
90 [0.4, -0.2],
91 [0.1, 0.9],
92 [0.8, 0.2],
93 [-0.5, 0.7],
94 [0.3, -0.9],
95 [-0.2, -0.3],
96 [0.6, 0.4],
97 [-0.7, 0.5],
98 ],
99 out_proj: [
100 [0.3, -0.1, 0.2, 0.5, -0.4, 0.1, 0.2, -0.3],
101 [-0.2, 0.6, -0.3, 0.1, 0.4, -0.5, 0.2, 0.3],
102 ],
103 }
104 }
105}
106
107impl ToyModel {
108 fn full_forward(&self, prompt: &[i32]) -> Result<Vec<f32>, RuntimeError> {
109 let seq_len = prompt.len();
110 let mut keys = Vec::with_capacity(seq_len * 2);
111 let mut values = Vec::with_capacity(seq_len * 2);
112
113 for (pos, &tok) in prompt.iter().enumerate() {
114 let mut kv = self.token_vec(tok)?;
115 apply_position_rotation(&mut kv, pos);
116 keys.extend_from_slice(&kv);
117 values.extend_from_slice(&kv);
118 }
119
120 let mut q = self.token_vec(*prompt.last().expect("prompt checked non-empty"))?;
121 apply_position_rotation(&mut q, seq_len - 1);
122
123 let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
124 Ok(project_logits(&ctx, &self.out_proj))
125 }
126
127 fn prefill_then_decode(
128 &self,
129 prompt: &[i32],
130 inject_off_by_one: bool,
131 ) -> Result<Vec<f32>, RuntimeError> {
132 let prefill_len = prompt.len() - 1;
133 let mut cache = LayerKVCache::new(prompt.len(), 1, 2, KVLayout::BySequence);
134
135 let mut k_prefill = Vec::with_capacity(prefill_len * 2);
136 let mut v_prefill = Vec::with_capacity(prefill_len * 2);
137 for (pos, &tok) in prompt[..prefill_len].iter().enumerate() {
138 let mut kv = self.token_vec(tok)?;
139 apply_position_rotation(&mut kv, pos);
140 k_prefill.extend_from_slice(&kv);
141 v_prefill.extend_from_slice(&kv);
142 }
143 cache.write_prefill(&k_prefill, &v_prefill, prefill_len)?;
144
145 let last = *prompt.last().expect("prompt checked len >= 2");
146 let mut q = self.token_vec(last)?;
147 let decode_pos = if inject_off_by_one {
148 prefill_len + 1
149 } else {
150 prefill_len
151 };
152 apply_position_rotation(&mut q, decode_pos);
153
154 let mut kv_last = self.token_vec(last)?;
155 apply_position_rotation(&mut kv_last, decode_pos);
156 cache.append_token(&kv_last, &kv_last)?;
157
158 let seq_len = cache.seq_len;
159 let keys = cache.k[..seq_len * 2].to_vec();
160 let values = cache.v[..seq_len * 2].to_vec();
161 let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
162 Ok(project_logits(&ctx, &self.out_proj))
163 }
164
165 fn token_vec(&self, token: i32) -> Result<[f32; 2], RuntimeError> {
166 let idx = usize::try_from(token).map_err(|_| RuntimeError::InvalidToken(token))?;
167 self.embeddings
168 .get(idx)
169 .copied()
170 .ok_or(RuntimeError::InvalidToken(token))
171 }
172}
173
174fn apply_position_rotation(v: &mut [f32; 2], position: usize) {
175 let theta = position as f32 * 0.15;
176 let (sin_t, cos_t) = theta.sin_cos();
177 let x0 = v[0];
178 let x1 = v[1];
179 v[0] = x0 * cos_t - x1 * sin_t;
180 v[1] = x0 * sin_t + x1 * cos_t;
181}
182
183fn attention_single_head(
184 q: &[f32; 2],
185 keys: &[f32],
186 values: &[f32],
187 seq_len: usize,
188 dim: usize,
189) -> [f32; 2] {
190 let scale = 1.0 / (dim as f32).sqrt();
191
192 let mut scores = vec![0.0f32; seq_len];
193 for t in 0..seq_len {
194 let k = &keys[t * dim..t * dim + dim];
195 scores[t] = (q[0] * k[0] + q[1] * k[1]) * scale;
196 }
197 let probs = softmax(&scores);
198
199 let mut out = [0.0f32; 2];
200 for (t, &p) in probs.iter().enumerate() {
201 let v = &values[t * dim..t * dim + dim];
202 out[0] += p * v[0];
203 out[1] += p * v[1];
204 }
205 out
206}
207
208fn project_logits(ctx: &[f32; 2], out_proj: &[[f32; 8]; 2]) -> Vec<f32> {
209 let mut logits = vec![0.0f32; 8];
210 for (i, logit) in logits.iter_mut().enumerate().take(8) {
211 *logit = ctx[0] * out_proj[0][i] + ctx[1] * out_proj[1][i];
212 }
213 logits
214}
215
216fn softmax(scores: &[f32]) -> Vec<f32> {
217 let max_v = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
218 let mut exps: Vec<f32> = scores.iter().map(|s| (s - max_v).exp()).collect();
219 let sum: f32 = exps.iter().sum();
220 if sum > 0.0 {
221 for e in &mut exps {
222 *e /= sum;
223 }
224 }
225 exps
226}
227
228fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
229 a.iter()
230 .zip(b.iter())
231 .map(|(&x, &y)| (x - y).abs())
232 .fold(0.0f32, f32::max)
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn kv_true_test_equivalence_holds() {
241 let verifier = RuntimeVerifier::new();
242 let prompt = [1, 3, 2, 6];
243 let report = verifier.verify_kv_equivalence(&prompt).unwrap();
244 assert!(
245 report.max_abs_diff <= 1e-5,
246 "expected <= 1e-5, got {}",
247 report.max_abs_diff
248 );
249 }
250
251 #[test]
252 fn kv_true_test_detects_off_by_one_bug() {
253 let verifier = RuntimeVerifier::new();
254 let prompt = [1, 3, 2, 6];
255 let report = verifier.verify_with_off_by_one_bug(&prompt).unwrap();
256 assert!(
257 report.max_abs_diff > 1e-4,
258 "off-by-one bug should be detectable, got {}",
259 report.max_abs_diff
260 );
261 }
262
263 #[test]
264 fn verify_rejects_too_short_prompt() {
265 let verifier = RuntimeVerifier::new();
266 let err = verifier.verify_kv_equivalence(&[1]).unwrap_err();
267 assert!(matches!(err, RuntimeError::PromptTooShort));
268 }
269}