1use crate::admin::StatusBroadcaster;
26use crate::status::{LoadPhase, StatusEvent};
27use crate::store::{Manifest, ManifestSource, ModelStore, format_blob_ref, parse_blob_ref};
28use sha2::{Digest, Sha256};
29use std::fs::{File, OpenOptions, TryLockError};
30use std::io::{self, Read, Write};
31use std::path::{Path, PathBuf};
32use std::time::{Duration, Instant};
33use subtle::ConstantTimeEq;
34use tracing::{info, warn};
35
36#[derive(Debug, Clone)]
40pub struct ModelSpec {
41 pub name: String,
44 pub source_url: String,
48 pub sha256_hex: String,
50 pub size_bytes: Option<u64>,
55 pub license: Option<String>,
58 pub source: Option<ManifestSource>,
61}
62
63#[derive(Debug, thiserror::Error)]
65pub enum FetchError {
66 #[error("model URL must be https:// (got {0:?})")]
68 InsecureUrl(String),
69 #[error("http transport: {0}")]
71 Transport(String),
72 #[error("http status {0}")]
74 HttpStatus(u16),
75 #[error("io: {0}")]
77 Io(#[from] io::Error),
78 #[error(
81 "SHA-256 mismatch (expected {expected}, got {actual}); quarantined to {quarantine_path}"
82 )]
83 HashMismatch {
84 expected: String,
86 actual: String,
88 quarantine_path: PathBuf,
90 },
91 #[error("finalise rename: {0}")]
93 Finalise(io::Error),
94 #[error("model {name:?} is being fetched by another process")]
97 LockContended {
98 name: String,
100 },
101 #[error("model {name:?} has no source_url and no manifest exists")]
105 NoSourceNoManifest {
106 name: String,
108 },
109}
110
111pub fn fetch_model(
125 spec: &ModelSpec,
126 store: &ModelStore,
127 broadcaster: &StatusBroadcaster,
128) -> Result<PathBuf, FetchError> {
129 store.ensure_layout()?;
130
131 let blob_path = store.blob_path(&spec.sha256_hex);
132
133 broadcaster.publish(StatusEvent::LoadingModel {
135 phase: LoadPhase::CheckingLocal {
136 path: blob_path.clone(),
137 },
138 });
139
140 if let Some(manifest) = store.read_manifest(&spec.name)? {
141 if let Some(manifest_sha) = parse_blob_ref(&manifest.blob) {
145 if hex_ct_eq(manifest_sha, &spec.sha256_hex) && blob_path.exists() {
146 info!(
147 name = %spec.name,
148 blob = %blob_path.display(),
149 "manifest + blob already present; skipping fetch"
150 );
151 return Ok(blob_path);
152 }
153 if !hex_ct_eq(manifest_sha, &spec.sha256_hex) {
154 warn!(
155 name = %spec.name,
156 expected = %spec.sha256_hex,
157 in_manifest = %manifest_sha,
158 "manifest blob ref disagrees with config sha; rewriting manifest"
159 );
160 }
161 }
162 }
163
164 let _lock = acquire_name_lock(store, &spec.name)?;
166
167 if blob_path.exists() {
170 let actual = sha256_of_path(&blob_path)?;
171 if hex_ct_eq(&actual, &spec.sha256_hex) {
172 write_manifest_for(store, spec, blob_path.metadata()?.len())?;
174 info!(name = %spec.name, "blob landed by concurrent producer; manifest written");
175 return Ok(blob_path);
176 }
177 warn!(
181 name = %spec.name,
182 expected = %spec.sha256_hex,
183 actual = %actual,
184 "blob at CAS path failed re-hash; quarantining"
185 );
186 let qpath = store.quarantine(&blob_path, "sha-mismatch")?;
187 broadcaster.publish(StatusEvent::LoadingModel {
188 phase: LoadPhase::Quarantine {
189 path: blob_path.clone(),
190 expected_sha256: spec.sha256_hex.clone(),
191 actual_sha256: actual,
192 quarantine_path: qpath,
193 },
194 });
195 }
196
197 if spec.source_url.is_empty() {
199 return Err(FetchError::NoSourceNoManifest {
200 name: spec.name.clone(),
201 });
202 }
203 if !spec.source_url.starts_with("https://") {
204 return Err(FetchError::InsecureUrl(spec.source_url.clone()));
205 }
206
207 let partial = store.partial_path(&spec.sha256_hex);
208 if let Some(parent) = partial.parent() {
209 std::fs::create_dir_all(parent)?;
210 }
211 let downloaded = download_with_progress(spec, &partial, broadcaster)?;
212
213 broadcaster.publish(StatusEvent::LoadingModel {
215 phase: LoadPhase::Verify {
216 path: partial.clone(),
217 },
218 });
219 let actual = sha256_of_path(&partial)?;
220 if !hex_ct_eq(&actual, &spec.sha256_hex) {
221 let qpath = store.quarantine(&partial, "sha-mismatch")?;
222 broadcaster.publish(StatusEvent::LoadingModel {
223 phase: LoadPhase::Quarantine {
224 path: partial.clone(),
225 expected_sha256: spec.sha256_hex.clone(),
226 actual_sha256: actual.clone(),
227 quarantine_path: qpath.clone(),
228 },
229 });
230 if let Some(parent) = partial.parent() {
232 let _ = std::fs::remove_dir(parent);
233 }
234 return Err(FetchError::HashMismatch {
235 expected: spec.sha256_hex.clone(),
236 actual,
237 quarantine_path: qpath,
238 });
239 }
240
241 if let Some(parent) = blob_path.parent() {
243 std::fs::create_dir_all(parent)?;
244 }
245 std::fs::rename(&partial, &blob_path).map_err(FetchError::Finalise)?;
246 if let Some(parent) = partial.parent() {
247 let _ = std::fs::remove_dir(parent);
248 }
249
250 write_manifest_for(store, spec, downloaded)?;
254 info!(
255 name = %spec.name,
256 blob = %blob_path.display(),
257 "model installed"
258 );
259 Ok(blob_path)
260}
261
262struct NameLock {
265 _file: File,
266}
267
268fn acquire_name_lock(store: &ModelStore, name: &str) -> Result<NameLock, FetchError> {
269 let lock_path = store.lock_path(name);
270 if let Some(parent) = lock_path.parent() {
271 std::fs::create_dir_all(parent)?;
272 }
273 let file = OpenOptions::new()
274 .read(true)
275 .write(true)
276 .create(true)
277 .truncate(false)
278 .open(&lock_path)?;
279 match file.try_lock() {
280 Ok(()) => Ok(NameLock { _file: file }),
281 Err(TryLockError::WouldBlock) => Err(FetchError::LockContended {
282 name: name.to_string(),
283 }),
284 Err(TryLockError::Error(e)) => Err(FetchError::Io(e)),
285 }
286}
287
288fn write_manifest_for(
289 store: &ModelStore,
290 spec: &ModelSpec,
291 size_bytes: u64,
292) -> Result<(), FetchError> {
293 let source = spec.source.clone().unwrap_or_else(|| ManifestSource {
294 registry: registry_from_url(&spec.source_url),
295 repo: String::new(),
296 revision: String::new(),
297 filename: filename_from_url(&spec.source_url),
298 });
299 let manifest = Manifest {
300 schema_version: 1,
301 name: spec.name.clone(),
302 format: "gguf".into(),
303 blob: format_blob_ref(&spec.sha256_hex),
304 size_bytes,
305 license: spec.license.clone(),
306 source,
307 produced_by: format!("inferd/{}", env!("CARGO_PKG_VERSION")),
308 produced_at: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
309 };
310 store
311 .write_manifest(&manifest)
312 .map_err(FetchError::Io)
313 .map(|_| ())
314}
315
316fn registry_from_url(url: &str) -> String {
317 url.strip_prefix("https://")
318 .and_then(|rest| rest.split('/').next())
319 .unwrap_or("")
320 .to_string()
321}
322
323fn filename_from_url(url: &str) -> String {
324 url.rsplit('/').next().unwrap_or("").to_string()
325}
326
327fn download_with_progress(
328 spec: &ModelSpec,
329 dest: &Path,
330 broadcaster: &StatusBroadcaster,
331) -> Result<u64, FetchError> {
332 let agent = ureq::AgentBuilder::new()
333 .timeout_connect(Duration::from_secs(30))
334 .build();
335
336 info!(
337 url = %spec.source_url,
338 name = %spec.name,
339 "model download starting"
340 );
341
342 let resp = agent
343 .get(&spec.source_url)
344 .call()
345 .map_err(|e| FetchError::Transport(e.to_string()))?;
346 let status = resp.status();
347 if !(200..300).contains(&status) {
348 return Err(FetchError::HttpStatus(status));
349 }
350 let total = resp
351 .header("content-length")
352 .and_then(|s| s.parse::<u64>().ok())
353 .or(spec.size_bytes);
354 if let Some(t) = total {
355 info!(
356 total_bytes = t,
357 total_mib = t / (1024 * 1024),
358 "model download size known"
359 );
360 } else {
361 info!("model download size unknown (no Content-Length)");
362 }
363
364 let mut reader = resp.into_reader();
365 let mut file = OpenOptions::new()
366 .create(true)
367 .write(true)
368 .truncate(true)
369 .open(dest)?;
370
371 let mut buf = vec![0u8; 1 << 20]; let mut downloaded: u64 = 0;
373 let mut last_publish = Instant::now();
374 let mut next_byte_milestone: u64 = 32 << 20; broadcaster.publish(StatusEvent::LoadingModel {
377 phase: LoadPhase::Download {
378 downloaded_bytes: 0,
379 total_bytes: total,
380 source_url: spec.source_url.clone(),
381 },
382 });
383
384 loop {
385 let n = reader.read(&mut buf)?;
386 if n == 0 {
387 break;
388 }
389 file.write_all(&buf[..n])?;
390 downloaded += n as u64;
391
392 let now = Instant::now();
393 let due = downloaded >= next_byte_milestone
394 || now.duration_since(last_publish) >= Duration::from_secs(5);
395 if due {
396 broadcaster.publish(StatusEvent::LoadingModel {
397 phase: LoadPhase::Download {
398 downloaded_bytes: downloaded,
399 total_bytes: total,
400 source_url: spec.source_url.clone(),
401 },
402 });
403 let pct = total
410 .map(|t| (downloaded as f64 / t as f64) * 100.0)
411 .map(|p| format!("{p:5.1}%"))
412 .unwrap_or_else(|| " ? ".to_string());
413 let mib = downloaded / (1024 * 1024);
414 let total_mib = total.map(|t| t / (1024 * 1024)).unwrap_or(0);
415 info!(
416 downloaded_mib = mib,
417 total_mib = total_mib,
418 pct = %pct,
419 "model download progress"
420 );
421 last_publish = now;
422 next_byte_milestone = downloaded + (32 << 20);
423 }
424 }
425 file.flush()?;
426
427 broadcaster.publish(StatusEvent::LoadingModel {
428 phase: LoadPhase::Download {
429 downloaded_bytes: downloaded,
430 total_bytes: total.or(Some(downloaded)),
431 source_url: spec.source_url.clone(),
432 },
433 });
434 info!(
435 downloaded_mib = downloaded / (1024 * 1024),
436 "model download complete"
437 );
438 Ok(downloaded)
439}
440
441fn sha256_of_path(path: &Path) -> io::Result<String> {
443 let mut file = File::open(path)?;
444 let mut hasher = Sha256::new();
445 let mut buf = vec![0u8; 1 << 20];
446 loop {
447 let n = file.read(&mut buf)?;
448 if n == 0 {
449 break;
450 }
451 hasher.update(&buf[..n]);
452 }
453 let bytes = hasher.finalize();
454 let mut s = String::with_capacity(bytes.len() * 2);
455 for b in bytes {
456 s.push_str(&format!("{:02x}", b));
457 }
458 Ok(s)
459}
460
461fn hex_ct_eq(a: &str, b: &str) -> bool {
462 a.as_bytes().ct_eq(b.as_bytes()).into()
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use tempfile::tempdir;
469
470 const HELLO_WORLD_SHA: &str =
472 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
473
474 fn dummy_broadcaster() -> StatusBroadcaster {
475 StatusBroadcaster::new(StatusEvent::Starting)
476 }
477
478 fn write_blob_at(store: &ModelStore, sha: &str, contents: &[u8]) -> PathBuf {
479 let blob = store.blob_path(sha);
480 std::fs::create_dir_all(blob.parent().unwrap()).unwrap();
481 std::fs::write(&blob, contents).unwrap();
482 blob
483 }
484
485 #[test]
486 fn fetch_returns_immediately_when_manifest_and_blob_present() {
487 let dir = tempdir().unwrap();
488 let store = ModelStore::open(dir.path());
489 store.ensure_layout().unwrap();
490
491 let blob = write_blob_at(&store, HELLO_WORLD_SHA, b"hello world");
493 let manifest = Manifest {
494 schema_version: 1,
495 name: "test".into(),
496 format: "gguf".into(),
497 blob: format_blob_ref(HELLO_WORLD_SHA),
498 size_bytes: 11,
499 license: None,
500 source: ManifestSource {
501 registry: "example.invalid".into(),
502 repo: String::new(),
503 revision: String::new(),
504 filename: "blob.gguf".into(),
505 },
506 produced_by: "test".into(),
507 produced_at: "2026-05-18T00:00:00Z".into(),
508 };
509 store.write_manifest(&manifest).unwrap();
510
511 let spec = ModelSpec {
512 name: "test".into(),
513 source_url: "https://example.invalid/blob.gguf".into(),
514 sha256_hex: HELLO_WORLD_SHA.into(),
515 size_bytes: Some(11),
516 license: None,
517 source: None,
518 };
519
520 let b = dummy_broadcaster();
521 let mut rx = b.subscribe();
522 let got = fetch_model(&spec, &store, &b).unwrap();
523 assert_eq!(got, blob);
524
525 let ev = rx.try_recv().unwrap();
526 assert!(matches!(
527 ev,
528 StatusEvent::LoadingModel {
529 phase: LoadPhase::CheckingLocal { .. }
530 }
531 ));
532 }
533
534 #[test]
535 fn fetch_quarantines_blob_with_wrong_bytes() {
536 let dir = tempdir().unwrap();
537 let store = ModelStore::open(dir.path());
538 store.ensure_layout().unwrap();
539
540 let blob = write_blob_at(&store, HELLO_WORLD_SHA, b"different bytes");
542
543 let spec = ModelSpec {
544 name: "test".into(),
545 source_url: "https://example.invalid/blob.gguf".into(),
546 sha256_hex: HELLO_WORLD_SHA.into(),
547 size_bytes: Some(11),
548 license: None,
549 source: None,
550 };
551 let b = dummy_broadcaster();
552 let _ = fetch_model(&spec, &store, &b);
555
556 assert!(!blob.exists(), "bad blob should have been quarantined");
557 let qdir = store.quarantine_dir();
558 assert!(qdir.is_dir());
559 let entries: Vec<_> = std::fs::read_dir(&qdir)
560 .unwrap()
561 .filter_map(Result::ok)
562 .collect();
563 assert!(
564 !entries.is_empty(),
565 "expected at least one quarantined file"
566 );
567 }
568
569 #[test]
570 fn fetch_rejects_non_https_url() {
571 let dir = tempdir().unwrap();
572 let store = ModelStore::open(dir.path());
573 let spec = ModelSpec {
574 name: "test".into(),
575 source_url: "http://example.invalid/blob.gguf".into(),
576 sha256_hex: HELLO_WORLD_SHA.into(),
577 size_bytes: None,
578 license: None,
579 source: None,
580 };
581 let b = dummy_broadcaster();
582 let err = fetch_model(&spec, &store, &b).unwrap_err();
583 assert!(matches!(err, FetchError::InsecureUrl(_)));
584 }
585
586 #[test]
587 fn fetch_errors_when_no_source_and_no_manifest() {
588 let dir = tempdir().unwrap();
589 let store = ModelStore::open(dir.path());
590 let spec = ModelSpec {
591 name: "test".into(),
592 source_url: String::new(),
593 sha256_hex: HELLO_WORLD_SHA.into(),
594 size_bytes: None,
595 license: None,
596 source: None,
597 };
598 let b = dummy_broadcaster();
599 let err = fetch_model(&spec, &store, &b).unwrap_err();
600 assert!(matches!(err, FetchError::NoSourceNoManifest { .. }));
601 }
602
603 #[test]
604 fn sha256_of_known_input() {
605 let dir = tempdir().unwrap();
606 let path = dir.path().join("blob");
607 std::fs::write(&path, b"hello world").unwrap();
608 let got = sha256_of_path(&path).unwrap();
609 assert_eq!(got, HELLO_WORLD_SHA);
610 }
611
612 #[test]
613 fn registry_from_url_pulls_hostname() {
614 assert_eq!(
615 registry_from_url("https://huggingface.co/foo/bar.gguf"),
616 "huggingface.co"
617 );
618 assert_eq!(registry_from_url("not-a-url"), "");
619 }
620
621 #[test]
622 fn filename_from_url_pulls_basename() {
623 assert_eq!(
624 filename_from_url("https://huggingface.co/foo/x.gguf"),
625 "x.gguf"
626 );
627 }
628}