Skip to main content

ctxgraph_extract/
model_manager.rs

1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7/// Specification for a downloadable ONNX model.
8#[derive(Debug, Clone)]
9pub struct ModelSpec {
10    pub name: String,
11    pub url: String,
12    pub sha256: String,
13    pub size_bytes: u64,
14}
15
16/// Manages downloading, caching, and verifying ONNX model files.
17pub struct ModelManager {
18    cache_dir: PathBuf,
19}
20
21impl ModelManager {
22    /// Create a new `ModelManager` using the default cache directory
23    /// (`~/.cache/ctxgraph/models/`).
24    pub fn new() -> Result<Self, ModelManagerError> {
25        let cache = Self::default_cache_dir()?;
26        Ok(Self { cache_dir: cache })
27    }
28
29    /// Create a `ModelManager` with a custom cache directory (useful for tests).
30    pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31        fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32            context: format!("creating cache dir {}", cache_dir.display()),
33            source: e,
34        })?;
35        Ok(Self { cache_dir })
36    }
37
38    /// Return the default cache directory (`~/.cache/ctxgraph/models/`),
39    /// creating it if it does not exist.
40    pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41        let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42        let dir = base.join("ctxgraph").join("models");
43        fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44            context: format!("creating cache dir {}", dir.display()),
45            source: e,
46        })?;
47        Ok(dir)
48    }
49
50    /// Path where a given model would be stored locally.
51    pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52        self.cache_dir.join(&spec.name)
53    }
54
55    /// Check whether the model file exists on disk and its size matches the spec.
56    pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57        let path = self.model_path(spec);
58        match fs::metadata(&path) {
59            Ok(meta) => meta.len() == spec.size_bytes,
60            Err(_) => false,
61        }
62    }
63
64    /// Verify the SHA-256 hash of a cached model file.
65    /// Returns `Ok(true)` if the hash matches, `Ok(false)` if it doesn't,
66    /// or an error if the file cannot be read.
67    pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
68        let path = self.model_path(spec);
69        let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
70            context: format!("opening {} for verification", path.display()),
71            source: e,
72        })?;
73
74        let mut hasher = Sha256::new();
75        let mut buf = [0u8; 8192];
76        loop {
77            let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
78                context: "reading file for hash".into(),
79                source: e,
80            })?;
81            if n == 0 {
82                break;
83            }
84            hasher.update(&buf[..n]);
85        }
86
87        let digest = format!("{:x}", hasher.finalize());
88        Ok(digest == spec.sha256)
89    }
90
91    /// Download a model, verify its hash, and return the local path.
92    pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
93        let dest = self.model_path(spec);
94
95        let response = reqwest::blocking::get(&spec.url).map_err(|e| {
96            ModelManagerError::Download {
97                url: spec.url.clone(),
98                source: e,
99            }
100        })?;
101
102        if !response.status().is_success() {
103            return Err(ModelManagerError::HttpStatus {
104                url: spec.url.clone(),
105                status: response.status().as_u16(),
106            });
107        }
108
109        let total_size = response.content_length().unwrap_or(spec.size_bytes);
110
111        let pb = indicatif::ProgressBar::new(total_size);
112        pb.set_style(
113            indicatif::ProgressStyle::default_bar()
114                .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
115                .unwrap()
116                .progress_chars("#>-"),
117        );
118
119        let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
120            context: format!("creating {}", dest.display()),
121            source: e,
122        })?;
123
124        let mut downloaded: u64 = 0;
125        let mut reader = response;
126        let mut buf = [0u8; 8192];
127        loop {
128            let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
129                context: "reading download stream".into(),
130                source: e,
131            })?;
132            if n == 0 {
133                break;
134            }
135            file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
136                context: "writing model file".into(),
137                source: e,
138            })?;
139            downloaded += n as u64;
140            pb.set_position(downloaded);
141        }
142        pb.finish_with_message("download complete");
143
144        // Verify hash after download
145        let ok = self.verify(spec)?;
146        if !ok {
147            // Remove the corrupt file
148            let _ = fs::remove_file(&dest);
149            return Err(ModelManagerError::HashMismatch {
150                model: spec.name.clone(),
151            });
152        }
153
154        Ok(dest)
155    }
156
157    /// Return the cached model path if it exists and is valid, otherwise download it.
158    pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
159        if self.is_cached(spec) {
160            // Optionally verify hash of cached file
161            if self.verify(spec)? {
162                return Ok(self.model_path(spec));
163            }
164        }
165        self.download(spec)
166    }
167}
168
169// ---------------------------------------------------------------------------
170// Pre-defined model specs
171// ---------------------------------------------------------------------------
172
173/// GLiNER2 Large quantised (Q8) model.
174pub fn gliner2_large() -> ModelSpec {
175    ModelSpec {
176        name: "gliner2-large-q8.onnx".into(),
177        url: "https://huggingface.co/ctxgraph/models/resolve/main/gliner2-large-q8.onnx".into(),
178        sha256: "placeholder_sha256_gliner2_large_q8".into(),
179        size_bytes: 200_000_000,
180    }
181}
182
183/// GLiREL Large relation-extraction model.
184pub fn glirel_large() -> ModelSpec {
185    ModelSpec {
186        name: "glirel-large.onnx".into(),
187        url: "https://huggingface.co/ctxgraph/models/resolve/main/glirel-large.onnx".into(),
188        sha256: "placeholder_sha256_glirel_large".into(),
189        size_bytes: 150_000_000,
190    }
191}
192
193/// MiniLM L6 v2 sentence-embedding model.
194pub fn minilm_l6_v2() -> ModelSpec {
195    ModelSpec {
196        name: "minilm-l6-v2.onnx".into(),
197        url: "https://huggingface.co/ctxgraph/models/resolve/main/minilm-l6-v2.onnx".into(),
198        sha256: "placeholder_sha256_minilm_l6_v2".into(),
199        size_bytes: 80_000_000,
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Errors
205// ---------------------------------------------------------------------------
206
207#[derive(Debug, thiserror::Error)]
208pub enum ModelManagerError {
209    #[error("could not determine cache directory")]
210    NoCacheDir,
211
212    #[error("I/O error ({context}): {source}")]
213    Io {
214        context: String,
215        source: std::io::Error,
216    },
217
218    #[error("download failed for {url}: {source}")]
219    Download {
220        url: String,
221        source: reqwest::Error,
222    },
223
224    #[error("HTTP {status} for {url}")]
225    HttpStatus { url: String, status: u16 },
226
227    #[error("SHA-256 hash mismatch for {model}")]
228    HashMismatch { model: String },
229}