Skip to main content

ctxgraph_extract/
model_manager.rs

1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7/// Specification for a downloadable ONNX model.
8#[derive(Debug, Clone)]
9pub struct ModelSpec {
10    pub name: String,
11    pub url: String,
12    pub sha256: String,
13    pub size_bytes: u64,
14}
15
16/// Manages downloading, caching, and verifying ONNX model files.
17pub struct ModelManager {
18    cache_dir: PathBuf,
19}
20
21impl ModelManager {
22    /// Create a new `ModelManager` using the default cache directory
23    /// (`~/.cache/ctxgraph/models/`).
24    pub fn new() -> Result<Self, ModelManagerError> {
25        let cache = Self::default_cache_dir()?;
26        Ok(Self { cache_dir: cache })
27    }
28
29    /// Create a `ModelManager` with a custom cache directory (useful for tests).
30    pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31        fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32            context: format!("creating cache dir {}", cache_dir.display()),
33            source: e,
34        })?;
35        Ok(Self { cache_dir })
36    }
37
38    /// Return the default cache directory (`~/.cache/ctxgraph/models/`),
39    /// creating it if it does not exist.
40    pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41        let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42        let dir = base.join("ctxgraph").join("models");
43        fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44            context: format!("creating cache dir {}", dir.display()),
45            source: e,
46        })?;
47        Ok(dir)
48    }
49
50    /// Path where a given model would be stored locally.
51    pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52        self.cache_dir.join(&spec.name)
53    }
54
55    /// Check whether the model file exists on disk and its size matches the spec.
56    pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57        let path = self.model_path(spec);
58        match fs::metadata(&path) {
59            Ok(meta) => meta.len() == spec.size_bytes,
60            Err(_) => false,
61        }
62    }
63
64    /// Verify the SHA-256 hash of a cached model file.
65    /// Returns `Ok(true)` if the hash matches, `Ok(false)` if it doesn't,
66    /// or an error if the file cannot be read.
67    pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
68        let path = self.model_path(spec);
69        let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
70            context: format!("opening {} for verification", path.display()),
71            source: e,
72        })?;
73
74        let mut hasher = Sha256::new();
75        let mut buf = [0u8; 8192];
76        loop {
77            let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
78                context: "reading file for hash".into(),
79                source: e,
80            })?;
81            if n == 0 {
82                break;
83            }
84            hasher.update(&buf[..n]);
85        }
86
87        let digest = format!("{:x}", hasher.finalize());
88        Ok(digest == spec.sha256)
89    }
90
91    /// Download a model, verify its hash, and return the local path.
92    pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
93        let dest = self.model_path(spec);
94
95        let response = reqwest::blocking::get(&spec.url).map_err(|e| {
96            ModelManagerError::Download {
97                url: spec.url.clone(),
98                source: e,
99            }
100        })?;
101
102        if !response.status().is_success() {
103            return Err(ModelManagerError::HttpStatus {
104                url: spec.url.clone(),
105                status: response.status().as_u16(),
106            });
107        }
108
109        let total_size = response.content_length().unwrap_or(spec.size_bytes);
110
111        let pb = indicatif::ProgressBar::new(total_size);
112        pb.set_style(
113            indicatif::ProgressStyle::default_bar()
114                .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
115                .unwrap()
116                .progress_chars("#>-"),
117        );
118
119        let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
120            context: format!("creating {}", dest.display()),
121            source: e,
122        })?;
123
124        let mut downloaded: u64 = 0;
125        let mut reader = response;
126        let mut buf = [0u8; 8192];
127        loop {
128            let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
129                context: "reading download stream".into(),
130                source: e,
131            })?;
132            if n == 0 {
133                break;
134            }
135            file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
136                context: "writing model file".into(),
137                source: e,
138            })?;
139            downloaded += n as u64;
140            pb.set_position(downloaded);
141        }
142        pb.finish_with_message("download complete");
143
144        // Verify hash after download
145        let ok = self.verify(spec)?;
146        if !ok {
147            // Remove the corrupt file
148            let _ = fs::remove_file(&dest);
149            return Err(ModelManagerError::HashMismatch {
150                model: spec.name.clone(),
151            });
152        }
153
154        Ok(dest)
155    }
156
157    /// Return the cached model path if it exists and is valid, otherwise download it.
158    pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
159        if self.is_cached(spec) {
160            // Optionally verify hash of cached file
161            if self.verify(spec)? {
162                return Ok(self.model_path(spec));
163            }
164        }
165        self.download(spec)
166    }
167}
168
169// ---------------------------------------------------------------------------
170// Pre-defined model specs
171// ---------------------------------------------------------------------------
172
173/// GLiNER Large v2.1 INT8 quantized model (span-based NER).
174///
175/// From: <https://huggingface.co/onnx-community/gliner_large-v2.1>
176pub fn gliner_large_v21_int8() -> ModelSpec {
177    ModelSpec {
178        name: "gliner_large-v2.1/onnx/model_int8.onnx".into(),
179        url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/onnx/model_int8.onnx".into(),
180        sha256: "pending_verification".into(),
181        size_bytes: 653_000_000,
182    }
183}
184
185/// GLiNER Large v2.1 tokenizer.
186pub fn gliner_large_v21_tokenizer() -> ModelSpec {
187    ModelSpec {
188        name: "gliner_large-v2.1/tokenizer.json".into(),
189        url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/tokenizer.json".into(),
190        sha256: "pending_verification".into(),
191        size_bytes: 17_000_000,
192    }
193}
194
195/// GLiNER Multitask Large v0.5 (token-based NER + relation extraction).
196///
197/// NOTE: This model requires ONNX conversion from PyTorch. No pre-built ONNX
198/// export exists on HuggingFace. Use `scripts/convert_model.py` to convert.
199/// From: <https://huggingface.co/knowledgator/gliner-multitask-large-v0.5>
200pub fn gliner_multitask_large() -> ModelSpec {
201    ModelSpec {
202        name: "gliner-multitask-large-v0.5/onnx/model.onnx".into(),
203        url: "https://huggingface.co/knowledgator/gliner-multitask-large-v0.5/resolve/main/onnx/model.onnx".into(),
204        sha256: "pending_conversion".into(),
205        size_bytes: 1_760_000_000,
206    }
207}
208
209/// GLiNER Multitask Large v0.5 tokenizer.
210pub fn gliner_multitask_tokenizer() -> ModelSpec {
211    ModelSpec {
212        name: "gliner-multitask-large-v0.5/tokenizer.json".into(),
213        url: "https://huggingface.co/knowledgator/gliner-multitask-large-v0.5/resolve/main/tokenizer.json".into(),
214        sha256: "pending_verification".into(),
215        size_bytes: 8_660_000,
216    }
217}
218
219/// MiniLM L6 v2 sentence-embedding model (for v0.3 semantic search).
220pub fn minilm_l6_v2() -> ModelSpec {
221    ModelSpec {
222        name: "minilm-l6-v2.onnx".into(),
223        url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".into(),
224        sha256: "pending_verification".into(),
225        size_bytes: 80_000_000,
226    }
227}
228
229// ---------------------------------------------------------------------------
230// Convenience: download all models needed for NER extraction
231// ---------------------------------------------------------------------------
232
233impl ModelManager {
234    /// Download the NER model and tokenizer needed for Tier 1 extraction.
235    ///
236    /// Downloads `gliner_large-v2.1` INT8 model + tokenizer to the cache directory.
237    pub fn ensure_ner_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
238        let model = self.get_or_download(&gliner_large_v21_int8())?;
239        let tokenizer = self.get_or_download(&gliner_large_v21_tokenizer())?;
240        Ok((model, tokenizer))
241    }
242
243    /// Download the multitask model and tokenizer needed for relation extraction.
244    ///
245    /// Returns `None` if the model is not available (needs ONNX conversion).
246    pub fn ensure_rel_models(&self) -> Option<(PathBuf, PathBuf)> {
247        let model = self.get_or_download(&gliner_multitask_large()).ok()?;
248        let tokenizer = self.get_or_download(&gliner_multitask_tokenizer()).ok()?;
249        Some((model, tokenizer))
250    }
251}
252
253// ---------------------------------------------------------------------------
254// Errors
255// ---------------------------------------------------------------------------
256
257#[derive(Debug, thiserror::Error)]
258pub enum ModelManagerError {
259    #[error("could not determine cache directory")]
260    NoCacheDir,
261
262    #[error("I/O error ({context}): {source}")]
263    Io {
264        context: String,
265        source: std::io::Error,
266    },
267
268    #[error("download failed for {url}: {source}")]
269    Download {
270        url: String,
271        source: reqwest::Error,
272    },
273
274    #[error("HTTP {status} for {url}")]
275    HttpStatus { url: String, status: u16 },
276
277    #[error("SHA-256 hash mismatch for {model}")]
278    HashMismatch { model: String },
279}