use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
use blake3::Hasher;
use oxicode::{Decode, Encode};
use crate::engine::{EngineConfig, InferenceEngine};
use crate::error::{RuntimeError, RuntimeResult};
use crate::sampling::SamplerConfig;
pub const SNAPSHOT_MAGIC: &[u8; 8] = b"OXISNAP1";
const DEFAULT_PROBE_SIZE: u32 = 8 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Encode, Decode)]
pub struct ModelFingerprint {
pub file_size: u64,
pub mtime_secs: i64,
pub head_hash: [u8; 32],
pub tail_hash: [u8; 32],
pub probe_size: u32,
}
impl ModelFingerprint {
pub fn compute(path: &Path) -> RuntimeResult<Self> {
Self::compute_with_probe(path, DEFAULT_PROBE_SIZE)
}
pub fn compute_with_probe(path: &Path, probe_size: u32) -> RuntimeResult<Self> {
let mut file = std::fs::File::open(path)?;
let metadata = file.metadata()?;
let file_size = metadata.len();
let mtime_secs = {
use std::time::SystemTime;
metadata
.modified()
.ok()
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
};
let head_read = (probe_size as u64).min(file_size) as usize;
let mut head_buf = vec![0u8; head_read];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut head_buf)?;
let head_hash: [u8; 32] = *Hasher::new().update(&head_buf).finalize().as_bytes();
let tail_start = file_size.saturating_sub(probe_size as u64);
let tail_read = (file_size - tail_start) as usize;
let mut tail_buf = vec![0u8; tail_read];
file.seek(SeekFrom::Start(tail_start))?;
file.read_exact(&mut tail_buf)?;
let tail_hash: [u8; 32] = *Hasher::new().update(&tail_buf).finalize().as_bytes();
Ok(Self {
file_size,
mtime_secs,
head_hash,
tail_hash,
probe_size,
})
}
pub fn verify(&self, path: &Path) -> RuntimeResult<()> {
let actual = Self::compute_with_probe(path, self.probe_size)?;
if actual == *self {
return Ok(());
}
Err(RuntimeError::ModelFingerprintMismatch {
expected: self.display(),
found: actual.display(),
detail: format!(
"model file '{}' has been modified or replaced since the snapshot was taken",
path.display()
),
})
}
pub fn display(&self) -> String {
let head_hex: String = self.head_hash.iter().map(|b| format!("{b:02x}")).collect();
let tail_hex: String = self.tail_hash.iter().map(|b| format!("{b:02x}")).collect();
format!(
"size={} mtime={} head={}...{} tail={}...{}",
self.file_size,
self.mtime_secs,
&head_hex[..8],
&head_hex[head_hex.len() - 8..],
&tail_hex[..8],
&tail_hex[tail_hex.len() - 8..],
)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct KvStatePayload {
pub keys: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
pub seq_len: usize,
pub num_layers: usize,
pub max_seq_len: usize,
pub kv_dim: usize,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct SsmStatePayload {
pub ssm_states: Vec<Vec<f32>>,
pub step: usize,
}
#[derive(Debug, Clone, Encode, Decode)]
pub enum SequenceStatePayload {
Attention(KvStatePayload),
Mamba2(SsmStatePayload),
Jamba {
attention: KvStatePayload,
ssm: SsmStatePayload,
},
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct SamplerStatePayload {
pub rng_state: u64,
pub mirostat_mu: f32,
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub min_p: f32,
pub repetition_penalty: f32,
pub repetition_penalty_window: usize,
pub seed: Option<u64>,
pub mirostat_mode: u8,
pub mirostat_tau: f32,
pub mirostat_eta: f32,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct GrammarStatePayload {
pub grammar_source: String,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct EngineSnapshot {
pub magic: [u8; 8],
pub version: u32,
pub arch_id: String,
pub model_path: String,
pub tokenizer_path: Option<String>,
pub model_fingerprint: ModelFingerprint,
pub tokens: Vec<u32>,
pub sequence_state: SequenceStatePayload,
pub sampler_state: SamplerStatePayload,
pub grammar_state: Option<GrammarStatePayload>,
pub max_context_length: usize,
pub num_threads: usize,
pub prefill_chunk_size: usize,
}
impl EngineSnapshot {
pub const VERSION: u32 = 1;
pub fn serialize(&self) -> RuntimeResult<Vec<u8>> {
oxicode::encode_to_vec(self).map_err(|e| RuntimeError::SnapshotIncompatible {
detail: format!("serialization failed: {e}"),
})
}
pub fn deserialize(bytes: &[u8]) -> RuntimeResult<Self> {
let (snap, _) = oxicode::decode_from_slice::<Self>(bytes).map_err(|e| {
RuntimeError::SnapshotIncompatible {
detail: format!("deserialization failed: {e}"),
}
})?;
if &snap.magic != SNAPSHOT_MAGIC {
return Err(RuntimeError::SnapshotIncompatible {
detail: "invalid snapshot magic bytes".to_string(),
});
}
if snap.version != Self::VERSION {
return Err(RuntimeError::SnapshotIncompatible {
detail: format!(
"snapshot version {} is not supported (expected {})",
snap.version,
Self::VERSION
),
});
}
Ok(snap)
}
}
pub const SPEC_SNAPSHOT_MAGIC: &[u8; 8] = b"OXISPEC1";
const SPEC_SNAPSHOT_VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct SpeculativeEngineSnapshot {
pub target_snapshot: EngineSnapshot,
pub draft_snapshot: EngineSnapshot,
pub num_speculative: usize,
pub spec_seed: Option<u64>,
pub accepted_tokens: Vec<u32>,
pub rng_state: u64,
}
impl SpeculativeEngineSnapshot {
pub fn encode(&self) -> RuntimeResult<Vec<u8>> {
let target_bytes = self.target_snapshot.serialize()?;
let draft_bytes = self.draft_snapshot.serialize()?;
let seed_bytes = if self.spec_seed.is_some() {
9usize
} else {
1usize
};
let capacity = 8
+ 4
+ 8
+ target_bytes.len()
+ 8
+ draft_bytes.len()
+ 8
+ seed_bytes
+ 8
+ self.accepted_tokens.len() * 4
+ 8;
let mut buf: Vec<u8> = Vec::with_capacity(capacity);
buf.extend_from_slice(SPEC_SNAPSHOT_MAGIC);
buf.extend_from_slice(&SPEC_SNAPSHOT_VERSION.to_le_bytes());
buf.extend_from_slice(&(target_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(&target_bytes);
buf.extend_from_slice(&(draft_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(&draft_bytes);
buf.extend_from_slice(&(self.num_speculative as u64).to_le_bytes());
match self.spec_seed {
None => buf.push(0x00),
Some(seed) => {
buf.push(0x01);
buf.extend_from_slice(&seed.to_le_bytes());
}
}
buf.extend_from_slice(&(self.accepted_tokens.len() as u64).to_le_bytes());
for &tok in &self.accepted_tokens {
buf.extend_from_slice(&tok.to_le_bytes());
}
buf.extend_from_slice(&self.rng_state.to_le_bytes());
Ok(buf)
}
pub fn decode(bytes: &[u8]) -> RuntimeResult<Self> {
let mut pos = 0usize;
macro_rules! read_exact {
($n:expr, $label:expr) => {{
let end = pos + $n;
if end > bytes.len() {
return Err(RuntimeError::SpecSnapshotIncompatible(format!(
"truncated: expected {} bytes for {} at offset {}",
$n, $label, pos
)));
}
let slice = &bytes[pos..end];
pos = end;
slice
}};
}
let magic = read_exact!(8, "magic");
if magic != SPEC_SNAPSHOT_MAGIC {
return Err(RuntimeError::SpecSnapshotIncompatible(format!(
"invalid magic bytes: expected {:?}, got {:?}",
SPEC_SNAPSHOT_MAGIC, magic
)));
}
let version = u32::from_le_bytes(
read_exact!(4, "version")
.try_into()
.expect("slice is exactly 4 bytes"),
);
if version != SPEC_SNAPSHOT_VERSION {
return Err(RuntimeError::SpecSnapshotIncompatible(format!(
"unsupported version {version} (expected {SPEC_SNAPSHOT_VERSION})"
)));
}
let target_len = u64::from_le_bytes(
read_exact!(8, "target_len")
.try_into()
.expect("slice is exactly 8 bytes"),
) as usize;
let target_raw = read_exact!(target_len, "target_bytes");
let target_snapshot = EngineSnapshot::deserialize(target_raw).map_err(|e| {
RuntimeError::SpecSnapshotIncompatible(format!("target snapshot corrupt: {e}"))
})?;
let draft_len = u64::from_le_bytes(
read_exact!(8, "draft_len")
.try_into()
.expect("slice is exactly 8 bytes"),
) as usize;
let draft_raw = read_exact!(draft_len, "draft_bytes");
let draft_snapshot = EngineSnapshot::deserialize(draft_raw).map_err(|e| {
RuntimeError::SpecSnapshotIncompatible(format!("draft snapshot corrupt: {e}"))
})?;
let num_speculative = u64::from_le_bytes(
read_exact!(8, "num_speculative")
.try_into()
.expect("slice is exactly 8 bytes"),
) as usize;
let has_seed = read_exact!(1, "has_seed")[0];
let spec_seed = if has_seed == 0x01 {
let seed_bytes = read_exact!(8, "seed");
Some(u64::from_le_bytes(
seed_bytes.try_into().expect("slice is exactly 8 bytes"),
))
} else {
None
};
let accepted_len = u64::from_le_bytes(
read_exact!(8, "accepted_len")
.try_into()
.expect("slice is exactly 8 bytes"),
) as usize;
let mut accepted_tokens = Vec::with_capacity(accepted_len);
for _ in 0..accepted_len {
let tok = u32::from_le_bytes(
read_exact!(4, "accepted_token")
.try_into()
.expect("slice is exactly 4 bytes"),
);
accepted_tokens.push(tok);
}
let rng_state = u64::from_le_bytes(
read_exact!(8, "rng_state")
.try_into()
.expect("slice is exactly 8 bytes"),
);
let _ = pos;
Ok(Self {
target_snapshot,
draft_snapshot,
num_speculative,
spec_seed,
accepted_tokens,
rng_state,
})
}
pub fn fingerprint(&self) -> RuntimeResult<[u8; 32]> {
let encoded = self.encode()?;
Ok(*Hasher::new().update(&encoded).finalize().as_bytes())
}
}
impl InferenceEngine {
pub fn snapshot(&self) -> RuntimeResult<Vec<u8>> {
let model_config = self.model_config().ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache_ref().ok_or(RuntimeError::ModelNotLoaded)?;
let model_path = Path::new(self.config().model_path.as_str());
let model_fingerprint = ModelFingerprint::compute(model_path)?;
let sequence_state = SequenceStatePayload::Attention(kv_cache.to_payload());
let sampler_cfg = &self.config().sampler;
let sampler_state = SamplerStatePayload {
rng_state: sampler_cfg.seed.unwrap_or(0),
mirostat_mu: 2.0 * sampler_cfg.mirostat_tau,
temperature: sampler_cfg.temperature,
top_k: sampler_cfg.top_k,
top_p: sampler_cfg.top_p,
min_p: sampler_cfg.min_p,
repetition_penalty: sampler_cfg.repetition_penalty,
repetition_penalty_window: sampler_cfg.repetition_penalty_window,
seed: sampler_cfg.seed,
mirostat_mode: sampler_cfg.mirostat,
mirostat_tau: sampler_cfg.mirostat_tau,
mirostat_eta: sampler_cfg.mirostat_eta,
};
let grammar_state = sampler_cfg.grammar.as_ref().map(|g| GrammarStatePayload {
grammar_source: g.source.clone(),
});
let snap = EngineSnapshot {
magic: *SNAPSHOT_MAGIC,
version: EngineSnapshot::VERSION,
arch_id: model_config.architecture.clone(),
model_path: self.config().model_path.clone(),
tokenizer_path: self.config().tokenizer_path.clone(),
model_fingerprint,
tokens: Vec::new(), sequence_state,
sampler_state,
grammar_state,
max_context_length: model_config.max_context_length,
num_threads: self.config().num_threads,
prefill_chunk_size: self.config().prefill_chunk_size,
};
snap.serialize()
}
pub fn resume(bytes: &[u8], model_path: &Path) -> RuntimeResult<Self> {
use crate::sampling::grammar::Grammar;
use std::sync::Arc;
let snap = EngineSnapshot::deserialize(bytes)?;
snap.model_fingerprint.verify(model_path)?;
let mut sampler_config = SamplerConfig {
temperature: snap.sampler_state.temperature,
top_k: snap.sampler_state.top_k,
top_p: snap.sampler_state.top_p,
min_p: snap.sampler_state.min_p,
repetition_penalty: snap.sampler_state.repetition_penalty,
repetition_penalty_window: snap.sampler_state.repetition_penalty_window,
seed: snap.sampler_state.seed,
mirostat: snap.sampler_state.mirostat_mode,
mirostat_tau: snap.sampler_state.mirostat_tau,
mirostat_eta: snap.sampler_state.mirostat_eta,
grammar: None,
token_vocab: None,
logit_bias: std::collections::HashMap::new(),
banned_tokens: Vec::new(),
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
xtc_threshold: 0.0,
xtc_probability: 0.5,
typical_p: 1.0,
top_a: 0.0,
eta_cutoff: 0.0,
epsilon_cutoff: 0.0,
};
if let Some(gs) = &snap.grammar_state {
let grammar =
Grammar::parse(&gs.grammar_source).map_err(|e| RuntimeError::ModelLoadError {
message: format!("failed to re-parse grammar from snapshot: {e}"),
})?;
sampler_config.grammar = Some(Arc::new(grammar));
}
let config = EngineConfig {
model_path: model_path
.to_str()
.ok_or_else(|| RuntimeError::ModelLoadError {
message: "model path contains non-UTF-8 characters".to_string(),
})?
.to_string(),
tokenizer_path: snap.tokenizer_path.clone(),
context_size: Some(snap.max_context_length),
num_threads: snap.num_threads,
sampler: sampler_config,
prefill_chunk_size: snap.prefill_chunk_size,
offload_policy: crate::offload::OffloadPolicy::None,
};
let mut engine = Self::new(config);
engine.load_model()?;
if let SequenceStatePayload::Attention(kv_payload) = &snap.sequence_state {
let kv = engine.kv_cache_mut().ok_or(RuntimeError::ModelNotLoaded)?;
kv.restore_from_payload(kv_payload)?;
}
Ok(engine)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_minimal_snapshot() -> EngineSnapshot {
EngineSnapshot {
magic: *SNAPSHOT_MAGIC,
version: EngineSnapshot::VERSION,
arch_id: "llama".to_string(),
model_path: "/tmp/test.gguf".to_string(),
tokenizer_path: None,
model_fingerprint: ModelFingerprint {
file_size: 1024,
mtime_secs: 1_000_000,
head_hash: [0u8; 32],
tail_hash: [1u8; 32],
probe_size: DEFAULT_PROBE_SIZE,
},
tokens: vec![1, 2, 3],
sequence_state: SequenceStatePayload::Attention(KvStatePayload {
keys: vec![vec![0.0f32; 4]],
values: vec![vec![0.0f32; 4]],
seq_len: 1,
num_layers: 1,
max_seq_len: 512,
kv_dim: 4,
}),
sampler_state: SamplerStatePayload {
rng_state: 42,
mirostat_mu: 5.0,
temperature: 0.7,
top_k: 40,
top_p: 0.9,
min_p: 0.0,
repetition_penalty: 1.1,
repetition_penalty_window: 64,
seed: Some(42),
mirostat_mode: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
},
grammar_state: None,
max_context_length: 512,
num_threads: 4,
prefill_chunk_size: 512,
}
}
#[test]
fn roundtrip_serialize_deserialize() {
let snap = make_minimal_snapshot();
let bytes = snap.serialize().expect("serialize");
let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
assert_eq!(restored.arch_id, "llama");
assert_eq!(restored.tokens, vec![1, 2, 3]);
assert_eq!(restored.version, EngineSnapshot::VERSION);
assert_eq!(&restored.magic, SNAPSHOT_MAGIC);
}
#[test]
fn bad_magic_rejected() {
let snap = make_minimal_snapshot();
let mut bytes = snap.serialize().expect("serialize");
if bytes.len() > 4 {
bytes[0] ^= 0xFF;
}
let result = EngineSnapshot::deserialize(&bytes);
assert!(result.is_err(), "corrupted bytes must return Err");
}
#[test]
fn incompatible_version_rejected() {
let mut snap = make_minimal_snapshot();
snap.version = 9999;
let bytes = snap.serialize().expect("serialize");
let result = EngineSnapshot::deserialize(&bytes);
assert!(
matches!(result, Err(RuntimeError::SnapshotIncompatible { .. })),
"invalid version must return SnapshotIncompatible"
);
}
#[test]
fn model_fingerprint_compute_and_verify() {
let dir = std::env::temp_dir();
let path = dir.join("oxillama_snap_test_fingerprint.gguf");
std::fs::write(&path, vec![0xABu8; 100 * 1024]).expect("write test file");
let fp = ModelFingerprint::compute(&path).expect("compute fingerprint");
assert_eq!(fp.file_size, 100 * 1024);
fp.verify(&path).expect("verify same file");
std::fs::write(&path, vec![0xCDu8; 100 * 1024]).expect("write modified file");
assert!(
fp.verify(&path).is_err(),
"fingerprint verify must fail after file modification"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn fingerprint_mismatch_error_type() {
let dir = std::env::temp_dir();
let path_a = dir.join("oxillama_snap_fp_a.gguf");
let path_b = dir.join("oxillama_snap_fp_b.gguf");
std::fs::write(&path_a, vec![0xAAu8; 10_000]).expect("write A");
std::fs::write(&path_b, vec![0xBBu8; 10_000]).expect("write B");
let fp_a = ModelFingerprint::compute(&path_a).expect("compute A");
let result = fp_a.verify(&path_b);
assert!(
matches!(result, Err(RuntimeError::ModelFingerprintMismatch { .. })),
"mismatch must return ModelFingerprintMismatch"
);
let _ = std::fs::remove_file(&path_a);
let _ = std::fs::remove_file(&path_b);
}
#[test]
fn kv_state_payload_roundtrip_in_snapshot() {
let kv = KvStatePayload {
keys: vec![vec![1.0f32, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
values: vec![vec![9.0f32, 10.0, 11.0, 12.0], vec![13.0, 14.0, 15.0, 16.0]],
seq_len: 1,
num_layers: 2,
max_seq_len: 512,
kv_dim: 4,
};
let mut snap = make_minimal_snapshot();
snap.sequence_state = SequenceStatePayload::Attention(kv.clone());
let bytes = snap.serialize().expect("serialize");
let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
if let SequenceStatePayload::Attention(restored_kv) = restored.sequence_state {
assert_eq!(restored_kv.keys, kv.keys);
assert_eq!(restored_kv.values, kv.values);
assert_eq!(restored_kv.seq_len, kv.seq_len);
assert_eq!(restored_kv.num_layers, kv.num_layers);
} else {
panic!("expected Attention sequence state payload");
}
}
fn make_spec_snapshot(accepted: Vec<u32>, rng_state: u64) -> SpeculativeEngineSnapshot {
SpeculativeEngineSnapshot {
target_snapshot: make_minimal_snapshot(),
draft_snapshot: make_minimal_snapshot(),
num_speculative: 4,
spec_seed: Some(0xdeadbeef),
accepted_tokens: accepted,
rng_state,
}
}
#[test]
fn spec_snapshot_roundtrip() {
let original = make_spec_snapshot(vec![10u32, 20, 30], 0x00c0_ffee_cafe_babe_u64);
let bytes = original.encode().expect("encode must succeed");
let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode must succeed");
assert_eq!(restored.num_speculative, 4);
assert_eq!(restored.spec_seed, Some(0xdeadbeef));
assert_eq!(restored.accepted_tokens, vec![10u32, 20, 30]);
assert_eq!(restored.rng_state, 0x00c0_ffee_cafe_babe_u64);
assert_eq!(restored.target_snapshot.arch_id, "llama");
assert_eq!(restored.draft_snapshot.arch_id, "llama");
}
#[test]
fn spec_snapshot_rejects_wrong_magic() {
let snap = make_spec_snapshot(vec![], 42);
let mut bytes = snap.encode().expect("encode");
if bytes.len() >= 8 {
bytes[0] ^= 0xFF;
}
let result = SpeculativeEngineSnapshot::decode(&bytes);
assert!(
matches!(result, Err(RuntimeError::SpecSnapshotIncompatible(_))),
"wrong magic must return SpecSnapshotIncompatible, got {result:?}"
);
}
#[test]
fn spec_snapshot_rejects_truncated() {
let snap = make_spec_snapshot(vec![1u32, 2], 99);
let bytes = snap.encode().expect("encode");
let truncated = &bytes[..12.min(bytes.len())];
let result = SpeculativeEngineSnapshot::decode(truncated);
assert!(result.is_err(), "truncated bytes must return Err, got Ok");
}
#[test]
fn spec_snapshot_preserves_accepted_history() {
let history = vec![1u32, 2, 3, 4, 5, 100, 200, 65535];
let snap = make_spec_snapshot(history.clone(), 0);
let bytes = snap.encode().expect("encode");
let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
assert_eq!(
restored.accepted_tokens, history,
"accepted token history must be identical after roundtrip"
);
}
#[test]
fn spec_snapshot_none_seed_roundtrip() {
let mut snap = make_spec_snapshot(vec![], 7);
snap.spec_seed = None;
let bytes = snap.encode().expect("encode");
let restored = SpeculativeEngineSnapshot::decode(&bytes).expect("decode");
assert!(
restored.spec_seed.is_none(),
"None seed must round-trip as None"
);
}
#[test]
fn spec_snapshot_fingerprint_is_deterministic() {
let snap = make_spec_snapshot(vec![42u32], 0xbeef);
let fp1 = snap.fingerprint().expect("fingerprint 1");
let fp2 = snap.fingerprint().expect("fingerprint 2");
assert_eq!(fp1, fp2, "fingerprint must be deterministic");
}
}