Skip to main content

inferd_daemon/
fetch.rs

1//! First-boot model bootstrap into the shared CAS store.
2//!
3//! Per ADR 0010 the daemon may issue outbound HTTPS for one purpose
4//! only: fetching a pinned GGUF named in `~/.inferd/config.json`. Per
5//! ADR 0011 the bytes land in the shared content-addressable store
6//! at `$MODELS_HOME/blobs/sha256/<aa>/<hash>/data`, with a manifest
7//! written at `$MODELS_HOME/manifests/<name>.json`.
8//!
9//! Producer flow:
10//!
11//! 1. Acquire `LOCK_EX` on `$MODELS_HOME/locks/<name>.lock`.
12//! 2. If the manifest already names a blob and that blob exists:
13//!    optional re-verify, then return the blob path immediately.
14//! 3. Otherwise stream HTTPS into
15//!    `$MODELS_HOME/blobs/sha256/<aa>/.partial-<hash>/data.tmp`
16//!    with a running SHA-256.
17//! 4. Constant-time compare computed vs expected SHA (F-5).
18//! 5. On match: atomic-rename into place, write manifest. On
19//!    mismatch: move bad bytes into `locks/quarantine/` and bail.
20//! 6. Release the lock.
21//!
22//! Progress events publish through a `StatusBroadcaster` so the
23//! admin socket can fan them out to UIs and middleware.
24
25use crate::admin::StatusBroadcaster;
26use crate::status::{LoadPhase, StatusEvent};
27use crate::store::{Manifest, ManifestSource, ModelStore, format_blob_ref, parse_blob_ref};
28use sha2::{Digest, Sha256};
29use std::fs::{File, OpenOptions, TryLockError};
30use std::io::{self, Read, Write};
31use std::path::{Path, PathBuf};
32use std::time::{Duration, Instant};
33use subtle::ConstantTimeEq;
34use tracing::{info, warn};
35
36/// One downloadable GGUF model. The fetch contract is one URL +
37/// expected SHA-256; anything more elaborate (registries, mirrors)
38/// belongs in the operator's HTTP proxy or a `wget` step, not here.
39#[derive(Debug, Clone)]
40pub struct ModelSpec {
41    /// Stable identifier, e.g. `"gemma-4-e4b"`. Used as the manifest
42    /// filename (`<name>.json`) and the lock-file basename.
43    pub name: String,
44    /// Direct-download HTTPS endpoint. Must be `https://`. Empty
45    /// string is permitted for CLI-only mode where the operator has
46    /// pre-placed bytes at a manifest-defined blob path.
47    pub source_url: String,
48    /// Lowercase hex SHA-256 of the GGUF bytes. Required.
49    pub sha256_hex: String,
50    /// Advisory total size for progress reporting + manifest. `None`
51    /// = unknown (Content-Length missing); progress frames omit
52    /// `total_bytes` and the manifest records the actually-downloaded
53    /// size.
54    pub size_bytes: Option<u64>,
55    /// SPDX-style license id when known. Recorded in the manifest
56    /// for cross-tool consumers; not consulted at runtime.
57    pub license: Option<String>,
58    /// Diagnostic provenance for the manifest. Optional — falls back
59    /// to a derived shape from `source_url` if absent.
60    pub source: Option<ManifestSource>,
61}
62
63/// Errors produced by `fetch_model`.
64#[derive(Debug, thiserror::Error)]
65pub enum FetchError {
66    /// `source_url` was not `https://`.
67    #[error("model URL must be https:// (got {0:?})")]
68    InsecureUrl(String),
69    /// HTTP transport error (DNS, TLS, refused connection).
70    #[error("http transport: {0}")]
71    Transport(String),
72    /// Server returned a non-success status.
73    #[error("http status {0}")]
74    HttpStatus(u16),
75    /// I/O error reading body or writing dest.
76    #[error("io: {0}")]
77    Io(#[from] io::Error),
78    /// SHA-256 mismatch between downloaded bytes and `sha256_hex`.
79    /// File has been moved into `locks/quarantine/`.
80    #[error(
81        "SHA-256 mismatch (expected {expected}, got {actual}); quarantined to {quarantine_path}"
82    )]
83    HashMismatch {
84        /// What the config said.
85        expected: String,
86        /// What we computed.
87        actual: String,
88        /// Where the bad bytes were moved.
89        quarantine_path: PathBuf,
90    },
91    /// Atomic rename of the partial into the final blob path failed.
92    #[error("finalise rename: {0}")]
93    Finalise(io::Error),
94    /// Another producer holds the per-name lock. Another daemon is
95    /// currently fetching this same model.
96    #[error("model {name:?} is being fetched by another process")]
97    LockContended {
98        /// Model name that was contended.
99        name: String,
100    },
101    /// CLI-only mode: source_url is empty AND no manifest exists.
102    /// Operator must either set source_url or pre-write a manifest +
103    /// blob.
104    #[error("model {name:?} has no source_url and no manifest exists")]
105    NoSourceNoManifest {
106        /// Model name that couldn't be resolved.
107        name: String,
108    },
109}
110
111/// Resolve a model into its CAS blob path.
112///
113/// If the manifest exists and its referenced blob is on disk, return
114/// the blob path immediately (no network, no re-hash by default).
115/// Otherwise — if `source_url` is set — download into the partial
116/// area, verify, atomic-rename into place, write the manifest, and
117/// return the blob path.
118///
119/// Publishes phase events through `broadcaster`:
120/// - `CheckingLocal { path }` on entry.
121/// - `Download { downloaded, total, source_url }` periodically.
122/// - `Verify { path }` after download completes.
123/// - `Quarantine { ... }` on SHA mismatch.
124pub fn fetch_model(
125    spec: &ModelSpec,
126    store: &ModelStore,
127    broadcaster: &StatusBroadcaster,
128) -> Result<PathBuf, FetchError> {
129    store.ensure_layout()?;
130
131    let blob_path = store.blob_path(&spec.sha256_hex);
132
133    // Phase 1: check the manifest + blob.
134    broadcaster.publish(StatusEvent::LoadingModel {
135        phase: LoadPhase::CheckingLocal {
136            path: blob_path.clone(),
137        },
138    });
139
140    if let Some(manifest) = store.read_manifest(&spec.name)? {
141        // Manifest names a SHA. If it matches the expected SHA AND
142        // the blob is on disk, we're done — content addressing is
143        // the trust boundary.
144        if let Some(manifest_sha) = parse_blob_ref(&manifest.blob) {
145            if hex_ct_eq(manifest_sha, &spec.sha256_hex) && blob_path.exists() {
146                info!(
147                    name = %spec.name,
148                    blob = %blob_path.display(),
149                    "manifest + blob already present; skipping fetch"
150                );
151                return Ok(blob_path);
152            }
153            if !hex_ct_eq(manifest_sha, &spec.sha256_hex) {
154                warn!(
155                    name = %spec.name,
156                    expected = %spec.sha256_hex,
157                    in_manifest = %manifest_sha,
158                    "manifest blob ref disagrees with config sha; rewriting manifest"
159                );
160            }
161        }
162    }
163
164    // Acquire the per-name lock. Held until the function returns.
165    let _lock = acquire_name_lock(store, &spec.name)?;
166
167    // Re-check after lock acquisition (someone else may have
168    // finished between phase 1 and the lock).
169    if blob_path.exists() {
170        let actual = sha256_of_path(&blob_path)?;
171        if hex_ct_eq(&actual, &spec.sha256_hex) {
172            // Make sure the manifest reflects current truth.
173            write_manifest_for(store, spec, blob_path.metadata()?.len())?;
174            info!(name = %spec.name, "blob landed by concurrent producer; manifest written");
175            return Ok(blob_path);
176        }
177        // Blob exists at the right path but bytes are wrong. The CAS
178        // path IS the hash, so this should be impossible without
179        // tampering. Quarantine and re-fetch.
180        warn!(
181            name = %spec.name,
182            expected = %spec.sha256_hex,
183            actual = %actual,
184            "blob at CAS path failed re-hash; quarantining"
185        );
186        let qpath = store.quarantine(&blob_path, "sha-mismatch")?;
187        broadcaster.publish(StatusEvent::LoadingModel {
188            phase: LoadPhase::Quarantine {
189                path: blob_path.clone(),
190                expected_sha256: spec.sha256_hex.clone(),
191                actual_sha256: actual,
192                quarantine_path: qpath,
193            },
194        });
195    }
196
197    // Phase 2: download — guarded by source_url presence.
198    if spec.source_url.is_empty() {
199        return Err(FetchError::NoSourceNoManifest {
200            name: spec.name.clone(),
201        });
202    }
203    if !spec.source_url.starts_with("https://") {
204        return Err(FetchError::InsecureUrl(spec.source_url.clone()));
205    }
206
207    let partial = store.partial_path(&spec.sha256_hex);
208    if let Some(parent) = partial.parent() {
209        std::fs::create_dir_all(parent)?;
210    }
211    let downloaded = download_with_progress(spec, &partial, broadcaster)?;
212
213    // Phase 3: verify.
214    broadcaster.publish(StatusEvent::LoadingModel {
215        phase: LoadPhase::Verify {
216            path: partial.clone(),
217        },
218    });
219    let actual = sha256_of_path(&partial)?;
220    if !hex_ct_eq(&actual, &spec.sha256_hex) {
221        let qpath = store.quarantine(&partial, "sha-mismatch")?;
222        broadcaster.publish(StatusEvent::LoadingModel {
223            phase: LoadPhase::Quarantine {
224                path: partial.clone(),
225                expected_sha256: spec.sha256_hex.clone(),
226                actual_sha256: actual.clone(),
227                quarantine_path: qpath.clone(),
228            },
229        });
230        // Best-effort cleanup of the empty `.partial-<hash>` dir.
231        if let Some(parent) = partial.parent() {
232            let _ = std::fs::remove_dir(parent);
233        }
234        return Err(FetchError::HashMismatch {
235            expected: spec.sha256_hex.clone(),
236            actual,
237            quarantine_path: qpath,
238        });
239    }
240
241    // Phase 4: atomic rename into the CAS path.
242    if let Some(parent) = blob_path.parent() {
243        std::fs::create_dir_all(parent)?;
244    }
245    std::fs::rename(&partial, &blob_path).map_err(FetchError::Finalise)?;
246    if let Some(parent) = partial.parent() {
247        let _ = std::fs::remove_dir(parent);
248    }
249
250    // Phase 5: write manifest last. Readers don't trust a manifest
251    // until its blob is on disk, so manifest-after-blob is the safe
252    // ordering.
253    write_manifest_for(store, spec, downloaded)?;
254    info!(
255        name = %spec.name,
256        blob = %blob_path.display(),
257        "model installed"
258    );
259    Ok(blob_path)
260}
261
262/// RAII handle on `$MODELS_HOME/locks/<name>.lock`. Dropped at
263/// function exit releases the lock.
264struct NameLock {
265    _file: File,
266}
267
268fn acquire_name_lock(store: &ModelStore, name: &str) -> Result<NameLock, FetchError> {
269    let lock_path = store.lock_path(name);
270    if let Some(parent) = lock_path.parent() {
271        std::fs::create_dir_all(parent)?;
272    }
273    let file = OpenOptions::new()
274        .read(true)
275        .write(true)
276        .create(true)
277        .truncate(false)
278        .open(&lock_path)?;
279    match file.try_lock() {
280        Ok(()) => Ok(NameLock { _file: file }),
281        Err(TryLockError::WouldBlock) => Err(FetchError::LockContended {
282            name: name.to_string(),
283        }),
284        Err(TryLockError::Error(e)) => Err(FetchError::Io(e)),
285    }
286}
287
288fn write_manifest_for(
289    store: &ModelStore,
290    spec: &ModelSpec,
291    size_bytes: u64,
292) -> Result<(), FetchError> {
293    let source = spec.source.clone().unwrap_or_else(|| ManifestSource {
294        registry: registry_from_url(&spec.source_url),
295        repo: String::new(),
296        revision: String::new(),
297        filename: filename_from_url(&spec.source_url),
298    });
299    let manifest = Manifest {
300        schema_version: 1,
301        name: spec.name.clone(),
302        format: "gguf".into(),
303        blob: format_blob_ref(&spec.sha256_hex),
304        size_bytes,
305        license: spec.license.clone(),
306        source,
307        produced_by: format!("inferd/{}", env!("CARGO_PKG_VERSION")),
308        produced_at: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
309    };
310    store
311        .write_manifest(&manifest)
312        .map_err(FetchError::Io)
313        .map(|_| ())
314}
315
316fn registry_from_url(url: &str) -> String {
317    url.strip_prefix("https://")
318        .and_then(|rest| rest.split('/').next())
319        .unwrap_or("")
320        .to_string()
321}
322
323fn filename_from_url(url: &str) -> String {
324    url.rsplit('/').next().unwrap_or("").to_string()
325}
326
327fn download_with_progress(
328    spec: &ModelSpec,
329    dest: &Path,
330    broadcaster: &StatusBroadcaster,
331) -> Result<u64, FetchError> {
332    let agent = ureq::AgentBuilder::new()
333        .timeout_connect(Duration::from_secs(30))
334        .build();
335
336    info!(
337        url = %spec.source_url,
338        name = %spec.name,
339        "model download starting"
340    );
341
342    let resp = agent
343        .get(&spec.source_url)
344        .call()
345        .map_err(|e| FetchError::Transport(e.to_string()))?;
346    let status = resp.status();
347    if !(200..300).contains(&status) {
348        return Err(FetchError::HttpStatus(status));
349    }
350    let total = resp
351        .header("content-length")
352        .and_then(|s| s.parse::<u64>().ok())
353        .or(spec.size_bytes);
354    if let Some(t) = total {
355        info!(
356            total_bytes = t,
357            total_mib = t / (1024 * 1024),
358            "model download size known"
359        );
360    } else {
361        info!("model download size unknown (no Content-Length)");
362    }
363
364    let mut reader = resp.into_reader();
365    let mut file = OpenOptions::new()
366        .create(true)
367        .write(true)
368        .truncate(true)
369        .open(dest)?;
370
371    let mut buf = vec![0u8; 1 << 20]; // 1 MiB
372    let mut downloaded: u64 = 0;
373    let mut last_publish = Instant::now();
374    let mut next_byte_milestone: u64 = 32 << 20; // every 32 MiB
375
376    broadcaster.publish(StatusEvent::LoadingModel {
377        phase: LoadPhase::Download {
378            downloaded_bytes: 0,
379            total_bytes: total,
380            source_url: spec.source_url.clone(),
381        },
382    });
383
384    loop {
385        let n = reader.read(&mut buf)?;
386        if n == 0 {
387            break;
388        }
389        file.write_all(&buf[..n])?;
390        downloaded += n as u64;
391
392        let now = Instant::now();
393        let due = downloaded >= next_byte_milestone
394            || now.duration_since(last_publish) >= Duration::from_secs(5);
395        if due {
396            broadcaster.publish(StatusEvent::LoadingModel {
397                phase: LoadPhase::Download {
398                    downloaded_bytes: downloaded,
399                    total_bytes: total,
400                    source_url: spec.source_url.clone(),
401                },
402            });
403            // Stdout/journal-visible progress so an operator running
404            // the daemon manually (or watching journalctl) sees the
405            // download is alive. Without this the daemon was silent
406            // for the duration of a 5 GB pull. Mirrors the milestone
407            // cadence of the admin-socket event so subscribers and
408            // log tailers see the same numbers.
409            let pct = total
410                .map(|t| (downloaded as f64 / t as f64) * 100.0)
411                .map(|p| format!("{p:5.1}%"))
412                .unwrap_or_else(|| "  ?  ".to_string());
413            let mib = downloaded / (1024 * 1024);
414            let total_mib = total.map(|t| t / (1024 * 1024)).unwrap_or(0);
415            info!(
416                downloaded_mib = mib,
417                total_mib = total_mib,
418                pct = %pct,
419                "model download progress"
420            );
421            last_publish = now;
422            next_byte_milestone = downloaded + (32 << 20);
423        }
424    }
425    file.flush()?;
426
427    broadcaster.publish(StatusEvent::LoadingModel {
428        phase: LoadPhase::Download {
429            downloaded_bytes: downloaded,
430            total_bytes: total.or(Some(downloaded)),
431            source_url: spec.source_url.clone(),
432        },
433    });
434    info!(
435        downloaded_mib = downloaded / (1024 * 1024),
436        "model download complete"
437    );
438    Ok(downloaded)
439}
440
441/// Streaming SHA-256 of a file as lowercase hex.
442fn sha256_of_path(path: &Path) -> io::Result<String> {
443    let mut file = File::open(path)?;
444    let mut hasher = Sha256::new();
445    let mut buf = vec![0u8; 1 << 20];
446    loop {
447        let n = file.read(&mut buf)?;
448        if n == 0 {
449            break;
450        }
451        hasher.update(&buf[..n]);
452    }
453    let bytes = hasher.finalize();
454    let mut s = String::with_capacity(bytes.len() * 2);
455    for b in bytes {
456        s.push_str(&format!("{:02x}", b));
457    }
458    Ok(s)
459}
460
461fn hex_ct_eq(a: &str, b: &str) -> bool {
462    a.as_bytes().ct_eq(b.as_bytes()).into()
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use tempfile::tempdir;
469
470    // SHA-256("hello world").
471    const HELLO_WORLD_SHA: &str =
472        "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
473
474    fn dummy_broadcaster() -> StatusBroadcaster {
475        StatusBroadcaster::new(StatusEvent::Starting)
476    }
477
478    fn write_blob_at(store: &ModelStore, sha: &str, contents: &[u8]) -> PathBuf {
479        let blob = store.blob_path(sha);
480        std::fs::create_dir_all(blob.parent().unwrap()).unwrap();
481        std::fs::write(&blob, contents).unwrap();
482        blob
483    }
484
485    #[test]
486    fn fetch_returns_immediately_when_manifest_and_blob_present() {
487        let dir = tempdir().unwrap();
488        let store = ModelStore::open(dir.path());
489        store.ensure_layout().unwrap();
490
491        // Pre-seed manifest + blob.
492        let blob = write_blob_at(&store, HELLO_WORLD_SHA, b"hello world");
493        let manifest = Manifest {
494            schema_version: 1,
495            name: "test".into(),
496            format: "gguf".into(),
497            blob: format_blob_ref(HELLO_WORLD_SHA),
498            size_bytes: 11,
499            license: None,
500            source: ManifestSource {
501                registry: "example.invalid".into(),
502                repo: String::new(),
503                revision: String::new(),
504                filename: "blob.gguf".into(),
505            },
506            produced_by: "test".into(),
507            produced_at: "2026-05-18T00:00:00Z".into(),
508        };
509        store.write_manifest(&manifest).unwrap();
510
511        let spec = ModelSpec {
512            name: "test".into(),
513            source_url: "https://example.invalid/blob.gguf".into(),
514            sha256_hex: HELLO_WORLD_SHA.into(),
515            size_bytes: Some(11),
516            license: None,
517            source: None,
518        };
519
520        let b = dummy_broadcaster();
521        let mut rx = b.subscribe();
522        let got = fetch_model(&spec, &store, &b).unwrap();
523        assert_eq!(got, blob);
524
525        let ev = rx.try_recv().unwrap();
526        assert!(matches!(
527            ev,
528            StatusEvent::LoadingModel {
529                phase: LoadPhase::CheckingLocal { .. }
530            }
531        ));
532    }
533
534    #[test]
535    fn fetch_quarantines_blob_with_wrong_bytes() {
536        let dir = tempdir().unwrap();
537        let store = ModelStore::open(dir.path());
538        store.ensure_layout().unwrap();
539
540        // Place WRONG bytes at the CAS path for HELLO_WORLD_SHA.
541        let blob = write_blob_at(&store, HELLO_WORLD_SHA, b"different bytes");
542
543        let spec = ModelSpec {
544            name: "test".into(),
545            source_url: "https://example.invalid/blob.gguf".into(),
546            sha256_hex: HELLO_WORLD_SHA.into(),
547            size_bytes: Some(11),
548            license: None,
549            source: None,
550        };
551        let b = dummy_broadcaster();
552        // Will fail to reach example.invalid AFTER quarantining the
553        // bad blob, which is the path under test.
554        let _ = fetch_model(&spec, &store, &b);
555
556        assert!(!blob.exists(), "bad blob should have been quarantined");
557        let qdir = store.quarantine_dir();
558        assert!(qdir.is_dir());
559        let entries: Vec<_> = std::fs::read_dir(&qdir)
560            .unwrap()
561            .filter_map(Result::ok)
562            .collect();
563        assert!(
564            !entries.is_empty(),
565            "expected at least one quarantined file"
566        );
567    }
568
569    #[test]
570    fn fetch_rejects_non_https_url() {
571        let dir = tempdir().unwrap();
572        let store = ModelStore::open(dir.path());
573        let spec = ModelSpec {
574            name: "test".into(),
575            source_url: "http://example.invalid/blob.gguf".into(),
576            sha256_hex: HELLO_WORLD_SHA.into(),
577            size_bytes: None,
578            license: None,
579            source: None,
580        };
581        let b = dummy_broadcaster();
582        let err = fetch_model(&spec, &store, &b).unwrap_err();
583        assert!(matches!(err, FetchError::InsecureUrl(_)));
584    }
585
586    #[test]
587    fn fetch_errors_when_no_source_and_no_manifest() {
588        let dir = tempdir().unwrap();
589        let store = ModelStore::open(dir.path());
590        let spec = ModelSpec {
591            name: "test".into(),
592            source_url: String::new(),
593            sha256_hex: HELLO_WORLD_SHA.into(),
594            size_bytes: None,
595            license: None,
596            source: None,
597        };
598        let b = dummy_broadcaster();
599        let err = fetch_model(&spec, &store, &b).unwrap_err();
600        assert!(matches!(err, FetchError::NoSourceNoManifest { .. }));
601    }
602
603    #[test]
604    fn sha256_of_known_input() {
605        let dir = tempdir().unwrap();
606        let path = dir.path().join("blob");
607        std::fs::write(&path, b"hello world").unwrap();
608        let got = sha256_of_path(&path).unwrap();
609        assert_eq!(got, HELLO_WORLD_SHA);
610    }
611
612    #[test]
613    fn registry_from_url_pulls_hostname() {
614        assert_eq!(
615            registry_from_url("https://huggingface.co/foo/bar.gguf"),
616            "huggingface.co"
617        );
618        assert_eq!(registry_from_url("not-a-url"), "");
619    }
620
621    #[test]
622    fn filename_from_url_pulls_basename() {
623        assert_eq!(
624            filename_from_url("https://huggingface.co/foo/x.gguf"),
625            "x.gguf"
626        );
627    }
628}