1use std::io::{Read, Seek, SeekFrom};
27use std::path::Path;
28
29use blake3::Hasher;
30use oxicode::{Decode, Encode};
31
32use crate::engine::{EngineConfig, InferenceEngine};
33use crate::error::{RuntimeError, RuntimeResult};
34use crate::sampling::SamplerConfig;
35
36pub const SNAPSHOT_MAGIC: &[u8; 8] = b"OXISNAP1";
38
39const DEFAULT_PROBE_SIZE: u32 = 8 * 1024 * 1024;
41
42#[derive(Debug, Clone, PartialEq, Encode, Decode)]
53pub struct ModelFingerprint {
54 pub file_size: u64,
56 pub mtime_secs: i64,
58 pub head_hash: [u8; 32],
60 pub tail_hash: [u8; 32],
62 pub probe_size: u32,
64}
65
66impl ModelFingerprint {
67 pub fn compute(path: &Path) -> RuntimeResult<Self> {
71 Self::compute_with_probe(path, DEFAULT_PROBE_SIZE)
72 }
73
74 pub fn compute_with_probe(path: &Path, probe_size: u32) -> RuntimeResult<Self> {
76 let mut file = std::fs::File::open(path)?;
77 let metadata = file.metadata()?;
78 let file_size = metadata.len();
79
80 let mtime_secs = {
82 use std::time::SystemTime;
83 metadata
84 .modified()
85 .ok()
86 .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
87 .map(|d| d.as_secs() as i64)
88 .unwrap_or(0)
89 };
90
91 let head_read = (probe_size as u64).min(file_size) as usize;
93 let mut head_buf = vec![0u8; head_read];
94 file.seek(SeekFrom::Start(0))?;
95 file.read_exact(&mut head_buf)?;
96 let head_hash: [u8; 32] = *Hasher::new().update(&head_buf).finalize().as_bytes();
97
98 let tail_start = file_size.saturating_sub(probe_size as u64);
102 let tail_read = (file_size - tail_start) as usize;
103 let mut tail_buf = vec![0u8; tail_read];
104 file.seek(SeekFrom::Start(tail_start))?;
105 file.read_exact(&mut tail_buf)?;
106 let tail_hash: [u8; 32] = *Hasher::new().update(&tail_buf).finalize().as_bytes();
107
108 Ok(Self {
109 file_size,
110 mtime_secs,
111 head_hash,
112 tail_hash,
113 probe_size,
114 })
115 }
116
117 pub fn verify(&self, path: &Path) -> RuntimeResult<()> {
122 let actual = Self::compute_with_probe(path, self.probe_size)?;
123 if actual == *self {
124 return Ok(());
125 }
126 Err(RuntimeError::ModelFingerprintMismatch {
127 expected: self.display(),
128 found: actual.display(),
129 detail: format!(
130 "model file '{}' has been modified or replaced since the snapshot was taken",
131 path.display()
132 ),
133 })
134 }
135
136 pub fn display(&self) -> String {
138 let head_hex: String = self.head_hash.iter().map(|b| format!("{b:02x}")).collect();
139 let tail_hex: String = self.tail_hash.iter().map(|b| format!("{b:02x}")).collect();
140 format!(
141 "size={} mtime={} head={}...{} tail={}...{}",
142 self.file_size,
143 self.mtime_secs,
144 &head_hex[..8],
145 &head_hex[head_hex.len() - 8..],
146 &tail_hex[..8],
147 &tail_hex[tail_hex.len() - 8..],
148 )
149 }
150}
151
152#[derive(Debug, Clone, Encode, Decode)]
156pub struct KvStatePayload {
157 pub keys: Vec<Vec<f32>>,
159 pub values: Vec<Vec<f32>>,
161 pub seq_len: usize,
163 pub num_layers: usize,
165 pub max_seq_len: usize,
167 pub kv_dim: usize,
169}
170
171#[derive(Debug, Clone, Encode, Decode)]
175pub struct SsmStatePayload {
176 pub ssm_states: Vec<Vec<f32>>,
179 pub step: usize,
181}
182
183#[derive(Debug, Clone, Encode, Decode)]
191pub enum SequenceStatePayload {
192 Attention(KvStatePayload),
194 Mamba2(SsmStatePayload),
196 Jamba {
198 attention: KvStatePayload,
200 ssm: SsmStatePayload,
202 },
203}
204
205#[derive(Debug, Clone, Encode, Decode)]
209pub struct SamplerStatePayload {
210 pub rng_state: u64,
212 pub mirostat_mu: f32,
214 pub temperature: f32,
216 pub top_k: usize,
218 pub top_p: f32,
220 pub min_p: f32,
222 pub repetition_penalty: f32,
224 pub repetition_penalty_window: usize,
226 pub seed: Option<u64>,
228 pub mirostat_mode: u8,
230 pub mirostat_tau: f32,
232 pub mirostat_eta: f32,
234}
235
236#[derive(Debug, Clone, Encode, Decode)]
244pub struct GrammarStatePayload {
245 pub grammar_source: String,
247}
248
249#[derive(Debug, Clone, Encode, Decode)]
257pub struct EngineSnapshot {
258 pub magic: [u8; 8],
260 pub version: u32,
262 pub arch_id: String,
264 pub model_path: String,
266 pub tokenizer_path: Option<String>,
268 pub model_fingerprint: ModelFingerprint,
270 pub tokens: Vec<u32>,
272 pub sequence_state: SequenceStatePayload,
274 pub sampler_state: SamplerStatePayload,
276 pub grammar_state: Option<GrammarStatePayload>,
278 pub max_context_length: usize,
280 pub num_threads: usize,
282 pub prefill_chunk_size: usize,
284}
285
286impl EngineSnapshot {
287 pub const VERSION: u32 = 1;
289
290 pub fn serialize(&self) -> RuntimeResult<Vec<u8>> {
292 oxicode::encode_to_vec(self).map_err(|e| RuntimeError::SnapshotIncompatible {
293 detail: format!("serialization failed: {e}"),
294 })
295 }
296
297 pub fn deserialize(bytes: &[u8]) -> RuntimeResult<Self> {
302 let (snap, _) = oxicode::decode_from_slice::<Self>(bytes).map_err(|e| {
303 RuntimeError::SnapshotIncompatible {
304 detail: format!("deserialization failed: {e}"),
305 }
306 })?;
307
308 if &snap.magic != SNAPSHOT_MAGIC {
309 return Err(RuntimeError::SnapshotIncompatible {
310 detail: "invalid snapshot magic bytes".to_string(),
311 });
312 }
313
314 if snap.version != Self::VERSION {
315 return Err(RuntimeError::SnapshotIncompatible {
316 detail: format!(
317 "snapshot version {} is not supported (expected {})",
318 snap.version,
319 Self::VERSION
320 ),
321 });
322 }
323
324 Ok(snap)
325 }
326}
327
328pub const SPEC_SNAPSHOT_MAGIC: &[u8; 8] = b"OXISPEC1";
332
333const SPEC_SNAPSHOT_VERSION: u32 = 1;
335
336#[derive(Debug, Clone)]
355pub struct SpeculativeEngineSnapshot {
356 pub target_snapshot: EngineSnapshot,
358 pub draft_snapshot: EngineSnapshot,
360 pub num_speculative: usize,
362 pub spec_seed: Option<u64>,
364 pub accepted_tokens: Vec<u32>,
366 pub rng_state: u64,
368}
369
370impl SpeculativeEngineSnapshot {
371 pub fn encode(&self) -> RuntimeResult<Vec<u8>> {
376 let target_bytes = self.target_snapshot.serialize()?;
378 let draft_bytes = self.draft_snapshot.serialize()?;
379
380 let seed_bytes = if self.spec_seed.is_some() {
384 9usize
385 } else {
386 1usize
387 };
388 let capacity = 8
389 + 4
390 + 8
391 + target_bytes.len()
392 + 8
393 + draft_bytes.len()
394 + 8
395 + seed_bytes
396 + 8
397 + self.accepted_tokens.len() * 4
398 + 8;
399
400 let mut buf: Vec<u8> = Vec::with_capacity(capacity);
401
402 buf.extend_from_slice(SPEC_SNAPSHOT_MAGIC);
404 buf.extend_from_slice(&SPEC_SNAPSHOT_VERSION.to_le_bytes());
405
406 buf.extend_from_slice(&(target_bytes.len() as u64).to_le_bytes());
408 buf.extend_from_slice(&target_bytes);
409
410 buf.extend_from_slice(&(draft_bytes.len() as u64).to_le_bytes());
412 buf.extend_from_slice(&draft_bytes);
413
414 buf.extend_from_slice(&(self.num_speculative as u64).to_le_bytes());
416
417 match self.spec_seed {
419 None => buf.push(0x00),
420 Some(seed) => {
421 buf.push(0x01);
422 buf.extend_from_slice(&seed.to_le_bytes());
423 }
424 }
425
426 buf.extend_from_slice(&(self.accepted_tokens.len() as u64).to_le_bytes());
428 for &tok in &self.accepted_tokens {
429 buf.extend_from_slice(&tok.to_le_bytes());
430 }
431
432 buf.extend_from_slice(&self.rng_state.to_le_bytes());
434
435 Ok(buf)
436 }
437
438 pub fn decode(bytes: &[u8]) -> RuntimeResult<Self> {
443 let mut pos = 0usize;
444
445 macro_rules! read_exact {
447 ($n:expr, $label:expr) => {{
448 let end = pos + $n;
449 if end > bytes.len() {
450 return Err(RuntimeError::SpecSnapshotIncompatible(format!(
451 "truncated: expected {} bytes for {} at offset {}",
452 $n, $label, pos
453 )));
454 }
455 let slice = &bytes[pos..end];
456 pos = end;
457 slice
458 }};
459 }
460
461 let magic = read_exact!(8, "magic");
463 if magic != SPEC_SNAPSHOT_MAGIC {
464 return Err(RuntimeError::SpecSnapshotIncompatible(format!(
465 "invalid magic bytes: expected {:?}, got {:?}",
466 SPEC_SNAPSHOT_MAGIC, magic
467 )));
468 }
469
470 let version = u32::from_le_bytes(
472 read_exact!(4, "version")
473 .try_into()
474 .expect("slice is exactly 4 bytes"),
475 );
476 if version != SPEC_SNAPSHOT_VERSION {
477 return Err(RuntimeError::SpecSnapshotIncompatible(format!(
478 "unsupported version {version} (expected {SPEC_SNAPSHOT_VERSION})"
479 )));
480 }
481
482 let target_len = u64::from_le_bytes(
484 read_exact!(8, "target_len")
485 .try_into()
486 .expect("slice is exactly 8 bytes"),
487 ) as usize;
488 let target_raw = read_exact!(target_len, "target_bytes");
489 let target_snapshot = EngineSnapshot::deserialize(target_raw).map_err(|e| {
490 RuntimeError::SpecSnapshotIncompatible(format!("target snapshot corrupt: {e}"))
491 })?;
492
493 let draft_len = u64::from_le_bytes(
495 read_exact!(8, "draft_len")
496 .try_into()
497 .expect("slice is exactly 8 bytes"),
498 ) as usize;
499 let draft_raw = read_exact!(draft_len, "draft_bytes");
500 let draft_snapshot = EngineSnapshot::deserialize(draft_raw).map_err(|e| {
501 RuntimeError::SpecSnapshotIncompatible(format!("draft snapshot corrupt: {e}"))
502 })?;
503
504 let num_speculative = u64::from_le_bytes(
506 read_exact!(8, "num_speculative")
507 .try_into()
508 .expect("slice is exactly 8 bytes"),
509 ) as usize;
510
511 let has_seed = read_exact!(1, "has_seed")[0];
513 let spec_seed = if has_seed == 0x01 {
514 let seed_bytes = read_exact!(8, "seed");
515 Some(u64::from_le_bytes(
516 seed_bytes.try_into().expect("slice is exactly 8 bytes"),
517 ))
518 } else {
519 None
520 };
521
522 let accepted_len = u64::from_le_bytes(
524 read_exact!(8, "accepted_len")
525 .try_into()
526 .expect("slice is exactly 8 bytes"),
527 ) as usize;
528 let mut accepted_tokens = Vec::with_capacity(accepted_len);
529 for _ in 0..accepted_len {
530 let tok = u32::from_le_bytes(
531 read_exact!(4, "accepted_token")
532 .try_into()
533 .expect("slice is exactly 4 bytes"),
534 );
535 accepted_tokens.push(tok);
536 }
537
538 let rng_state = u64::from_le_bytes(
540 read_exact!(8, "rng_state")
541 .try_into()
542 .expect("slice is exactly 8 bytes"),
543 );
544 let _ = pos;
546
547 Ok(Self {
548 target_snapshot,
549 draft_snapshot,
550 num_speculative,
551 spec_seed,
552 accepted_tokens,
553 rng_state,
554 })
555 }
556
557 pub fn fingerprint(&self) -> RuntimeResult<[u8; 32]> {
562 let encoded = self.encode()?;
563 Ok(*Hasher::new().update(&encoded).finalize().as_bytes())
564 }
565}
566
567impl InferenceEngine {
570 pub fn snapshot(&self) -> RuntimeResult<Vec<u8>> {
587 let model_config = self.model_config().ok_or(RuntimeError::ModelNotLoaded)?;
588 let kv_cache = self.kv_cache_ref().ok_or(RuntimeError::ModelNotLoaded)?;
589
590 let model_path = Path::new(self.config().model_path.as_str());
592 let model_fingerprint = ModelFingerprint::compute(model_path)?;
593
594 let sequence_state = SequenceStatePayload::Attention(kv_cache.to_payload());
596
597 let sampler_cfg = &self.config().sampler;
599 let sampler_state = SamplerStatePayload {
600 rng_state: sampler_cfg.seed.unwrap_or(0),
601 mirostat_mu: 2.0 * sampler_cfg.mirostat_tau,
602 temperature: sampler_cfg.temperature,
603 top_k: sampler_cfg.top_k,
604 top_p: sampler_cfg.top_p,
605 min_p: sampler_cfg.min_p,
606 repetition_penalty: sampler_cfg.repetition_penalty,
607 repetition_penalty_window: sampler_cfg.repetition_penalty_window,
608 seed: sampler_cfg.seed,
609 mirostat_mode: sampler_cfg.mirostat,
610 mirostat_tau: sampler_cfg.mirostat_tau,
611 mirostat_eta: sampler_cfg.mirostat_eta,
612 };
613
614 let grammar_state = sampler_cfg.grammar.as_ref().map(|g| GrammarStatePayload {
616 grammar_source: g.source.clone(),
617 });
618
619 let snap = EngineSnapshot {
620 magic: *SNAPSHOT_MAGIC,
621 version: EngineSnapshot::VERSION,
622 arch_id: model_config.architecture.clone(),
623 model_path: self.config().model_path.clone(),
624 tokenizer_path: self.config().tokenizer_path.clone(),
625 model_fingerprint,
626 tokens: Vec::new(), sequence_state,
628 sampler_state,
629 grammar_state,
630 max_context_length: model_config.max_context_length,
631 num_threads: self.config().num_threads,
632 prefill_chunk_size: self.config().prefill_chunk_size,
633 };
634
635 snap.serialize()
636 }
637
638 pub fn resume(bytes: &[u8], model_path: &Path) -> RuntimeResult<Self> {
653 use crate::sampling::grammar::Grammar;
654 use std::sync::Arc;
655
656 let snap = EngineSnapshot::deserialize(bytes)?;
657
658 snap.model_fingerprint.verify(model_path)?;
660
661 let mut sampler_config = SamplerConfig {
663 temperature: snap.sampler_state.temperature,
664 top_k: snap.sampler_state.top_k,
665 top_p: snap.sampler_state.top_p,
666 min_p: snap.sampler_state.min_p,
667 repetition_penalty: snap.sampler_state.repetition_penalty,
668 repetition_penalty_window: snap.sampler_state.repetition_penalty_window,
669 seed: snap.sampler_state.seed,
670 mirostat: snap.sampler_state.mirostat_mode,
671 mirostat_tau: snap.sampler_state.mirostat_tau,
672 mirostat_eta: snap.sampler_state.mirostat_eta,
673 grammar: None,
674 token_vocab: None,
675 logit_bias: std::collections::HashMap::new(),
678 banned_tokens: Vec::new(),
679 dry_multiplier: 0.0,
680 dry_base: 1.75,
681 dry_allowed_length: 2,
682 xtc_threshold: 0.0,
683 xtc_probability: 0.5,
684 typical_p: 1.0,
685 top_a: 0.0,
686 eta_cutoff: 0.0,
687 epsilon_cutoff: 0.0,
688 };
689
690 if let Some(gs) = &snap.grammar_state {
692 let grammar =
693 Grammar::parse(&gs.grammar_source).map_err(|e| RuntimeError::ModelLoadError {
694 message: format!("failed to re-parse grammar from snapshot: {e}"),
695 })?;
696 sampler_config.grammar = Some(Arc::new(grammar));
697 }
698
699 let config = EngineConfig {
700 model_path: model_path
701 .to_str()
702 .ok_or_else(|| RuntimeError::ModelLoadError {
703 message: "model path contains non-UTF-8 characters".to_string(),
704 })?
705 .to_string(),
706 tokenizer_path: snap.tokenizer_path.clone(),
707 context_size: Some(snap.max_context_length),
708 num_threads: snap.num_threads,
709 sampler: sampler_config,
710 prefill_chunk_size: snap.prefill_chunk_size,
711 offload_policy: crate::offload::OffloadPolicy::None,
712 };
713
714 let mut engine = Self::new(config);
715 engine.load_model()?;
716
717 if let SequenceStatePayload::Attention(kv_payload) = &snap.sequence_state {
719 let kv = engine.kv_cache_mut().ok_or(RuntimeError::ModelNotLoaded)?;
720 kv.restore_from_payload(kv_payload)?;
721 }
722
723 Ok(engine)
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730
731 fn make_minimal_snapshot() -> EngineSnapshot {
732 EngineSnapshot {
733 magic: *SNAPSHOT_MAGIC,
734 version: EngineSnapshot::VERSION,
735 arch_id: "llama".to_string(),
736 model_path: "/tmp/test.gguf".to_string(),
737 tokenizer_path: None,
738 model_fingerprint: ModelFingerprint {
739 file_size: 1024,
740 mtime_secs: 1_000_000,
741 head_hash: [0u8; 32],
742 tail_hash: [1u8; 32],
743 probe_size: DEFAULT_PROBE_SIZE,
744 },
745 tokens: vec![1, 2, 3],
746 sequence_state: SequenceStatePayload::Attention(KvStatePayload {
747 keys: vec![vec![0.0f32; 4]],
748 values: vec![vec![0.0f32; 4]],
749 seq_len: 1,
750 num_layers: 1,
751 max_seq_len: 512,
752 kv_dim: 4,
753 }),
754 sampler_state: SamplerStatePayload {
755 rng_state: 42,
756 mirostat_mu: 5.0,
757 temperature: 0.7,
758 top_k: 40,
759 top_p: 0.9,
760 min_p: 0.0,
761 repetition_penalty: 1.1,
762 repetition_penalty_window: 64,
763 seed: Some(42),
764 mirostat_mode: 0,
765 mirostat_tau: 5.0,
766 mirostat_eta: 0.1,
767 },
768 grammar_state: None,
769 max_context_length: 512,
770 num_threads: 4,
771 prefill_chunk_size: 512,
772 }
773 }
774
775 #[test]
776 fn roundtrip_serialize_deserialize() {
777 let snap = make_minimal_snapshot();
778 let bytes = snap.serialize().expect("serialize");
779 let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
780 assert_eq!(restored.arch_id, "llama");
781 assert_eq!(restored.tokens, vec![1, 2, 3]);
782 assert_eq!(restored.version, EngineSnapshot::VERSION);
783 assert_eq!(&restored.magic, SNAPSHOT_MAGIC);
784 }
785
786 #[test]
787 fn bad_magic_rejected() {
788 let snap = make_minimal_snapshot();
790 let mut bytes = snap.serialize().expect("serialize");
791 if bytes.len() > 4 {
794 bytes[0] ^= 0xFF;
795 }
796 let result = EngineSnapshot::deserialize(&bytes);
797 assert!(result.is_err(), "corrupted bytes must return Err");
798 }
799
800 #[test]
801 fn incompatible_version_rejected() {
802 let mut snap = make_minimal_snapshot();
804 snap.version = 9999;
805 let bytes = snap.serialize().expect("serialize");
806 let result = EngineSnapshot::deserialize(&bytes);
807 assert!(
808 matches!(result, Err(RuntimeError::SnapshotIncompatible { .. })),
809 "invalid version must return SnapshotIncompatible"
810 );
811 }
812
813 #[test]
814 fn model_fingerprint_compute_and_verify() {
815 let dir = std::env::temp_dir();
816 let path = dir.join("oxillama_snap_test_fingerprint.gguf");
817 std::fs::write(&path, vec![0xABu8; 100 * 1024]).expect("write test file");
818
819 let fp = ModelFingerprint::compute(&path).expect("compute fingerprint");
820 assert_eq!(fp.file_size, 100 * 1024);
821 fp.verify(&path).expect("verify same file");
822
823 std::fs::write(&path, vec![0xCDu8; 100 * 1024]).expect("write modified file");
825 assert!(
826 fp.verify(&path).is_err(),
827 "fingerprint verify must fail after file modification"
828 );
829
830 let _ = std::fs::remove_file(&path);
831 }
832
833 #[test]
834 fn fingerprint_mismatch_error_type() {
835 let dir = std::env::temp_dir();
836 let path_a = dir.join("oxillama_snap_fp_a.gguf");
837 let path_b = dir.join("oxillama_snap_fp_b.gguf");
838 std::fs::write(&path_a, vec![0xAAu8; 10_000]).expect("write A");
839 std::fs::write(&path_b, vec![0xBBu8; 10_000]).expect("write B");
840
841 let fp_a = ModelFingerprint::compute(&path_a).expect("compute A");
842 let result = fp_a.verify(&path_b);
843 assert!(
844 matches!(result, Err(RuntimeError::ModelFingerprintMismatch { .. })),
845 "mismatch must return ModelFingerprintMismatch"
846 );
847
848 let _ = std::fs::remove_file(&path_a);
849 let _ = std::fs::remove_file(&path_b);
850 }
851
852 #[test]
853 fn kv_state_payload_roundtrip_in_snapshot() {
854 let kv = KvStatePayload {
855 keys: vec![vec![1.0f32, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
856 values: vec![vec![9.0f32, 10.0, 11.0, 12.0], vec![13.0, 14.0, 15.0, 16.0]],
857 seq_len: 1,
858 num_layers: 2,
859 max_seq_len: 512,
860 kv_dim: 4,
861 };
862 let mut snap = make_minimal_snapshot();
863 snap.sequence_state = SequenceStatePayload::Attention(kv.clone());
864
865 let bytes = snap.serialize().expect("serialize");
866 let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
867
868 if let SequenceStatePayload::Attention(restored_kv) = restored.sequence_state {
869 assert_eq!(restored_kv.keys, kv.keys);
870 assert_eq!(restored_kv.values, kv.values);
871 assert_eq!(restored_kv.seq_len, kv.seq_len);
872 assert_eq!(restored_kv.num_layers, kv.num_layers);
873 } else {
874 panic!("expected Attention sequence state payload");
875 }
876 }
877
878 fn make_spec_snapshot(accepted: Vec<u32>, rng_state: u64) -> SpeculativeEngineSnapshot {
881 SpeculativeEngineSnapshot {
882 target_snapshot: make_minimal_snapshot(),
883 draft_snapshot: make_minimal_snapshot(),
884 num_speculative: 4,
885 spec_seed: Some(0xdeadbeef),
886 accepted_tokens: accepted,
887 rng_state,
888 }
889 }
890
891 #[test]
893 fn spec_snapshot_roundtrip() {
894 let original = make_spec_snapshot(vec![10u32, 20, 30], 0x00c0_ffee_cafe_babe_u64);
895 let bytes = original.encode().expect("encode must succeed");
896 let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode must succeed");
897
898 assert_eq!(restored.num_speculative, 4);
899 assert_eq!(restored.spec_seed, Some(0xdeadbeef));
900 assert_eq!(restored.accepted_tokens, vec![10u32, 20, 30]);
901 assert_eq!(restored.rng_state, 0x00c0_ffee_cafe_babe_u64);
902 assert_eq!(restored.target_snapshot.arch_id, "llama");
903 assert_eq!(restored.draft_snapshot.arch_id, "llama");
904 }
905
906 #[test]
908 fn spec_snapshot_rejects_wrong_magic() {
909 let snap = make_spec_snapshot(vec![], 42);
910 let mut bytes = snap.encode().expect("encode");
911 if bytes.len() >= 8 {
913 bytes[0] ^= 0xFF;
914 }
915 let result = SpeculativeEngineSnapshot::decode(&bytes);
916 assert!(
917 matches!(result, Err(RuntimeError::SpecSnapshotIncompatible(_))),
918 "wrong magic must return SpecSnapshotIncompatible, got {result:?}"
919 );
920 }
921
922 #[test]
924 fn spec_snapshot_rejects_truncated() {
925 let snap = make_spec_snapshot(vec![1u32, 2], 99);
926 let bytes = snap.encode().expect("encode");
927 let truncated = &bytes[..12.min(bytes.len())];
929 let result = SpeculativeEngineSnapshot::decode(truncated);
930 assert!(result.is_err(), "truncated bytes must return Err, got Ok");
931 }
932
933 #[test]
935 fn spec_snapshot_preserves_accepted_history() {
936 let history = vec![1u32, 2, 3, 4, 5, 100, 200, 65535];
937 let snap = make_spec_snapshot(history.clone(), 0);
938 let bytes = snap.encode().expect("encode");
939 let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
940 assert_eq!(
941 restored.accepted_tokens, history,
942 "accepted token history must be identical after roundtrip"
943 );
944 }
945
946 #[test]
948 fn spec_snapshot_none_seed_roundtrip() {
949 let mut snap = make_spec_snapshot(vec![], 7);
950 snap.spec_seed = None;
951 let bytes = snap.encode().expect("encode");
952 let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
953 assert!(
954 restored.spec_seed.is_none(),
955 "None seed must round-trip as None"
956 );
957 }
958
959 #[test]
961 fn spec_snapshot_fingerprint_is_deterministic() {
962 let snap = make_spec_snapshot(vec![42u32], 0xbeef);
963 let fp1 = snap.fingerprint().expect("fingerprint 1");
964 let fp2 = snap.fingerprint().expect("fingerprint 2");
965 assert_eq!(fp1, fp2, "fingerprint must be deterministic");
966 }
967}