1use llama_engine::{
11 DecodeResult, LlamaEngine, LlamaError, ModelHandle, ModelSpec, PrefillResult, Result, Session,
12 TokenId,
13};
14use llama_kv::{KVLayout, LayerKVCache};
15use llama_sampling::{Sampler, SamplingConfig, SamplingStrategy};
16use llama_tokenizer::{Tokenizer, WhitespaceTokenizer};
17use std::sync::Mutex;
18
19pub struct MockEngine {
28 tokenizer: WhitespaceTokenizer,
29 sampler: Mutex<Sampler>,
30}
31
32impl MockEngine {
33 pub fn new() -> Self {
34 Self {
35 tokenizer: WhitespaceTokenizer::new(),
36 sampler: Mutex::new(
37 Sampler::new(SamplingConfig {
38 strategy: SamplingStrategy::Greedy,
39 ..SamplingConfig::default()
40 })
41 .expect("default config is valid"),
42 ),
43 }
44 }
45}
46
47impl Default for MockEngine {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl LlamaEngine for MockEngine {
54 fn load_model(&self, _spec: &ModelSpec) -> Result<ModelHandle> {
55 Ok(ModelHandle)
56 }
57
58 fn tokenize(&self, text: &str) -> Result<Vec<TokenId>> {
59 self.tokenizer
60 .encode(text)
61 .map_err(|e| LlamaError::Tokenization(e.to_string()))
62 }
63
64 fn detokenize(&self, tokens: &[TokenId]) -> Result<String> {
65 self.tokenizer
66 .decode(tokens)
67 .map_err(|e| LlamaError::Tokenization(e.to_string()))
68 }
69
70 fn prefill(&self, _session: &mut Session, tokens: &[TokenId]) -> Result<PrefillResult> {
71 Ok(PrefillResult {
72 tokens_processed: tokens.len(),
73 })
74 }
75
76 fn decode(&self, _session: &mut Session) -> Result<DecodeResult> {
77 let mock_logits = vec![0.1, 0.5, 0.1, 0.1, 0.2];
78 let mut sampler = self.sampler.lock().unwrap();
79 let token = sampler
80 .sample(&mock_logits, &[])
81 .map_err(|e| LlamaError::Inference(format!("{}", e)))?;
82 Ok(DecodeResult {
83 token: token as TokenId,
84 })
85 }
86
87 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
88 Ok(texts.iter().map(|_| vec![0.0; 128]).collect())
89 }
90}
91
92#[derive(Debug, thiserror::Error)]
98pub enum RuntimeError {
99 #[error("prompt must contain at least 2 tokens")]
100 PromptTooShort,
101 #[error("invalid token id: {0}")]
102 InvalidToken(i32),
103 #[error("kv error: {0}")]
104 Kv(#[from] llama_kv::KVError),
105}
106
107#[derive(Debug, Clone, Copy)]
109pub struct KvEquivalenceReport {
110 pub max_abs_diff: f32,
111}
112
113pub struct RuntimeVerifier {
115 model: ToyModel,
116}
117
118impl Default for RuntimeVerifier {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl RuntimeVerifier {
125 pub fn new() -> Self {
126 Self {
127 model: ToyModel::default(),
128 }
129 }
130
131 pub fn verify_kv_equivalence(
134 &self,
135 prompt: &[i32],
136 ) -> std::result::Result<KvEquivalenceReport, RuntimeError> {
137 if prompt.len() < 2 {
138 return Err(RuntimeError::PromptTooShort);
139 }
140
141 let full = self.model.full_forward(prompt)?;
142 let kv = self.model.prefill_then_decode(prompt, false)?;
143 Ok(KvEquivalenceReport {
144 max_abs_diff: max_abs_diff(&full, &kv),
145 })
146 }
147
148 pub fn verify_with_off_by_one_bug(
151 &self,
152 prompt: &[i32],
153 ) -> std::result::Result<KvEquivalenceReport, RuntimeError> {
154 if prompt.len() < 2 {
155 return Err(RuntimeError::PromptTooShort);
156 }
157
158 let full = self.model.full_forward(prompt)?;
159 let kv_bug = self.model.prefill_then_decode(prompt, true)?;
160 Ok(KvEquivalenceReport {
161 max_abs_diff: max_abs_diff(&full, &kv_bug),
162 })
163 }
164}
165
166#[derive(Debug)]
167struct ToyModel {
168 embeddings: [[f32; 2]; 8],
169 out_proj: [[f32; 8]; 2],
170}
171
172impl Default for ToyModel {
173 fn default() -> Self {
174 Self {
175 embeddings: [
176 [0.4, -0.2],
177 [0.1, 0.9],
178 [0.8, 0.2],
179 [-0.5, 0.7],
180 [0.3, -0.9],
181 [-0.2, -0.3],
182 [0.6, 0.4],
183 [-0.7, 0.5],
184 ],
185 out_proj: [
186 [0.3, -0.1, 0.2, 0.5, -0.4, 0.1, 0.2, -0.3],
187 [-0.2, 0.6, -0.3, 0.1, 0.4, -0.5, 0.2, 0.3],
188 ],
189 }
190 }
191}
192
193impl ToyModel {
194 fn full_forward(&self, prompt: &[i32]) -> std::result::Result<Vec<f32>, RuntimeError> {
195 let seq_len = prompt.len();
196 let mut keys = Vec::with_capacity(seq_len * 2);
197 let mut values = Vec::with_capacity(seq_len * 2);
198
199 for (pos, &tok) in prompt.iter().enumerate() {
200 let mut kv = self.token_vec(tok)?;
201 apply_position_rotation(&mut kv, pos);
202 keys.extend_from_slice(&kv);
203 values.extend_from_slice(&kv);
204 }
205
206 let mut q = self.token_vec(*prompt.last().expect("prompt checked non-empty"))?;
207 apply_position_rotation(&mut q, seq_len - 1);
208
209 let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
210 Ok(project_logits(&ctx, &self.out_proj))
211 }
212
213 fn prefill_then_decode(
214 &self,
215 prompt: &[i32],
216 inject_off_by_one: bool,
217 ) -> std::result::Result<Vec<f32>, RuntimeError> {
218 let prefill_len = prompt.len() - 1;
219 let mut cache = LayerKVCache::new(prompt.len(), 1, 2, KVLayout::BySequence);
220
221 let mut k_prefill = Vec::with_capacity(prefill_len * 2);
222 let mut v_prefill = Vec::with_capacity(prefill_len * 2);
223 for (pos, &tok) in prompt[..prefill_len].iter().enumerate() {
224 let mut kv = self.token_vec(tok)?;
225 apply_position_rotation(&mut kv, pos);
226 k_prefill.extend_from_slice(&kv);
227 v_prefill.extend_from_slice(&kv);
228 }
229 cache.write_prefill(&k_prefill, &v_prefill, prefill_len)?;
230
231 let last = *prompt.last().expect("prompt checked len >= 2");
232 let mut q = self.token_vec(last)?;
233 let decode_pos = if inject_off_by_one {
234 prefill_len + 1
235 } else {
236 prefill_len
237 };
238 apply_position_rotation(&mut q, decode_pos);
239
240 let mut kv_last = self.token_vec(last)?;
241 apply_position_rotation(&mut kv_last, decode_pos);
242 cache.append_token(&kv_last, &kv_last)?;
243
244 let seq_len = cache.seq_len;
245 let keys = cache.k[..seq_len * 2].to_vec();
246 let values = cache.v[..seq_len * 2].to_vec();
247 let ctx = attention_single_head(&q, &keys, &values, seq_len, 2);
248 Ok(project_logits(&ctx, &self.out_proj))
249 }
250
251 fn token_vec(&self, token: i32) -> std::result::Result<[f32; 2], RuntimeError> {
252 let idx = usize::try_from(token).map_err(|_| RuntimeError::InvalidToken(token))?;
253 self.embeddings
254 .get(idx)
255 .copied()
256 .ok_or(RuntimeError::InvalidToken(token))
257 }
258}
259
260fn apply_position_rotation(v: &mut [f32; 2], position: usize) {
261 let theta = position as f32 * 0.15;
262 let (sin_t, cos_t) = theta.sin_cos();
263 let x0 = v[0];
264 let x1 = v[1];
265 v[0] = x0 * cos_t - x1 * sin_t;
266 v[1] = x0 * sin_t + x1 * cos_t;
267}
268
269fn attention_single_head(
270 q: &[f32; 2],
271 keys: &[f32],
272 values: &[f32],
273 seq_len: usize,
274 dim: usize,
275) -> [f32; 2] {
276 let scale = 1.0 / (dim as f32).sqrt();
277
278 let mut scores = vec![0.0f32; seq_len];
279 for t in 0..seq_len {
280 let k = &keys[t * dim..t * dim + dim];
281 scores[t] = (q[0] * k[0] + q[1] * k[1]) * scale;
282 }
283 let probs = softmax(&scores);
284
285 let mut out = [0.0f32; 2];
286 for (t, &p) in probs.iter().enumerate() {
287 let v = &values[t * dim..t * dim + dim];
288 out[0] += p * v[0];
289 out[1] += p * v[1];
290 }
291 out
292}
293
294fn project_logits(ctx: &[f32; 2], out_proj: &[[f32; 8]; 2]) -> Vec<f32> {
295 let mut logits = vec![0.0f32; 8];
296 for (i, logit) in logits.iter_mut().enumerate().take(8) {
297 *logit = ctx[0] * out_proj[0][i] + ctx[1] * out_proj[1][i];
298 }
299 logits
300}
301
302fn softmax(scores: &[f32]) -> Vec<f32> {
303 let max_v = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
304 let mut exps: Vec<f32> = scores.iter().map(|s| (s - max_v).exp()).collect();
305 let sum: f32 = exps.iter().sum();
306 if sum > 0.0 {
307 for e in &mut exps {
308 *e /= sum;
309 }
310 }
311 exps
312}
313
314fn max_abs_diff(a: &[f32], b: &[f32]) -> f32 {
315 a.iter()
316 .zip(b.iter())
317 .map(|(&x, &y)| (x - y).abs())
318 .fold(0.0f32, f32::max)
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
328 fn mock_engine_tokenize_roundtrip() {
329 let engine = MockEngine::new();
330 let tokens = engine.tokenize("hello world").unwrap();
331 assert_eq!(tokens.len(), 2);
332
333 let text = engine.detokenize(&tokens).unwrap();
334 assert_eq!(text, "hello world");
335 }
336
337 #[test]
338 fn mock_engine_prefill_decode() {
339 let engine = MockEngine::new();
340 let mut session = Session::new();
341
342 let tokens = engine.tokenize("hello world").unwrap();
343 let prefill = engine.prefill(&mut session, &tokens).unwrap();
344 assert_eq!(prefill.tokens_processed, 2);
345
346 let result = engine.decode(&mut session).unwrap();
347 assert!(result.token >= 0);
348 }
349
350 #[test]
351 fn mock_engine_embed() {
352 let engine = MockEngine::new();
353 let embeddings = engine.embed(&["hello", "world"]).unwrap();
354 assert_eq!(embeddings.len(), 2);
355 assert_eq!(embeddings[0].len(), 128);
356 }
357
358 #[test]
359 fn mock_engine_is_send_sync() {
360 fn assert_send_sync<T: Send + Sync>() {}
361 assert_send_sync::<MockEngine>();
362 }
363
364 #[test]
367 fn kv_true_test_equivalence_holds() {
368 let verifier = RuntimeVerifier::new();
369 let prompt = [1, 3, 2, 6];
370 let report = verifier.verify_kv_equivalence(&prompt).unwrap();
371 assert!(
372 report.max_abs_diff <= 1e-5,
373 "expected <= 1e-5, got {}",
374 report.max_abs_diff
375 );
376 }
377
378 #[test]
379 fn kv_true_test_detects_off_by_one_bug() {
380 let verifier = RuntimeVerifier::new();
381 let prompt = [1, 3, 2, 6];
382 let report = verifier.verify_with_off_by_one_bug(&prompt).unwrap();
383 assert!(
384 report.max_abs_diff > 1e-4,
385 "off-by-one bug should be detectable, got {}",
386 report.max_abs_diff
387 );
388 }
389
390 #[test]
391 fn verify_rejects_too_short_prompt() {
392 let verifier = RuntimeVerifier::new();
393 let err = verifier.verify_kv_equivalence(&[1]).unwrap_err();
394 assert!(matches!(err, RuntimeError::PromptTooShort));
395 }
396}