Skip to main content

oxillama_runtime/
snapshot.rs

1//! Snapshot and resume for [`crate::engine::InferenceEngine`] sessions.
2//!
3//! A snapshot serializes the complete live state of an inference session into
4//! a portable opaque byte blob. The blob can be stored on disk, transferred
5//! over a network, or embedded in a database, then later deserialized to
6//! resume inference deterministically from the same position.
7//!
8//! ## What is captured
9//!
10//! - All tokens generated so far (used to reconstruct context position).
11//! - KV cache state for attention-based models, or SSM hidden states for
12//!   Mamba-based models.
13//! - Sampler RNG state and mirostat-v2 mu value for deterministic resumption.
14//! - Sampler configuration (temperature, top-k/p, etc.).
15//! - Grammar source string (if constrained sampling is active). On resume the
16//!   grammar state is reset to initial — this is a known limitation.
17//! - A model fingerprint that guards against loading the wrong weights file.
18//! - The architecture identifier and model path.
19//!
20//! ## Format
21//!
22//! Snapshots begin with the 8-byte magic `b"OXISNAP1"` and carry a version
23//! field. The rest is serialized using `oxicode`. Unknown future versions are
24//! rejected rather than silently misinterpreted.
25
26use std::io::{Read, Seek, SeekFrom};
27use std::path::Path;
28
29use blake3::Hasher;
30use oxicode::{Decode, Encode};
31
32use crate::engine::{EngineConfig, InferenceEngine};
33use crate::error::{RuntimeError, RuntimeResult};
34use crate::sampling::SamplerConfig;
35
36/// Magic bytes at the start of every snapshot.
37pub const SNAPSHOT_MAGIC: &[u8; 8] = b"OXISNAP1";
38
39/// Default probe size: 8 MiB from head + 8 MiB from tail.
40const DEFAULT_PROBE_SIZE: u32 = 8 * 1024 * 1024;
41
42// ─── ModelFingerprint ────────────────────────────────────────────────────────
43
44/// Bounded O(constant) fingerprint of a GGUF model file.
45///
46/// Avoids the O(file-size) cost of hashing the whole file by reading only
47/// `probe_size` bytes from the head and `probe_size` bytes from the tail,
48/// then hashing each block independently.  The combination of file size,
49/// modification time, and the two content hashes is sufficient to detect
50/// truncation, replacement, or in-place modification of any real GGUF file
51/// while capping I/O at `2 * probe_size` bytes regardless of model size.
52#[derive(Debug, Clone, PartialEq, Encode, Decode)]
53pub struct ModelFingerprint {
54    /// Total file size in bytes.
55    pub file_size: u64,
56    /// File mtime as Unix seconds (best-effort; 0 if unavailable).
57    pub mtime_secs: i64,
58    /// Blake3 hash of the first `probe_size` bytes.
59    pub head_hash: [u8; 32],
60    /// Blake3 hash of the last `probe_size` bytes.
61    pub tail_hash: [u8; 32],
62    /// Number of bytes probed from each end of the file.
63    pub probe_size: u32,
64}
65
66impl ModelFingerprint {
67    /// Compute a fingerprint for the file at `path`.
68    ///
69    /// Reads at most `2 * DEFAULT_PROBE_SIZE` bytes in total.
70    pub fn compute(path: &Path) -> RuntimeResult<Self> {
71        Self::compute_with_probe(path, DEFAULT_PROBE_SIZE)
72    }
73
74    /// Compute a fingerprint with a custom probe size.
75    pub fn compute_with_probe(path: &Path, probe_size: u32) -> RuntimeResult<Self> {
76        let mut file = std::fs::File::open(path)?;
77        let metadata = file.metadata()?;
78        let file_size = metadata.len();
79
80        // Extract mtime as unix seconds (platform-dependent, best-effort).
81        let mtime_secs = {
82            use std::time::SystemTime;
83            metadata
84                .modified()
85                .ok()
86                .and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
87                .map(|d| d.as_secs() as i64)
88                .unwrap_or(0)
89        };
90
91        // --- Head hash ---
92        let head_read = (probe_size as u64).min(file_size) as usize;
93        let mut head_buf = vec![0u8; head_read];
94        file.seek(SeekFrom::Start(0))?;
95        file.read_exact(&mut head_buf)?;
96        let head_hash: [u8; 32] = *Hasher::new().update(&head_buf).finalize().as_bytes();
97
98        // --- Tail hash ---
99        // If the file is smaller than 2 * probe_size the head and tail overlap —
100        // that is intentional and still produces a valid fingerprint.
101        let tail_start = file_size.saturating_sub(probe_size as u64);
102        let tail_read = (file_size - tail_start) as usize;
103        let mut tail_buf = vec![0u8; tail_read];
104        file.seek(SeekFrom::Start(tail_start))?;
105        file.read_exact(&mut tail_buf)?;
106        let tail_hash: [u8; 32] = *Hasher::new().update(&tail_buf).finalize().as_bytes();
107
108        Ok(Self {
109            file_size,
110            mtime_secs,
111            head_hash,
112            tail_hash,
113            probe_size,
114        })
115    }
116
117    /// Verify that `path` matches this fingerprint.
118    ///
119    /// Returns `Ok(())` if the file matches, or a
120    /// [`RuntimeError::ModelFingerprintMismatch`] if it does not.
121    pub fn verify(&self, path: &Path) -> RuntimeResult<()> {
122        let actual = Self::compute_with_probe(path, self.probe_size)?;
123        if actual == *self {
124            return Ok(());
125        }
126        Err(RuntimeError::ModelFingerprintMismatch {
127            expected: self.display(),
128            found: actual.display(),
129            detail: format!(
130                "model file '{}' has been modified or replaced since the snapshot was taken",
131                path.display()
132            ),
133        })
134    }
135
136    /// Human-readable display string for error messages.
137    pub fn display(&self) -> String {
138        let head_hex: String = self.head_hash.iter().map(|b| format!("{b:02x}")).collect();
139        let tail_hex: String = self.tail_hash.iter().map(|b| format!("{b:02x}")).collect();
140        format!(
141            "size={} mtime={} head={}...{} tail={}...{}",
142            self.file_size,
143            self.mtime_secs,
144            &head_hex[..8],
145            &head_hex[head_hex.len() - 8..],
146            &tail_hex[..8],
147            &tail_hex[tail_hex.len() - 8..],
148        )
149    }
150}
151
152// ─── KvStatePayload ──────────────────────────────────────────────────────────
153
154/// Serializable KV cache state for attention-based models.
155#[derive(Debug, Clone, Encode, Decode)]
156pub struct KvStatePayload {
157    /// Per-layer key vectors (compact: only up to `seq_len * kv_dim` floats).
158    pub keys: Vec<Vec<f32>>,
159    /// Per-layer value vectors.
160    pub values: Vec<Vec<f32>>,
161    /// Sequence length at snapshot time.
162    pub seq_len: usize,
163    /// Number of transformer layers.
164    pub num_layers: usize,
165    /// Maximum context length the cache was allocated for.
166    pub max_seq_len: usize,
167    /// KV dimension per token (num_kv_heads × head_dim).
168    pub kv_dim: usize,
169}
170
171// ─── SsmStatePayload ─────────────────────────────────────────────────────────
172
173/// Serializable SSM recurrent state for Mamba-2 / Jamba models.
174#[derive(Debug, Clone, Encode, Decode)]
175pub struct SsmStatePayload {
176    /// Per-layer flattened hidden state vectors.
177    /// For Jamba, attention layers have an empty inner vec.
178    pub ssm_states: Vec<Vec<f32>>,
179    /// Current token step position.
180    pub step: usize,
181}
182
183// ─── SequenceStatePayload ────────────────────────────────────────────────────
184
185/// Union of all possible sequence state variants for serialization.
186///
187/// The runtime's `EngineSnapshot` carries one of these.  It maps to
188/// `SequenceStateSnapshot` in the arch crate for in-process use, but this
189/// type adds `Encode + Decode` for wire persistence.
190#[derive(Debug, Clone, Encode, Decode)]
191pub enum SequenceStatePayload {
192    /// Attention-based (LLaMA, Qwen3, Mistral, Gemma, Phi, …).
193    Attention(KvStatePayload),
194    /// Pure Mamba-2 SSM.
195    Mamba2(SsmStatePayload),
196    /// Jamba hybrid: both KV attention positions and SSM states.
197    Jamba {
198        /// KV attention state.
199        attention: KvStatePayload,
200        /// SSM recurrent state.
201        ssm: SsmStatePayload,
202    },
203}
204
205// ─── SamplerStatePayload ─────────────────────────────────────────────────────
206
207/// Serializable sampler state for snapshot/resume.
208#[derive(Debug, Clone, Encode, Decode)]
209pub struct SamplerStatePayload {
210    /// Raw Xorshift64 PRNG state (0 is remapped to 1 on restore).
211    pub rng_state: u64,
212    /// Mirostat-v2 running surprise estimate (mu).
213    pub mirostat_mu: f32,
214    /// Temperature for logit scaling.
215    pub temperature: f32,
216    /// Top-K (0 = disabled).
217    pub top_k: usize,
218    /// Top-P / nucleus threshold.
219    pub top_p: f32,
220    /// Min-P threshold.
221    pub min_p: f32,
222    /// Repetition penalty factor (1.0 = no penalty).
223    pub repetition_penalty: f32,
224    /// Window size for repetition penalty.
225    pub repetition_penalty_window: usize,
226    /// Optional fixed RNG seed.
227    pub seed: Option<u64>,
228    /// Mirostat mode: 0 = disabled, 2 = Mirostat v2.
229    pub mirostat_mode: u8,
230    /// Mirostat target surprise (tau).
231    pub mirostat_tau: f32,
232    /// Mirostat learning rate (eta).
233    pub mirostat_eta: f32,
234}
235
236// ─── GrammarStatePayload ─────────────────────────────────────────────────────
237
238/// Serializable grammar state.
239///
240/// Only the grammar source is stored.  On resume the grammar is re-parsed
241/// and the state is reset to the initial state.  This is a known limitation:
242/// partial grammar progress from before the snapshot is not replayed.
243#[derive(Debug, Clone, Encode, Decode)]
244pub struct GrammarStatePayload {
245    /// Original GBNF grammar source string.
246    pub grammar_source: String,
247}
248
249// ─── EngineSnapshot ──────────────────────────────────────────────────────────
250
251/// The complete engine snapshot — opaque to callers outside this module.
252///
253/// Callers should treat the serialized form as opaque bytes: construct via
254/// `InferenceEngine::snapshot()`, persist however is appropriate, then pass
255/// the bytes to `InferenceEngine::resume()`.
256#[derive(Debug, Clone, Encode, Decode)]
257pub struct EngineSnapshot {
258    /// Magic bytes: must equal `SNAPSHOT_MAGIC`.
259    pub magic: [u8; 8],
260    /// Format version. Current: [`EngineSnapshot::VERSION`].
261    pub version: u32,
262    /// Architecture identifier (e.g. `"llama"`, `"qwen3"`, …).
263    pub arch_id: String,
264    /// Absolute path to the model file at snapshot time.
265    pub model_path: String,
266    /// Optional explicit tokenizer path (None = auto-detect).
267    pub tokenizer_path: Option<String>,
268    /// Bounded fingerprint of the model file.
269    pub model_fingerprint: ModelFingerprint,
270    /// All token IDs processed so far (prompt + generated).
271    pub tokens: Vec<u32>,
272    /// Sequence / KV state at snapshot time.
273    pub sequence_state: SequenceStatePayload,
274    /// Sampler state at snapshot time.
275    pub sampler_state: SamplerStatePayload,
276    /// Optional grammar state (None when no grammar is configured).
277    pub grammar_state: Option<GrammarStatePayload>,
278    /// Maximum context length the engine was configured with.
279    pub max_context_length: usize,
280    /// Number of parallel inference threads.
281    pub num_threads: usize,
282    /// Prefill chunk size.
283    pub prefill_chunk_size: usize,
284}
285
286impl EngineSnapshot {
287    /// Current snapshot format version.
288    pub const VERSION: u32 = 1;
289
290    /// Serialize this snapshot to bytes using oxicode.
291    pub fn serialize(&self) -> RuntimeResult<Vec<u8>> {
292        oxicode::encode_to_vec(self).map_err(|e| RuntimeError::SnapshotIncompatible {
293            detail: format!("serialization failed: {e}"),
294        })
295    }
296
297    /// Deserialize a snapshot from bytes.
298    ///
299    /// Returns `SnapshotIncompatible` if the bytes cannot be decoded, the
300    /// magic is wrong, or the version is not supported.
301    pub fn deserialize(bytes: &[u8]) -> RuntimeResult<Self> {
302        let (snap, _) = oxicode::decode_from_slice::<Self>(bytes).map_err(|e| {
303            RuntimeError::SnapshotIncompatible {
304                detail: format!("deserialization failed: {e}"),
305            }
306        })?;
307
308        if &snap.magic != SNAPSHOT_MAGIC {
309            return Err(RuntimeError::SnapshotIncompatible {
310                detail: "invalid snapshot magic bytes".to_string(),
311            });
312        }
313
314        if snap.version != Self::VERSION {
315            return Err(RuntimeError::SnapshotIncompatible {
316                detail: format!(
317                    "snapshot version {} is not supported (expected {})",
318                    snap.version,
319                    Self::VERSION
320                ),
321            });
322        }
323
324        Ok(snap)
325    }
326}
327
328// ─── SpeculativeEngineSnapshot ───────────────────────────────────────────────
329
330/// Magic bytes at the start of every speculative-engine snapshot.
331pub const SPEC_SNAPSHOT_MAGIC: &[u8; 8] = b"OXISPEC1";
332
333/// Version number for the `SpeculativeEngineSnapshot` wire format.
334const SPEC_SNAPSHOT_VERSION: u32 = 1;
335
336/// Portable snapshot of a complete [`crate::speculative::SpeculativeEngine`] session.
337///
338/// Contains individual [`EngineSnapshot`]s for both the target and draft models,
339/// plus the speculative-decoding loop state needed to resume deterministically.
340///
341/// ## Wire format
342///
343/// ```text
344/// [magic: 8 bytes][version: u32 LE][target_len: u64 LE][target_bytes: ...]
345/// [draft_len: u64 LE][draft_bytes: ...]
346/// [num_speculative: u64 LE][has_seed: u8][seed: u64 LE (if has_seed)]
347/// [accepted_len: u64 LE][accepted_tokens: u32 LE × accepted_len]
348/// [rng_state: u64 LE]
349/// ```
350///
351/// All multibyte integers are little-endian.  Neither `oxicode` nor `bincode`
352/// is used for the outer envelope so that the magic header can be verified
353/// before any heap allocation.
354#[derive(Debug, Clone)]
355pub struct SpeculativeEngineSnapshot {
356    /// Snapshot of the target (large, accurate) model session.
357    pub target_snapshot: EngineSnapshot,
358    /// Snapshot of the draft (small, fast) model session.
359    pub draft_snapshot: EngineSnapshot,
360    /// Number of speculative tokens proposed per round.
361    pub num_speculative: usize,
362    /// RNG seed that was used to initialise the accept/reject PRNG.
363    pub spec_seed: Option<u64>,
364    /// Token IDs accepted during the last speculation round (may be empty).
365    pub accepted_tokens: Vec<u32>,
366    /// Raw Xorshift64 state for the accept/reject PRNG.
367    pub rng_state: u64,
368}
369
370impl SpeculativeEngineSnapshot {
371    /// Encode this snapshot into a self-describing byte blob.
372    ///
373    /// The blob starts with [`SPEC_SNAPSHOT_MAGIC`] and can be decoded by
374    /// [`Self::decode`].
375    pub fn encode(&self) -> RuntimeResult<Vec<u8>> {
376        // Serialise the two inner snapshots first so we know their lengths.
377        let target_bytes = self.target_snapshot.serialize()?;
378        let draft_bytes = self.draft_snapshot.serialize()?;
379
380        // Pre-compute capacity: magic(8) + version(4) + target_len(8) + target
381        //   + draft_len(8) + draft + num_spec(8) + has_seed(1) + [seed(8)]
382        //   + accepted_len(8) + accepted * 4 + rng_state(8)
383        let seed_bytes = if self.spec_seed.is_some() {
384            9usize
385        } else {
386            1usize
387        };
388        let capacity = 8
389            + 4
390            + 8
391            + target_bytes.len()
392            + 8
393            + draft_bytes.len()
394            + 8
395            + seed_bytes
396            + 8
397            + self.accepted_tokens.len() * 4
398            + 8;
399
400        let mut buf: Vec<u8> = Vec::with_capacity(capacity);
401
402        // Magic + version
403        buf.extend_from_slice(SPEC_SNAPSHOT_MAGIC);
404        buf.extend_from_slice(&SPEC_SNAPSHOT_VERSION.to_le_bytes());
405
406        // Target snapshot (length-prefixed)
407        buf.extend_from_slice(&(target_bytes.len() as u64).to_le_bytes());
408        buf.extend_from_slice(&target_bytes);
409
410        // Draft snapshot (length-prefixed)
411        buf.extend_from_slice(&(draft_bytes.len() as u64).to_le_bytes());
412        buf.extend_from_slice(&draft_bytes);
413
414        // num_speculative
415        buf.extend_from_slice(&(self.num_speculative as u64).to_le_bytes());
416
417        // Optional seed: 0x00 = absent, 0x01 followed by 8 bytes = present
418        match self.spec_seed {
419            None => buf.push(0x00),
420            Some(seed) => {
421                buf.push(0x01);
422                buf.extend_from_slice(&seed.to_le_bytes());
423            }
424        }
425
426        // accepted_tokens (length-prefixed, each token as u32 LE)
427        buf.extend_from_slice(&(self.accepted_tokens.len() as u64).to_le_bytes());
428        for &tok in &self.accepted_tokens {
429            buf.extend_from_slice(&tok.to_le_bytes());
430        }
431
432        // rng_state
433        buf.extend_from_slice(&self.rng_state.to_le_bytes());
434
435        Ok(buf)
436    }
437
438    /// Decode a [`SpeculativeEngineSnapshot`] from raw bytes.
439    ///
440    /// Returns [`RuntimeError::SpecSnapshotIncompatible`] when the magic bytes
441    /// are wrong, the version is unsupported, or the buffer is truncated.
442    pub fn decode(bytes: &[u8]) -> RuntimeResult<Self> {
443        let mut pos = 0usize;
444
445        /// Read `N` bytes from `bytes` starting at `*pos`, advancing `*pos`.
446        macro_rules! read_exact {
447            ($n:expr, $label:expr) => {{
448                let end = pos + $n;
449                if end > bytes.len() {
450                    return Err(RuntimeError::SpecSnapshotIncompatible(format!(
451                        "truncated: expected {} bytes for {} at offset {}",
452                        $n, $label, pos
453                    )));
454                }
455                let slice = &bytes[pos..end];
456                pos = end;
457                slice
458            }};
459        }
460
461        // Magic
462        let magic = read_exact!(8, "magic");
463        if magic != SPEC_SNAPSHOT_MAGIC {
464            return Err(RuntimeError::SpecSnapshotIncompatible(format!(
465                "invalid magic bytes: expected {:?}, got {:?}",
466                SPEC_SNAPSHOT_MAGIC, magic
467            )));
468        }
469
470        // Version
471        let version = u32::from_le_bytes(
472            read_exact!(4, "version")
473                .try_into()
474                .expect("slice is exactly 4 bytes"),
475        );
476        if version != SPEC_SNAPSHOT_VERSION {
477            return Err(RuntimeError::SpecSnapshotIncompatible(format!(
478                "unsupported version {version} (expected {SPEC_SNAPSHOT_VERSION})"
479            )));
480        }
481
482        // Target snapshot
483        let target_len = u64::from_le_bytes(
484            read_exact!(8, "target_len")
485                .try_into()
486                .expect("slice is exactly 8 bytes"),
487        ) as usize;
488        let target_raw = read_exact!(target_len, "target_bytes");
489        let target_snapshot = EngineSnapshot::deserialize(target_raw).map_err(|e| {
490            RuntimeError::SpecSnapshotIncompatible(format!("target snapshot corrupt: {e}"))
491        })?;
492
493        // Draft snapshot
494        let draft_len = u64::from_le_bytes(
495            read_exact!(8, "draft_len")
496                .try_into()
497                .expect("slice is exactly 8 bytes"),
498        ) as usize;
499        let draft_raw = read_exact!(draft_len, "draft_bytes");
500        let draft_snapshot = EngineSnapshot::deserialize(draft_raw).map_err(|e| {
501            RuntimeError::SpecSnapshotIncompatible(format!("draft snapshot corrupt: {e}"))
502        })?;
503
504        // num_speculative
505        let num_speculative = u64::from_le_bytes(
506            read_exact!(8, "num_speculative")
507                .try_into()
508                .expect("slice is exactly 8 bytes"),
509        ) as usize;
510
511        // Optional seed
512        let has_seed = read_exact!(1, "has_seed")[0];
513        let spec_seed = if has_seed == 0x01 {
514            let seed_bytes = read_exact!(8, "seed");
515            Some(u64::from_le_bytes(
516                seed_bytes.try_into().expect("slice is exactly 8 bytes"),
517            ))
518        } else {
519            None
520        };
521
522        // accepted_tokens
523        let accepted_len = u64::from_le_bytes(
524            read_exact!(8, "accepted_len")
525                .try_into()
526                .expect("slice is exactly 8 bytes"),
527        ) as usize;
528        let mut accepted_tokens = Vec::with_capacity(accepted_len);
529        for _ in 0..accepted_len {
530            let tok = u32::from_le_bytes(
531                read_exact!(4, "accepted_token")
532                    .try_into()
533                    .expect("slice is exactly 4 bytes"),
534            );
535            accepted_tokens.push(tok);
536        }
537
538        // rng_state — last field; pos is advanced but not read after this.
539        let rng_state = u64::from_le_bytes(
540            read_exact!(8, "rng_state")
541                .try_into()
542                .expect("slice is exactly 8 bytes"),
543        );
544        // Suppress unused-assignment lint: pos is consumed on the last read.
545        let _ = pos;
546
547        Ok(Self {
548            target_snapshot,
549            draft_snapshot,
550            num_speculative,
551            spec_seed,
552            accepted_tokens,
553            rng_state,
554        })
555    }
556
557    /// Compute a 32-byte Blake3 fingerprint of the encoded snapshot bytes.
558    ///
559    /// Useful for deduplication and integrity checks without fully decoding
560    /// the snapshot.
561    pub fn fingerprint(&self) -> RuntimeResult<[u8; 32]> {
562        let encoded = self.encode()?;
563        Ok(*Hasher::new().update(&encoded).finalize().as_bytes())
564    }
565}
566
567// ─── InferenceEngine snapshot / resume ───────────────────────────────────────
568
569impl InferenceEngine {
570    /// Capture the full engine state as a portable byte blob.
571    ///
572    /// The returned bytes can be stored on disk, sent over the network, or
573    /// embedded in a database. Pass them to [`InferenceEngine::resume`] to
574    /// resume inference from the same position.
575    ///
576    /// # Limitations
577    ///
578    /// - **Grammar state**: only the grammar source string is stored. On
579    ///   resume the grammar state is reset to its initial state — any partial
580    ///   progress through a grammar constraint is lost.
581    /// - **Sampler state**: the engine creates a new `Sampler` for each
582    ///   `generate()` call. The snapshot captures the config values rather
583    ///   than live RNG state from an in-flight generation.
584    ///
585    /// Returns [`RuntimeError::ModelNotLoaded`] if no model has been loaded.
586    pub fn snapshot(&self) -> RuntimeResult<Vec<u8>> {
587        let model_config = self.model_config().ok_or(RuntimeError::ModelNotLoaded)?;
588        let kv_cache = self.kv_cache_ref().ok_or(RuntimeError::ModelNotLoaded)?;
589
590        // Compute model fingerprint from file on disk.
591        let model_path = Path::new(self.config().model_path.as_str());
592        let model_fingerprint = ModelFingerprint::compute(model_path)?;
593
594        // Build KV state payload.
595        let sequence_state = SequenceStatePayload::Attention(kv_cache.to_payload());
596
597        // Build sampler state from config (engine-level snapshot; live RNG is per-generate).
598        let sampler_cfg = &self.config().sampler;
599        let sampler_state = SamplerStatePayload {
600            rng_state: sampler_cfg.seed.unwrap_or(0),
601            mirostat_mu: 2.0 * sampler_cfg.mirostat_tau,
602            temperature: sampler_cfg.temperature,
603            top_k: sampler_cfg.top_k,
604            top_p: sampler_cfg.top_p,
605            min_p: sampler_cfg.min_p,
606            repetition_penalty: sampler_cfg.repetition_penalty,
607            repetition_penalty_window: sampler_cfg.repetition_penalty_window,
608            seed: sampler_cfg.seed,
609            mirostat_mode: sampler_cfg.mirostat,
610            mirostat_tau: sampler_cfg.mirostat_tau,
611            mirostat_eta: sampler_cfg.mirostat_eta,
612        };
613
614        // Extract grammar source if configured.
615        let grammar_state = sampler_cfg.grammar.as_ref().map(|g| GrammarStatePayload {
616            grammar_source: g.source.clone(),
617        });
618
619        let snap = EngineSnapshot {
620            magic: *SNAPSHOT_MAGIC,
621            version: EngineSnapshot::VERSION,
622            arch_id: model_config.architecture.clone(),
623            model_path: self.config().model_path.clone(),
624            tokenizer_path: self.config().tokenizer_path.clone(),
625            model_fingerprint,
626            tokens: Vec::new(), // token history is not tracked at engine level
627            sequence_state,
628            sampler_state,
629            grammar_state,
630            max_context_length: model_config.max_context_length,
631            num_threads: self.config().num_threads,
632            prefill_chunk_size: self.config().prefill_chunk_size,
633        };
634
635        snap.serialize()
636    }
637
638    /// Resume an inference session from a previously captured snapshot.
639    ///
640    /// 1. Deserializes the snapshot bytes.
641    /// 2. Validates the model fingerprint against `model_path` on disk.
642    /// 3. Loads the model from `model_path`.
643    /// 4. Restores the KV cache state.
644    /// 5. Restores the sampler config.
645    /// 6. If a grammar source was saved, re-parses it (grammar state is reset to initial).
646    ///
647    /// # Errors
648    ///
649    /// - [`RuntimeError::SnapshotIncompatible`] — bytes are not a valid snapshot.
650    /// - [`RuntimeError::ModelFingerprintMismatch`] — model file differs from snapshot.
651    /// - Any error from loading the model.
652    pub fn resume(bytes: &[u8], model_path: &Path) -> RuntimeResult<Self> {
653        use crate::sampling::grammar::Grammar;
654        use std::sync::Arc;
655
656        let snap = EngineSnapshot::deserialize(bytes)?;
657
658        // Validate the model on disk matches the fingerprint.
659        snap.model_fingerprint.verify(model_path)?;
660
661        // Build SamplerConfig from snapshot.
662        let mut sampler_config = SamplerConfig {
663            temperature: snap.sampler_state.temperature,
664            top_k: snap.sampler_state.top_k,
665            top_p: snap.sampler_state.top_p,
666            min_p: snap.sampler_state.min_p,
667            repetition_penalty: snap.sampler_state.repetition_penalty,
668            repetition_penalty_window: snap.sampler_state.repetition_penalty_window,
669            seed: snap.sampler_state.seed,
670            mirostat: snap.sampler_state.mirostat_mode,
671            mirostat_tau: snap.sampler_state.mirostat_tau,
672            mirostat_eta: snap.sampler_state.mirostat_eta,
673            grammar: None,
674            token_vocab: None,
675            // Logit bias, banned tokens, and advanced sampler stages are not
676            // persisted in v1 snapshots; they default to empty/disabled.
677            logit_bias: std::collections::HashMap::new(),
678            banned_tokens: Vec::new(),
679            dry_multiplier: 0.0,
680            dry_base: 1.75,
681            dry_allowed_length: 2,
682            xtc_threshold: 0.0,
683            xtc_probability: 0.5,
684            typical_p: 1.0,
685            top_a: 0.0,
686            eta_cutoff: 0.0,
687            epsilon_cutoff: 0.0,
688        };
689
690        // Re-parse grammar if present (state resets to initial — known limitation).
691        if let Some(gs) = &snap.grammar_state {
692            let grammar =
693                Grammar::parse(&gs.grammar_source).map_err(|e| RuntimeError::ModelLoadError {
694                    message: format!("failed to re-parse grammar from snapshot: {e}"),
695                })?;
696            sampler_config.grammar = Some(Arc::new(grammar));
697        }
698
699        let config = EngineConfig {
700            model_path: model_path
701                .to_str()
702                .ok_or_else(|| RuntimeError::ModelLoadError {
703                    message: "model path contains non-UTF-8 characters".to_string(),
704                })?
705                .to_string(),
706            tokenizer_path: snap.tokenizer_path.clone(),
707            context_size: Some(snap.max_context_length),
708            num_threads: snap.num_threads,
709            sampler: sampler_config,
710            prefill_chunk_size: snap.prefill_chunk_size,
711            offload_policy: crate::offload::OffloadPolicy::None,
712        };
713
714        let mut engine = Self::new(config);
715        engine.load_model()?;
716
717        // Restore KV cache state.
718        if let SequenceStatePayload::Attention(kv_payload) = &snap.sequence_state {
719            let kv = engine.kv_cache_mut().ok_or(RuntimeError::ModelNotLoaded)?;
720            kv.restore_from_payload(kv_payload)?;
721        }
722
723        Ok(engine)
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730
731    fn make_minimal_snapshot() -> EngineSnapshot {
732        EngineSnapshot {
733            magic: *SNAPSHOT_MAGIC,
734            version: EngineSnapshot::VERSION,
735            arch_id: "llama".to_string(),
736            model_path: "/tmp/test.gguf".to_string(),
737            tokenizer_path: None,
738            model_fingerprint: ModelFingerprint {
739                file_size: 1024,
740                mtime_secs: 1_000_000,
741                head_hash: [0u8; 32],
742                tail_hash: [1u8; 32],
743                probe_size: DEFAULT_PROBE_SIZE,
744            },
745            tokens: vec![1, 2, 3],
746            sequence_state: SequenceStatePayload::Attention(KvStatePayload {
747                keys: vec![vec![0.0f32; 4]],
748                values: vec![vec![0.0f32; 4]],
749                seq_len: 1,
750                num_layers: 1,
751                max_seq_len: 512,
752                kv_dim: 4,
753            }),
754            sampler_state: SamplerStatePayload {
755                rng_state: 42,
756                mirostat_mu: 5.0,
757                temperature: 0.7,
758                top_k: 40,
759                top_p: 0.9,
760                min_p: 0.0,
761                repetition_penalty: 1.1,
762                repetition_penalty_window: 64,
763                seed: Some(42),
764                mirostat_mode: 0,
765                mirostat_tau: 5.0,
766                mirostat_eta: 0.1,
767            },
768            grammar_state: None,
769            max_context_length: 512,
770            num_threads: 4,
771            prefill_chunk_size: 512,
772        }
773    }
774
775    #[test]
776    fn roundtrip_serialize_deserialize() {
777        let snap = make_minimal_snapshot();
778        let bytes = snap.serialize().expect("serialize");
779        let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
780        assert_eq!(restored.arch_id, "llama");
781        assert_eq!(restored.tokens, vec![1, 2, 3]);
782        assert_eq!(restored.version, EngineSnapshot::VERSION);
783        assert_eq!(&restored.magic, SNAPSHOT_MAGIC);
784    }
785
786    #[test]
787    fn bad_magic_rejected() {
788        // Build a valid snap then corrupt the serialized magic bytes.
789        let snap = make_minimal_snapshot();
790        let mut bytes = snap.serialize().expect("serialize");
791        // The first 8 bytes in the oxicode encoding encode the magic field.
792        // Corrupt some early bytes to trigger either a decode error or a magic mismatch.
793        if bytes.len() > 4 {
794            bytes[0] ^= 0xFF;
795        }
796        let result = EngineSnapshot::deserialize(&bytes);
797        assert!(result.is_err(), "corrupted bytes must return Err");
798    }
799
800    #[test]
801    fn incompatible_version_rejected() {
802        // Serialize a snapshot with an invalid version.
803        let mut snap = make_minimal_snapshot();
804        snap.version = 9999;
805        let bytes = snap.serialize().expect("serialize");
806        let result = EngineSnapshot::deserialize(&bytes);
807        assert!(
808            matches!(result, Err(RuntimeError::SnapshotIncompatible { .. })),
809            "invalid version must return SnapshotIncompatible"
810        );
811    }
812
813    #[test]
814    fn model_fingerprint_compute_and_verify() {
815        let dir = std::env::temp_dir();
816        let path = dir.join("oxillama_snap_test_fingerprint.gguf");
817        std::fs::write(&path, vec![0xABu8; 100 * 1024]).expect("write test file");
818
819        let fp = ModelFingerprint::compute(&path).expect("compute fingerprint");
820        assert_eq!(fp.file_size, 100 * 1024);
821        fp.verify(&path).expect("verify same file");
822
823        // Modify and re-verify — must fail.
824        std::fs::write(&path, vec![0xCDu8; 100 * 1024]).expect("write modified file");
825        assert!(
826            fp.verify(&path).is_err(),
827            "fingerprint verify must fail after file modification"
828        );
829
830        let _ = std::fs::remove_file(&path);
831    }
832
833    #[test]
834    fn fingerprint_mismatch_error_type() {
835        let dir = std::env::temp_dir();
836        let path_a = dir.join("oxillama_snap_fp_a.gguf");
837        let path_b = dir.join("oxillama_snap_fp_b.gguf");
838        std::fs::write(&path_a, vec![0xAAu8; 10_000]).expect("write A");
839        std::fs::write(&path_b, vec![0xBBu8; 10_000]).expect("write B");
840
841        let fp_a = ModelFingerprint::compute(&path_a).expect("compute A");
842        let result = fp_a.verify(&path_b);
843        assert!(
844            matches!(result, Err(RuntimeError::ModelFingerprintMismatch { .. })),
845            "mismatch must return ModelFingerprintMismatch"
846        );
847
848        let _ = std::fs::remove_file(&path_a);
849        let _ = std::fs::remove_file(&path_b);
850    }
851
852    #[test]
853    fn kv_state_payload_roundtrip_in_snapshot() {
854        let kv = KvStatePayload {
855            keys: vec![vec![1.0f32, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
856            values: vec![vec![9.0f32, 10.0, 11.0, 12.0], vec![13.0, 14.0, 15.0, 16.0]],
857            seq_len: 1,
858            num_layers: 2,
859            max_seq_len: 512,
860            kv_dim: 4,
861        };
862        let mut snap = make_minimal_snapshot();
863        snap.sequence_state = SequenceStatePayload::Attention(kv.clone());
864
865        let bytes = snap.serialize().expect("serialize");
866        let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
867
868        if let SequenceStatePayload::Attention(restored_kv) = restored.sequence_state {
869            assert_eq!(restored_kv.keys, kv.keys);
870            assert_eq!(restored_kv.values, kv.values);
871            assert_eq!(restored_kv.seq_len, kv.seq_len);
872            assert_eq!(restored_kv.num_layers, kv.num_layers);
873        } else {
874            panic!("expected Attention sequence state payload");
875        }
876    }
877
878    // ── SpeculativeEngineSnapshot tests ──────────────────────────────────────
879
880    fn make_spec_snapshot(accepted: Vec<u32>, rng_state: u64) -> SpeculativeEngineSnapshot {
881        SpeculativeEngineSnapshot {
882            target_snapshot: make_minimal_snapshot(),
883            draft_snapshot: make_minimal_snapshot(),
884            num_speculative: 4,
885            spec_seed: Some(0xdeadbeef),
886            accepted_tokens: accepted,
887            rng_state,
888        }
889    }
890
891    /// Full encode → decode roundtrip must preserve all fields.
892    #[test]
893    fn spec_snapshot_roundtrip() {
894        let original = make_spec_snapshot(vec![10u32, 20, 30], 0x00c0_ffee_cafe_babe_u64);
895        let bytes = original.encode().expect("encode must succeed");
896        let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode must succeed");
897
898        assert_eq!(restored.num_speculative, 4);
899        assert_eq!(restored.spec_seed, Some(0xdeadbeef));
900        assert_eq!(restored.accepted_tokens, vec![10u32, 20, 30]);
901        assert_eq!(restored.rng_state, 0x00c0_ffee_cafe_babe_u64);
902        assert_eq!(restored.target_snapshot.arch_id, "llama");
903        assert_eq!(restored.draft_snapshot.arch_id, "llama");
904    }
905
906    /// Bytes starting with a wrong magic header must return `SpecSnapshotIncompatible`.
907    #[test]
908    fn spec_snapshot_rejects_wrong_magic() {
909        let snap = make_spec_snapshot(vec![], 42);
910        let mut bytes = snap.encode().expect("encode");
911        // Corrupt the magic header bytes
912        if bytes.len() >= 8 {
913            bytes[0] ^= 0xFF;
914        }
915        let result = SpeculativeEngineSnapshot::decode(&bytes);
916        assert!(
917            matches!(result, Err(RuntimeError::SpecSnapshotIncompatible(_))),
918            "wrong magic must return SpecSnapshotIncompatible, got {result:?}"
919        );
920    }
921
922    /// Truncated bytes must return `SpecSnapshotIncompatible`.
923    #[test]
924    fn spec_snapshot_rejects_truncated() {
925        let snap = make_spec_snapshot(vec![1u32, 2], 99);
926        let bytes = snap.encode().expect("encode");
927        // Feed only the first 12 bytes (magic + partial version)
928        let truncated = &bytes[..12.min(bytes.len())];
929        let result = SpeculativeEngineSnapshot::decode(truncated);
930        assert!(result.is_err(), "truncated bytes must return Err, got Ok");
931    }
932
933    /// Accepted token history must survive a full encode → decode cycle.
934    #[test]
935    fn spec_snapshot_preserves_accepted_history() {
936        let history = vec![1u32, 2, 3, 4, 5, 100, 200, 65535];
937        let snap = make_spec_snapshot(history.clone(), 0);
938        let bytes = snap.encode().expect("encode");
939        let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
940        assert_eq!(
941            restored.accepted_tokens, history,
942            "accepted token history must be identical after roundtrip"
943        );
944    }
945
946    /// `spec_seed = None` is encoded and decoded faithfully.
947    #[test]
948    fn spec_snapshot_none_seed_roundtrip() {
949        let mut snap = make_spec_snapshot(vec![], 7);
950        snap.spec_seed = None;
951        let bytes = snap.encode().expect("encode");
952        let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
953        assert!(
954            restored.spec_seed.is_none(),
955            "None seed must round-trip as None"
956        );
957    }
958
959    /// `fingerprint()` is deterministic for the same snapshot content.
960    #[test]
961    fn spec_snapshot_fingerprint_is_deterministic() {
962        let snap = make_spec_snapshot(vec![42u32], 0xbeef);
963        let fp1 = snap.fingerprint().expect("fingerprint 1");
964        let fp2 = snap.fingerprint().expect("fingerprint 2");
965        assert_eq!(fp1, fp2, "fingerprint must be deterministic");
966    }
967}