use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
use blake3::Hasher;
use oxicode::{Decode, Encode};
use crate::engine::{EngineConfig, InferenceEngine};
use crate::error::{RuntimeError, RuntimeResult};
use crate::sampling::SamplerConfig;
pub const SNAPSHOT_MAGIC: &[u8; 8] = b"OXISNAP1";
const DEFAULT_PROBE_SIZE: u32 = 8 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Encode, Decode)]
pub struct ModelFingerprint {
pub file_size: u64,
pub mtime_secs: i64,
pub head_hash: [u8; 32],
pub tail_hash: [u8; 32],
pub probe_size: u32,
}
impl ModelFingerprint {
pub fn compute(path: &Path) -> RuntimeResult<Self> {
Self::compute_with_probe(path, DEFAULT_PROBE_SIZE)
}
pub fn compute_with_probe(path: &Path, probe_size: u32) -> RuntimeResult<Self> {
let mut file = std::fs::File::open(path)?;
let metadata = file.metadata()?;
let file_size = metadata.len();
let mtime_secs = {
use std::time::SystemTime;
metadata
.modified()
.ok()
.and_then(|t| t.duration_since(SystemTime::UNIX_EPOCH).ok())
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
};
let head_read = (probe_size as u64).min(file_size) as usize;
let mut head_buf = vec![0u8; head_read];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut head_buf)?;
let head_hash: [u8; 32] = *Hasher::new().update(&head_buf).finalize().as_bytes();
let tail_start = file_size.saturating_sub(probe_size as u64);
let tail_read = (file_size - tail_start) as usize;
let mut tail_buf = vec![0u8; tail_read];
file.seek(SeekFrom::Start(tail_start))?;
file.read_exact(&mut tail_buf)?;
let tail_hash: [u8; 32] = *Hasher::new().update(&tail_buf).finalize().as_bytes();
Ok(Self {
file_size,
mtime_secs,
head_hash,
tail_hash,
probe_size,
})
}
pub fn verify(&self, path: &Path) -> RuntimeResult<()> {
let actual = Self::compute_with_probe(path, self.probe_size)?;
if actual == *self {
return Ok(());
}
Err(RuntimeError::ModelFingerprintMismatch {
expected: self.display(),
found: actual.display(),
detail: format!(
"model file '{}' has been modified or replaced since the snapshot was taken",
path.display()
),
})
}
pub fn display(&self) -> String {
let head_hex: String = self.head_hash.iter().map(|b| format!("{b:02x}")).collect();
let tail_hex: String = self.tail_hash.iter().map(|b| format!("{b:02x}")).collect();
format!(
"size={} mtime={} head={}...{} tail={}...{}",
self.file_size,
self.mtime_secs,
&head_hex[..8],
&head_hex[head_hex.len() - 8..],
&tail_hex[..8],
&tail_hex[tail_hex.len() - 8..],
)
}
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct KvStatePayload {
pub keys: Vec<Vec<f32>>,
pub values: Vec<Vec<f32>>,
pub seq_len: usize,
pub num_layers: usize,
pub max_seq_len: usize,
pub kv_dim: usize,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct SsmStatePayload {
pub ssm_states: Vec<Vec<f32>>,
pub step: usize,
}
#[derive(Debug, Clone, Encode, Decode)]
pub enum SequenceStatePayload {
Attention(KvStatePayload),
Mamba2(SsmStatePayload),
Jamba {
attention: KvStatePayload,
ssm: SsmStatePayload,
},
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct SamplerStatePayload {
pub rng_state: u64,
pub mirostat_mu: f32,
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub min_p: f32,
pub repetition_penalty: f32,
pub repetition_penalty_window: usize,
pub seed: Option<u64>,
pub mirostat_mode: u8,
pub mirostat_tau: f32,
pub mirostat_eta: f32,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct GrammarStatePayload {
pub grammar_source: String,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct EngineSnapshot {
pub magic: [u8; 8],
pub version: u32,
pub arch_id: String,
pub model_path: String,
pub tokenizer_path: Option<String>,
pub model_fingerprint: ModelFingerprint,
pub tokens: Vec<u32>,
pub sequence_state: SequenceStatePayload,
pub sampler_state: SamplerStatePayload,
pub grammar_state: Option<GrammarStatePayload>,
pub max_context_length: usize,
pub num_threads: usize,
pub prefill_chunk_size: usize,
}
impl EngineSnapshot {
pub const VERSION: u32 = 1;
pub fn serialize(&self) -> RuntimeResult<Vec<u8>> {
oxicode::encode_to_vec(self).map_err(|e| RuntimeError::SnapshotIncompatible {
detail: format!("serialization failed: {e}"),
})
}
pub fn deserialize(bytes: &[u8]) -> RuntimeResult<Self> {
let (snap, _) = oxicode::decode_from_slice::<Self>(bytes).map_err(|e| {
RuntimeError::SnapshotIncompatible {
detail: format!("deserialization failed: {e}"),
}
})?;
if &snap.magic != SNAPSHOT_MAGIC {
return Err(RuntimeError::SnapshotIncompatible {
detail: "invalid snapshot magic bytes".to_string(),
});
}
if snap.version != Self::VERSION {
return Err(RuntimeError::SnapshotIncompatible {
detail: format!(
"snapshot version {} is not supported (expected {})",
snap.version,
Self::VERSION
),
});
}
Ok(snap)
}
}
impl InferenceEngine {
pub fn snapshot(&self) -> RuntimeResult<Vec<u8>> {
let model_config = self.model_config().ok_or(RuntimeError::ModelNotLoaded)?;
let kv_cache = self.kv_cache_ref().ok_or(RuntimeError::ModelNotLoaded)?;
let model_path = Path::new(self.config().model_path.as_str());
let model_fingerprint = ModelFingerprint::compute(model_path)?;
let sequence_state = SequenceStatePayload::Attention(kv_cache.to_payload());
let sampler_cfg = &self.config().sampler;
let sampler_state = SamplerStatePayload {
rng_state: sampler_cfg.seed.unwrap_or(0),
mirostat_mu: 2.0 * sampler_cfg.mirostat_tau,
temperature: sampler_cfg.temperature,
top_k: sampler_cfg.top_k,
top_p: sampler_cfg.top_p,
min_p: sampler_cfg.min_p,
repetition_penalty: sampler_cfg.repetition_penalty,
repetition_penalty_window: sampler_cfg.repetition_penalty_window,
seed: sampler_cfg.seed,
mirostat_mode: sampler_cfg.mirostat,
mirostat_tau: sampler_cfg.mirostat_tau,
mirostat_eta: sampler_cfg.mirostat_eta,
};
let grammar_state = sampler_cfg.grammar.as_ref().map(|g| GrammarStatePayload {
grammar_source: g.source.clone(),
});
let snap = EngineSnapshot {
magic: *SNAPSHOT_MAGIC,
version: EngineSnapshot::VERSION,
arch_id: model_config.architecture.clone(),
model_path: self.config().model_path.clone(),
tokenizer_path: self.config().tokenizer_path.clone(),
model_fingerprint,
tokens: Vec::new(), sequence_state,
sampler_state,
grammar_state,
max_context_length: model_config.max_context_length,
num_threads: self.config().num_threads,
prefill_chunk_size: self.config().prefill_chunk_size,
};
snap.serialize()
}
pub fn resume(bytes: &[u8], model_path: &Path) -> RuntimeResult<Self> {
use crate::sampling::grammar::Grammar;
use std::sync::Arc;
let snap = EngineSnapshot::deserialize(bytes)?;
snap.model_fingerprint.verify(model_path)?;
let mut sampler_config = SamplerConfig {
temperature: snap.sampler_state.temperature,
top_k: snap.sampler_state.top_k,
top_p: snap.sampler_state.top_p,
min_p: snap.sampler_state.min_p,
repetition_penalty: snap.sampler_state.repetition_penalty,
repetition_penalty_window: snap.sampler_state.repetition_penalty_window,
seed: snap.sampler_state.seed,
mirostat: snap.sampler_state.mirostat_mode,
mirostat_tau: snap.sampler_state.mirostat_tau,
mirostat_eta: snap.sampler_state.mirostat_eta,
grammar: None,
token_vocab: None,
};
if let Some(gs) = &snap.grammar_state {
let grammar =
Grammar::parse(&gs.grammar_source).map_err(|e| RuntimeError::ModelLoadError {
message: format!("failed to re-parse grammar from snapshot: {e}"),
})?;
sampler_config.grammar = Some(Arc::new(grammar));
}
let config = EngineConfig {
model_path: model_path
.to_str()
.ok_or_else(|| RuntimeError::ModelLoadError {
message: "model path contains non-UTF-8 characters".to_string(),
})?
.to_string(),
tokenizer_path: snap.tokenizer_path.clone(),
context_size: Some(snap.max_context_length),
num_threads: snap.num_threads,
sampler: sampler_config,
prefill_chunk_size: snap.prefill_chunk_size,
};
let mut engine = Self::new(config);
engine.load_model()?;
if let SequenceStatePayload::Attention(kv_payload) = &snap.sequence_state {
let kv = engine.kv_cache_mut().ok_or(RuntimeError::ModelNotLoaded)?;
kv.restore_from_payload(kv_payload)?;
}
Ok(engine)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_minimal_snapshot() -> EngineSnapshot {
EngineSnapshot {
magic: *SNAPSHOT_MAGIC,
version: EngineSnapshot::VERSION,
arch_id: "llama".to_string(),
model_path: "/tmp/test.gguf".to_string(),
tokenizer_path: None,
model_fingerprint: ModelFingerprint {
file_size: 1024,
mtime_secs: 1_000_000,
head_hash: [0u8; 32],
tail_hash: [1u8; 32],
probe_size: DEFAULT_PROBE_SIZE,
},
tokens: vec![1, 2, 3],
sequence_state: SequenceStatePayload::Attention(KvStatePayload {
keys: vec![vec![0.0f32; 4]],
values: vec![vec![0.0f32; 4]],
seq_len: 1,
num_layers: 1,
max_seq_len: 512,
kv_dim: 4,
}),
sampler_state: SamplerStatePayload {
rng_state: 42,
mirostat_mu: 5.0,
temperature: 0.7,
top_k: 40,
top_p: 0.9,
min_p: 0.0,
repetition_penalty: 1.1,
repetition_penalty_window: 64,
seed: Some(42),
mirostat_mode: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
},
grammar_state: None,
max_context_length: 512,
num_threads: 4,
prefill_chunk_size: 512,
}
}
#[test]
fn roundtrip_serialize_deserialize() {
let snap = make_minimal_snapshot();
let bytes = snap.serialize().expect("serialize");
let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
assert_eq!(restored.arch_id, "llama");
assert_eq!(restored.tokens, vec![1, 2, 3]);
assert_eq!(restored.version, EngineSnapshot::VERSION);
assert_eq!(&restored.magic, SNAPSHOT_MAGIC);
}
#[test]
fn bad_magic_rejected() {
let snap = make_minimal_snapshot();
let mut bytes = snap.serialize().expect("serialize");
if bytes.len() > 4 {
bytes[0] ^= 0xFF;
}
let result = EngineSnapshot::deserialize(&bytes);
assert!(result.is_err(), "corrupted bytes must return Err");
}
#[test]
fn incompatible_version_rejected() {
let mut snap = make_minimal_snapshot();
snap.version = 9999;
let bytes = snap.serialize().expect("serialize");
let result = EngineSnapshot::deserialize(&bytes);
assert!(
matches!(result, Err(RuntimeError::SnapshotIncompatible { .. })),
"invalid version must return SnapshotIncompatible"
);
}
#[test]
fn model_fingerprint_compute_and_verify() {
let dir = std::env::temp_dir();
let path = dir.join("oxillama_snap_test_fingerprint.gguf");
std::fs::write(&path, vec![0xABu8; 100 * 1024]).expect("write test file");
let fp = ModelFingerprint::compute(&path).expect("compute fingerprint");
assert_eq!(fp.file_size, 100 * 1024);
fp.verify(&path).expect("verify same file");
std::fs::write(&path, vec![0xCDu8; 100 * 1024]).expect("write modified file");
assert!(
fp.verify(&path).is_err(),
"fingerprint verify must fail after file modification"
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn fingerprint_mismatch_error_type() {
let dir = std::env::temp_dir();
let path_a = dir.join("oxillama_snap_fp_a.gguf");
let path_b = dir.join("oxillama_snap_fp_b.gguf");
std::fs::write(&path_a, vec![0xAAu8; 10_000]).expect("write A");
std::fs::write(&path_b, vec![0xBBu8; 10_000]).expect("write B");
let fp_a = ModelFingerprint::compute(&path_a).expect("compute A");
let result = fp_a.verify(&path_b);
assert!(
matches!(result, Err(RuntimeError::ModelFingerprintMismatch { .. })),
"mismatch must return ModelFingerprintMismatch"
);
let _ = std::fs::remove_file(&path_a);
let _ = std::fs::remove_file(&path_b);
}
#[test]
fn kv_state_payload_roundtrip_in_snapshot() {
let kv = KvStatePayload {
keys: vec![vec![1.0f32, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
values: vec![vec![9.0f32, 10.0, 11.0, 12.0], vec![13.0, 14.0, 15.0, 16.0]],
seq_len: 1,
num_layers: 2,
max_seq_len: 512,
kv_dim: 4,
};
let mut snap = make_minimal_snapshot();
snap.sequence_state = SequenceStatePayload::Attention(kv.clone());
let bytes = snap.serialize().expect("serialize");
let restored = EngineSnapshot::deserialize(&bytes).expect("deserialize");
if let SequenceStatePayload::Attention(restored_kv) = restored.sequence_state {
assert_eq!(restored_kv.keys, kv.keys);
assert_eq!(restored_kv.values, kv.values);
assert_eq!(restored_kv.seq_len, kv.seq_len);
assert_eq!(restored_kv.num_layers, kv.num_layers);
} else {
panic!("expected Attention sequence state payload");
}
}
}