use std::path::{Path, PathBuf};
use anyhow::{bail, Context, Result};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelSpec {
pub id: &'static str,
pub file: &'static str,
pub url: &'static str,
pub sha256: &'static str,
pub size_bytes: u64,
pub min_ram_mb: u64,
}
impl ModelSpec {
pub fn is_pinned(&self) -> bool {
self.sha256.len() == 64
}
}
pub const MODEL_PRIMARY: ModelSpec = ModelSpec {
id: "qwen3-4b-instruct-q4_k_m",
file: "qwen3-4b-instruct-q4_k_m.gguf",
url: "",
sha256: "",
size_bytes: 2_500_000_000,
min_ram_mb: 6_000,
};
pub const MODEL_FALLBACK: ModelSpec = ModelSpec {
id: "qwen3-1.7b-instruct-q4_k_m",
file: "qwen3-1.7b-instruct-q4_k_m.gguf",
url: "",
sha256: "",
size_bytes: 1_100_000_000,
min_ram_mb: 0,
};
pub fn select_spec(ram_mb: u64) -> &'static ModelSpec {
if ram_mb >= MODEL_PRIMARY.min_ram_mb {
&MODEL_PRIMARY
} else {
&MODEL_FALLBACK
}
}
pub fn detect_ram_mb() -> u64 {
#[cfg(target_os = "linux")]
{
if let Ok(text) = std::fs::read_to_string("/proc/meminfo") {
for line in text.lines() {
if let Some(rest) = line.strip_prefix("MemTotal:") {
if let Some(kb) = rest.split_whitespace().next() {
if let Ok(kb) = kb.parse::<u64>() {
return kb / 1024;
}
}
}
}
}
}
#[cfg(target_os = "macos")]
{
if let Ok(out) = std::process::Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
{
if let Ok(s) = String::from_utf8(out.stdout) {
if let Ok(bytes) = s.trim().parse::<u64>() {
return bytes / (1024 * 1024);
}
}
}
}
4096
}
pub fn sha256_file(path: &Path) -> Result<String> {
let bytes = std::fs::read(path).with_context(|| format!("read {}", path.display()))?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
Ok(hex::encode(hasher.finalize()))
}
pub fn verify(spec: &ModelSpec, path: &Path) -> Result<bool> {
if !spec.is_pinned() {
bail!(
"model {} has no pinned checksum; refusing to use it",
spec.id
);
}
Ok(sha256_file(path)?.eq_ignore_ascii_case(spec.sha256))
}
pub fn ensure_weights(spec: &ModelSpec, dir: &Path) -> Result<PathBuf> {
if !spec.is_pinned() {
bail!(
"model {} is not pinned (set its url + sha256 before download)",
spec.id
);
}
let path = dir.join(spec.file);
if path.is_file() {
if verify(spec, &path)? {
return Ok(path);
}
bail!("checksum mismatch for {}", path.display());
}
#[cfg(feature = "download")]
{
std::fs::create_dir_all(dir).with_context(|| format!("create {}", dir.display()))?;
download(spec, &path)?;
if !verify(spec, &path)? {
let _ = std::fs::remove_file(&path);
bail!("downloaded weights failed checksum for {}", spec.id);
}
Ok(path)
}
#[cfg(not(feature = "download"))]
{
bail!(
"weights for {} not present at {} (build with --features download to fetch)",
spec.id,
path.display()
)
}
}
#[cfg(feature = "download")]
fn download(spec: &ModelSpec, dest: &Path) -> Result<()> {
if spec.url.is_empty() {
bail!("model {} has no pinned URL", spec.id);
}
let resp = reqwest::blocking::get(spec.url)
.with_context(|| format!("GET {}", spec.url))?
.error_for_status()?;
let bytes = resp.bytes()?;
std::fs::write(dest, &bytes).with_context(|| format!("write {}", dest.display()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn selects_primary_with_enough_ram_else_fallback() {
assert_eq!(select_spec(16_000).id, MODEL_PRIMARY.id);
assert_eq!(select_spec(6_000).id, MODEL_PRIMARY.id);
assert_eq!(select_spec(4_000).id, MODEL_FALLBACK.id);
assert_eq!(select_spec(0).id, MODEL_FALLBACK.id);
}
#[test]
fn detect_ram_is_positive() {
assert!(detect_ram_mb() > 0);
}
#[test]
fn unpinned_specs_are_refused() {
assert!(!MODEL_PRIMARY.is_pinned());
let tmp = tempfile::tempdir().unwrap();
assert!(ensure_weights(&MODEL_PRIMARY, tmp.path()).is_err());
assert!(verify(&MODEL_PRIMARY, tmp.path()).is_err());
}
#[test]
fn checksum_roundtrip_and_match() {
let tmp = tempfile::tempdir().unwrap();
let f = tmp.path().join("blob.bin");
std::fs::write(&f, b"hello kintsugi").unwrap();
let digest = sha256_file(&f).unwrap();
assert_eq!(digest.len(), 64);
let good = ModelSpec {
sha256: Box::leak(digest.clone().into_boxed_str()),
..MODEL_FALLBACK
};
assert!(verify(&good, &f).unwrap());
let bad = ModelSpec {
sha256: "0000000000000000000000000000000000000000000000000000000000000000",
..MODEL_FALLBACK
};
assert!(!verify(&bad, &f).unwrap());
}
}