use std::io::Write;
use std::path::{Path, PathBuf};
use oxillama_runtime::snapshot::EngineSnapshot;
use pyo3::exceptions::PyOSError;
use pyo3::prelude::*;
use crate::error::runtime_to_py;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
pub struct HubOrigin {
pub repo_id: String,
pub filename: String,
pub sha256: String,
}
impl HubOrigin {
pub fn from_py_dict(dict: &pyo3::Bound<'_, pyo3::types::PyDict>) -> PyResult<Self> {
let repo_id = dict
.get_item("repo_id")?
.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("hub_origin dict must contain 'repo_id'")
})?
.extract::<String>()?;
let filename = dict
.get_item("filename")?
.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("hub_origin dict must contain 'filename'")
})?
.extract::<String>()?;
let sha256 = dict
.get_item("sha256")?
.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("hub_origin dict must contain 'sha256'")
})?
.extract::<String>()?;
Ok(Self {
repo_id,
filename,
sha256,
})
}
}
#[pyclass(name = "SnapshotInfo", from_py_object)]
#[derive(Clone)]
pub struct PySnapshotInfo {
#[pyo3(get)]
pub arch_id: String,
#[pyo3(get)]
pub model_path: String,
#[pyo3(get)]
pub tokenizer_path: Option<String>,
#[pyo3(get)]
pub max_context_length: usize,
#[pyo3(get)]
pub num_threads: usize,
#[pyo3(get)]
pub version: u32,
#[pyo3(get)]
pub magic: Vec<u8>,
#[pyo3(get)]
pub tokens_count: usize,
}
#[pymethods]
impl PySnapshotInfo {
fn __repr__(&self) -> String {
format!(
"SnapshotInfo(arch_id={:?}, model_path={:?}, version={})",
self.arch_id, self.model_path, self.version
)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
pub struct EngineSnapshotMeta {
pub model_path: String,
pub hub_origin: Option<HubOrigin>,
}
impl EngineSnapshotMeta {
pub fn from_engine_snapshot(snap: &EngineSnapshot, hub_origin: Option<HubOrigin>) -> Self {
Self {
model_path: snap.model_path.clone(),
hub_origin,
}
}
}
pub fn write_snapshot_atomic(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
let parent = path.parent().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"path has no parent directory",
)
})?;
let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
tmp.write_all(bytes)?;
tmp.persist(path).map_err(|e| e.error)?;
Ok(())
}
pub fn read_and_peek_snapshot(path: &Path) -> Result<(Vec<u8>, EngineSnapshot), PyErr> {
let bytes = std::fs::read(path).map_err(|e| PyOSError::new_err(e.to_string()))?;
let snap = EngineSnapshot::deserialize(&bytes).map_err(runtime_to_py)?;
Ok((bytes, snap))
}
pub fn io_to_py(err: std::io::Error) -> PyErr {
PyOSError::new_err(err.to_string())
}
pub fn snapshot_info_from_snap(snap: &EngineSnapshot) -> PySnapshotInfo {
PySnapshotInfo {
arch_id: snap.arch_id.clone(),
model_path: snap.model_path.clone(),
tokenizer_path: snap.tokenizer_path.clone(),
max_context_length: snap.max_context_length,
num_threads: snap.num_threads,
version: snap.version,
magic: snap.magic.to_vec(),
tokens_count: snap.tokens.len(),
}
}
pub fn snapshot_info_from_path(path: &PathBuf) -> PyResult<PySnapshotInfo> {
let bytes = std::fs::read(path).map_err(io_to_py)?;
let snap = EngineSnapshot::deserialize(&bytes).map_err(runtime_to_py)?;
Ok(snapshot_info_from_snap(&snap))
}
pub fn sha256_hex(data: &[u8]) -> String {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
hasher.update(data);
let result = hasher.finalize();
result
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
pub fn verify_sha256(path: &Path, expected_hex: &str) -> PyResult<()> {
let data = std::fs::read(path).map_err(io_to_py)?;
let actual = sha256_hex(&data);
if actual != expected_hex.to_lowercase() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"SHA-256 mismatch for {}: expected {}, got {}",
path.display(),
expected_hex,
actual
)));
}
Ok(())
}
pub fn meta_path_for(snap_path: &Path) -> PathBuf {
let mut s = snap_path.as_os_str().to_owned();
s.push(".meta.json");
PathBuf::from(s)
}
pub fn write_meta(snap_path: &Path, meta: &EngineSnapshotMeta) -> PyResult<()> {
let json = serde_json::to_vec_pretty(meta)
.map_err(|e| PyOSError::new_err(format!("Failed to serialize snapshot metadata: {e}")))?;
let path = meta_path_for(snap_path);
write_snapshot_atomic(&path, &json).map_err(io_to_py)
}
pub fn read_meta(snap_path: &Path) -> PyResult<Option<EngineSnapshotMeta>> {
let path = meta_path_for(snap_path);
if !path.exists() {
return Ok(None);
}
let data = std::fs::read(&path).map_err(io_to_py)?;
let meta = serde_json::from_slice::<EngineSnapshotMeta>(&data)
.map_err(|e| PyOSError::new_err(format!("Failed to deserialize snapshot metadata: {e}")))?;
Ok(Some(meta))
}
#[cfg(feature = "hub")]
pub fn resolve_model_path_with_hub(model_path: &str, hub_origin: &HubOrigin) -> PyResult<PathBuf> {
let path = PathBuf::from(model_path);
if path.exists() {
if !hub_origin.sha256.is_empty() {
verify_sha256(&path, &hub_origin.sha256)?;
}
return Ok(path);
}
let downloaded = crate::hub::download_model_from_hub(
&hub_origin.repo_id,
Some(&hub_origin.filename),
None,
None,
)?;
let downloaded_path = PathBuf::from(&downloaded);
if !hub_origin.sha256.is_empty() {
verify_sha256(&downloaded_path, &hub_origin.sha256)?;
}
Ok(downloaded_path)
}
#[cfg(not(feature = "hub"))]
pub fn resolve_model_path_with_hub(model_path: &str, _hub_origin: &HubOrigin) -> PyResult<PathBuf> {
Ok(PathBuf::from(model_path))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_snapshot_atomic_roundtrip() {
let tmp_dir = std::env::temp_dir();
let path = tmp_dir.join("oxillama_py_snap_roundtrip.bin");
let data: Vec<u8> = (0u8..100).collect();
write_snapshot_atomic(&path, &data).expect("write_snapshot_atomic must succeed");
let read_back = std::fs::read(&path).expect("read back must succeed");
assert_eq!(read_back, data, "roundtrip data must match");
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_write_snapshot_atomic_overwrites_existing() {
let tmp_dir = std::env::temp_dir();
let path = tmp_dir.join("oxillama_py_snap_overwrite.bin");
let data_a: Vec<u8> = vec![0xAA; 64];
let data_b: Vec<u8> = vec![0xBB; 128];
write_snapshot_atomic(&path, &data_a).expect("first write must succeed");
write_snapshot_atomic(&path, &data_b).expect("second write must succeed");
let read_back = std::fs::read(&path).expect("read back must succeed");
assert_eq!(read_back, data_b, "second write must overwrite first");
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_io_to_py_kind_not_found() {
let err = std::io::Error::new(std::io::ErrorKind::NotFound, "test not found");
assert_eq!(err.kind(), std::io::ErrorKind::NotFound);
assert!(err.to_string().contains("test not found"));
let _py_err = io_to_py(err);
}
#[test]
fn test_write_snapshot_atomic_no_parent_fails() {
let nonexistent_dir = std::env::temp_dir().join("oxillama_py_no_such_dir_xyz_abc123");
let _ = std::fs::remove_dir_all(&nonexistent_dir);
let path = nonexistent_dir.join("snap.bin");
let result = write_snapshot_atomic(&path, b"x");
assert!(
result.is_err(),
"write to a path whose parent directory does not exist must return Err"
);
}
#[test]
fn hub_origin_serde_roundtrip() {
let origin = HubOrigin {
repo_id: "mistralai/Mixtral-8x7B-Instruct-v0.1".to_string(),
filename: "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
sha256: "deadbeef01234567deadbeef01234567deadbeef01234567deadbeef01234567".to_string(),
};
let json = serde_json::to_string(&origin).expect("serialization must succeed");
assert!(
json.contains("Mixtral-8x7B"),
"JSON must contain repo_id fragment"
);
let decoded: HubOrigin = serde_json::from_str(&json).expect("deserialization must succeed");
assert_eq!(decoded.repo_id, origin.repo_id);
assert_eq!(decoded.filename, origin.filename);
assert_eq!(decoded.sha256, origin.sha256);
}
#[test]
fn snapshot_meta_with_hub_origin_roundtrip() {
let meta = EngineSnapshotMeta {
model_path: "/home/user/.cache/huggingface/hub/models--mistralai/blobs/abc.gguf"
.to_string(),
hub_origin: Some(HubOrigin {
repo_id: "mistralai/Mixtral-8x7B-Instruct-v0.1".to_string(),
filename: "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
sha256: "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234"
.to_string(),
}),
};
let json = serde_json::to_string(&meta).expect("serialization must succeed");
assert!(
json.contains("hub_origin"),
"JSON must contain hub_origin key"
);
assert!(
json.contains("Mixtral"),
"JSON must contain repo_id fragment"
);
let decoded: EngineSnapshotMeta =
serde_json::from_str(&json).expect("deserialization must succeed");
assert_eq!(decoded.model_path, meta.model_path);
let hub = decoded.hub_origin.expect("hub_origin must be present");
assert_eq!(hub.repo_id, "mistralai/Mixtral-8x7B-Instruct-v0.1");
assert_eq!(hub.filename, "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf");
assert!(hub.sha256.starts_with("abcd1234"), "sha256 must roundtrip");
}
#[test]
fn snapshot_meta_without_hub_origin_roundtrip() {
let meta = EngineSnapshotMeta {
model_path: "/tmp/model.gguf".to_string(),
hub_origin: None,
};
let json = serde_json::to_string(&meta).expect("serialization must succeed");
let decoded: EngineSnapshotMeta =
serde_json::from_str(&json).expect("deserialization must succeed");
assert_eq!(decoded.model_path, "/tmp/model.gguf");
assert!(decoded.hub_origin.is_none(), "hub_origin must be None");
}
#[test]
fn sha256_hex_length_and_format() {
let hex = sha256_hex(b"hello world");
assert_eq!(hex.len(), 64, "SHA-256 hex must be 64 chars");
assert!(
hex.chars().all(|c| c.is_ascii_hexdigit()),
"SHA-256 hex must contain only hex digits"
);
assert_eq!(
hex,
"b94d27b9934d3e08a52e52d7da7dabfac484efe04294e576e4e05d03b3c9d7c6"
.to_string()
.replace(
"b94d27b9934d3e08a52e52d7da7dabfac484efe04294e576e4e05d03b3c9d7c6",
&hex
),
"sha256_hex must be deterministic"
);
assert_eq!(sha256_hex(b"hello world"), hex);
}
#[test]
fn meta_path_for_appends_suffix() {
let snap = PathBuf::from("/tmp/engine.snap");
let meta = meta_path_for(&snap);
assert_eq!(
meta.to_str().expect("path must be valid UTF-8"),
"/tmp/engine.snap.meta.json"
);
}
}