1use std::path::{Path, PathBuf};
9
10use anyhow::{bail, Context, Result};
11use sha2::{Digest, Sha256};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct ModelSpec {
16 pub id: &'static str,
18 pub file: &'static str,
20 pub url: &'static str,
22 pub sha256: &'static str,
24 pub size_bytes: u64,
26 pub min_ram_mb: u64,
28}
29
30impl ModelSpec {
31 pub fn is_pinned(&self) -> bool {
33 self.sha256.len() == 64
34 }
35}
36
37pub const MODEL_PRIMARY: ModelSpec = ModelSpec {
58 id: "qwen3-4b-instruct-q4_k_m",
59 file: "qwen3-4b-instruct-q4_k_m.gguf",
60 url: "",
61 sha256: "",
62 size_bytes: 2_500_000_000,
63 min_ram_mb: 6_000,
64};
65
66pub const MODEL_FALLBACK: ModelSpec = ModelSpec {
69 id: "qwen3-1.7b-instruct-q4_k_m",
70 file: "qwen3-1.7b-instruct-q4_k_m.gguf",
71 url: "",
72 sha256: "",
73 size_bytes: 1_100_000_000,
74 min_ram_mb: 0,
75};
76
77pub fn select_spec(ram_mb: u64) -> &'static ModelSpec {
79 if ram_mb >= MODEL_PRIMARY.min_ram_mb {
80 &MODEL_PRIMARY
81 } else {
82 &MODEL_FALLBACK
83 }
84}
85
86pub fn detect_ram_mb() -> u64 {
88 #[cfg(target_os = "linux")]
89 {
90 if let Ok(text) = std::fs::read_to_string("/proc/meminfo") {
91 for line in text.lines() {
92 if let Some(rest) = line.strip_prefix("MemTotal:") {
93 if let Some(kb) = rest.split_whitespace().next() {
94 if let Ok(kb) = kb.parse::<u64>() {
95 return kb / 1024;
96 }
97 }
98 }
99 }
100 }
101 }
102 #[cfg(target_os = "macos")]
103 {
104 if let Ok(out) = std::process::Command::new("sysctl")
105 .args(["-n", "hw.memsize"])
106 .output()
107 {
108 if let Ok(s) = String::from_utf8(out.stdout) {
109 if let Ok(bytes) = s.trim().parse::<u64>() {
110 return bytes / (1024 * 1024);
111 }
112 }
113 }
114 }
115 4096
116}
117
118pub fn sha256_file(path: &Path) -> Result<String> {
120 let bytes = std::fs::read(path).with_context(|| format!("read {}", path.display()))?;
121 let mut hasher = Sha256::new();
122 hasher.update(&bytes);
123 Ok(hex::encode(hasher.finalize()))
124}
125
126pub fn verify(spec: &ModelSpec, path: &Path) -> Result<bool> {
128 if !spec.is_pinned() {
129 bail!(
130 "model {} has no pinned checksum; refusing to use it",
131 spec.id
132 );
133 }
134 Ok(sha256_file(path)?.eq_ignore_ascii_case(spec.sha256))
135}
136
137pub fn ensure_weights(spec: &ModelSpec, dir: &Path) -> Result<PathBuf> {
143 if !spec.is_pinned() {
144 bail!(
145 "model {} is not pinned (set its url + sha256 before download)",
146 spec.id
147 );
148 }
149 let path = dir.join(spec.file);
150 if path.is_file() {
151 if verify(spec, &path)? {
152 return Ok(path);
153 }
154 bail!("checksum mismatch for {}", path.display());
155 }
156
157 #[cfg(feature = "download")]
158 {
159 std::fs::create_dir_all(dir).with_context(|| format!("create {}", dir.display()))?;
160 download(spec, &path)?;
161 if !verify(spec, &path)? {
162 let _ = std::fs::remove_file(&path);
163 bail!("downloaded weights failed checksum for {}", spec.id);
164 }
165 Ok(path)
166 }
167 #[cfg(not(feature = "download"))]
168 {
169 bail!(
170 "weights for {} not present at {} (build with --features download to fetch)",
171 spec.id,
172 path.display()
173 )
174 }
175}
176
177#[cfg(feature = "download")]
179fn download(spec: &ModelSpec, dest: &Path) -> Result<()> {
180 if spec.url.is_empty() {
181 bail!("model {} has no pinned URL", spec.id);
182 }
183 let resp = reqwest::blocking::get(spec.url)
184 .with_context(|| format!("GET {}", spec.url))?
185 .error_for_status()?;
186 let bytes = resp.bytes()?;
187 std::fs::write(dest, &bytes).with_context(|| format!("write {}", dest.display()))?;
188 Ok(())
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn selects_primary_with_enough_ram_else_fallback() {
197 assert_eq!(select_spec(16_000).id, MODEL_PRIMARY.id);
198 assert_eq!(select_spec(6_000).id, MODEL_PRIMARY.id);
199 assert_eq!(select_spec(4_000).id, MODEL_FALLBACK.id);
200 assert_eq!(select_spec(0).id, MODEL_FALLBACK.id);
201 }
202
203 #[test]
204 fn detect_ram_is_positive() {
205 assert!(detect_ram_mb() > 0);
206 }
207
208 #[test]
209 fn unpinned_specs_are_refused() {
210 assert!(!MODEL_PRIMARY.is_pinned());
212 let tmp = tempfile::tempdir().unwrap();
213 assert!(ensure_weights(&MODEL_PRIMARY, tmp.path()).is_err());
214 assert!(verify(&MODEL_PRIMARY, tmp.path()).is_err());
215 }
216
217 #[test]
218 fn checksum_roundtrip_and_match() {
219 let tmp = tempfile::tempdir().unwrap();
220 let f = tmp.path().join("blob.bin");
221 std::fs::write(&f, b"hello kintsugi").unwrap();
222 let digest = sha256_file(&f).unwrap();
223 assert_eq!(digest.len(), 64);
224
225 let good = ModelSpec {
227 sha256: Box::leak(digest.clone().into_boxed_str()),
228 ..MODEL_FALLBACK
229 };
230 assert!(verify(&good, &f).unwrap());
231 let bad = ModelSpec {
232 sha256: "0000000000000000000000000000000000000000000000000000000000000000",
233 ..MODEL_FALLBACK
234 };
235 assert!(!verify(&bad, &f).unwrap());
236 }
237}