#![allow(unsafe_code)]
use sha2::{Digest, Sha256};
use std::ffi::CString;
use std::fs::File;
use std::io::{self, Read};
use std::path::{Path, PathBuf};
use std::ptr::NonNull;
#[derive(Debug, thiserror::Error)]
pub enum ModelLoadError {
#[error("io: {0}")]
Io(#[from] io::Error),
#[error("model hash mismatch (expected != actual)")]
HashMismatch,
#[error("path contains NUL byte: {0}")]
PathNul(PathBuf),
#[error("llama_model_load_from_file returned null")]
LlamaLoadFailed,
}
#[derive(Debug)]
pub struct ModelHandle {
ptr: NonNull<crate::ffi::llama_model>,
}
unsafe impl Send for ModelHandle {}
unsafe impl Sync for ModelHandle {}
impl ModelHandle {
pub(crate) fn as_ptr(&self) -> *mut crate::ffi::llama_model {
self.ptr.as_ptr()
}
}
impl Drop for ModelHandle {
fn drop(&mut self) {
unsafe { crate::ffi::llama_model_free(self.ptr.as_ptr()) };
}
}
pub fn load_model(
path: &Path,
expected_sha256: Option<&[u8; 32]>,
gpu_layers: i32,
) -> Result<ModelHandle, ModelLoadError> {
if let Some(expected) = expected_sha256 {
verify_sha256(path, expected)?;
}
let cpath = CString::new(path.as_os_str().to_string_lossy().as_bytes())
.map_err(|_| ModelLoadError::PathNul(path.to_path_buf()))?;
let model_ptr = unsafe {
let mut params = crate::ffi::llama_model_default_params();
params.n_gpu_layers = gpu_layers;
crate::ffi::llama_model_load_from_file(cpath.as_ptr(), params)
};
NonNull::new(model_ptr)
.map(|ptr| ModelHandle { ptr })
.ok_or(ModelLoadError::LlamaLoadFailed)
}
pub fn verify_mmproj_sha256(path: &Path, expected: &[u8; 32]) -> Result<(), ModelLoadError> {
verify_sha256(path, expected)
}
fn verify_sha256(path: &Path, expected: &[u8; 32]) -> Result<(), ModelLoadError> {
let mut file = File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = vec![0u8; 1 << 20]; loop {
let n = file.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let actual = hasher.finalize();
use subtle::ConstantTimeEq;
if actual.as_slice().ct_eq(expected.as_slice()).into() {
Ok(())
} else {
Err(ModelLoadError::HashMismatch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn verify_sha256_accepts_correct_hash() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("blob.bin");
let mut f = File::create(&path).unwrap();
f.write_all(b"hello world").unwrap();
f.sync_all().unwrap();
let expected = hex_lit("b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9");
verify_sha256(&path, &expected).unwrap();
}
#[test]
fn verify_sha256_rejects_mismatch() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("blob.bin");
let mut f = File::create(&path).unwrap();
f.write_all(b"hello world").unwrap();
f.sync_all().unwrap();
let wrong = [0u8; 32];
let err = verify_sha256(&path, &wrong).unwrap_err();
assert!(matches!(err, ModelLoadError::HashMismatch));
}
#[test]
fn load_model_with_wrong_hash_fails_at_hash_check() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("not-a-gguf.bin");
let mut f = File::create(&path).unwrap();
f.write_all(b"definitely not a gguf model file").unwrap();
f.sync_all().unwrap();
let wrong = [0u8; 32];
let err = load_model(&path, Some(&wrong), 0).unwrap_err();
assert!(
matches!(err, ModelLoadError::HashMismatch),
"expected HashMismatch, got {err:?}"
);
}
#[test]
fn load_model_with_no_hash_skips_copy_path() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("not-a-gguf.bin");
let mut f = File::create(&path).unwrap();
f.write_all(b"definitely not a gguf model file").unwrap();
f.sync_all().unwrap();
let err = load_model(&path, None, 0).unwrap_err();
assert!(
!matches!(err, ModelLoadError::HashMismatch),
"no-hash path should not hit HashMismatch; got {err:?}"
);
}
fn hex_lit(s: &str) -> [u8; 32] {
let bytes: Vec<u8> = (0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
.collect();
let mut out = [0u8; 32];
out.copy_from_slice(&bytes);
out
}
}