Skip to main content

ctxgraph_extract/
model_manager.rs

1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7/// Specification for a downloadable ONNX model.
8#[derive(Debug, Clone)]
9pub struct ModelSpec {
10    pub name: String,
11    pub url: String,
12    pub sha256: String,
13    pub size_bytes: u64,
14}
15
16/// Manages downloading, caching, and verifying ONNX model files.
17pub struct ModelManager {
18    cache_dir: PathBuf,
19}
20
21impl ModelManager {
22    /// Create a new `ModelManager` using the default cache directory
23    /// (`~/.cache/ctxgraph/models/`).
24    pub fn new() -> Result<Self, ModelManagerError> {
25        let cache = Self::default_cache_dir()?;
26        Ok(Self { cache_dir: cache })
27    }
28
29    /// Create a `ModelManager` with a custom cache directory (useful for tests).
30    pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31        fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32            context: format!("creating cache dir {}", cache_dir.display()),
33            source: e,
34        })?;
35        Ok(Self { cache_dir })
36    }
37
38    /// Return the default cache directory (`~/.cache/ctxgraph/models/`),
39    /// creating it if it does not exist.
40    pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41        let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42        let dir = base.join("ctxgraph").join("models");
43        fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44            context: format!("creating cache dir {}", dir.display()),
45            source: e,
46        })?;
47        Ok(dir)
48    }
49
50    /// Path where a given model would be stored locally.
51    pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52        self.cache_dir.join(&spec.name)
53    }
54
55    /// Check whether the model file exists on disk and its size matches the spec.
56    pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57        let path = self.model_path(spec);
58        match fs::metadata(&path) {
59            Ok(meta) => meta.len() == spec.size_bytes,
60            Err(_) => false,
61        }
62    }
63
64    /// Verify the SHA-256 hash of a cached model file.
65    /// Returns `Ok(true)` if the hash matches, `Ok(false)` if it doesn't,
66    /// or an error if the file cannot be read.
67    ///
68    /// If `spec.sha256` starts with "pending" or equals "skip", verification
69    /// is bypassed and `Ok(true)` is returned unconditionally.
70    pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
71        // Skip verification when we don't yet have the authoritative hash.
72        if spec.sha256.starts_with("pending") || spec.sha256 == "skip" {
73            return Ok(true);
74        }
75
76        let path = self.model_path(spec);
77        let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
78            context: format!("opening {} for verification", path.display()),
79            source: e,
80        })?;
81
82        let mut hasher = Sha256::new();
83        let mut buf = [0u8; 8192];
84        loop {
85            let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
86                context: "reading file for hash".into(),
87                source: e,
88            })?;
89            if n == 0 {
90                break;
91            }
92            hasher.update(&buf[..n]);
93        }
94
95        let digest = format!("{:x}", hasher.finalize());
96        Ok(digest == spec.sha256)
97    }
98
99    /// Download a model, verify its hash, and return the local path.
100    pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
101        let dest = self.model_path(spec);
102
103        let response = reqwest::blocking::get(&spec.url).map_err(|e| {
104            ModelManagerError::Download {
105                url: spec.url.clone(),
106                source: e,
107            }
108        })?;
109
110        if !response.status().is_success() {
111            return Err(ModelManagerError::HttpStatus {
112                url: spec.url.clone(),
113                status: response.status().as_u16(),
114            });
115        }
116
117        let total_size = response.content_length().unwrap_or(spec.size_bytes);
118
119        let pb = indicatif::ProgressBar::new(total_size);
120        pb.set_style(
121            indicatif::ProgressStyle::default_bar()
122                .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
123                .unwrap()
124                .progress_chars("#>-"),
125        );
126
127        let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
128            context: format!("creating {}", dest.display()),
129            source: e,
130        })?;
131
132        let mut downloaded: u64 = 0;
133        let mut reader = response;
134        let mut buf = [0u8; 8192];
135        loop {
136            let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
137                context: "reading download stream".into(),
138                source: e,
139            })?;
140            if n == 0 {
141                break;
142            }
143            file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
144                context: "writing model file".into(),
145                source: e,
146            })?;
147            downloaded += n as u64;
148            pb.set_position(downloaded);
149        }
150        pb.finish_with_message("download complete");
151
152        // Verify hash after download
153        let ok = self.verify(spec)?;
154        if !ok {
155            // Remove the corrupt file
156            let _ = fs::remove_file(&dest);
157            return Err(ModelManagerError::HashMismatch {
158                model: spec.name.clone(),
159            });
160        }
161
162        Ok(dest)
163    }
164
165    /// Return the cached model path if it exists and is valid, otherwise download it.
166    pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
167        if self.is_cached(spec) {
168            // Optionally verify hash of cached file
169            if self.verify(spec)? {
170                return Ok(self.model_path(spec));
171            }
172        }
173        self.download(spec)
174    }
175}
176
177// ---------------------------------------------------------------------------
178// Pre-defined model specs
179// ---------------------------------------------------------------------------
180
181/// GLiNER Large v2.1 INT8 quantized model (span-based NER).
182///
183/// From: <https://huggingface.co/onnx-community/gliner_large-v2.1>
184pub fn gliner_large_v21_int8() -> ModelSpec {
185    ModelSpec {
186        name: "gliner_large-v2.1/onnx/model_int8.onnx".into(),
187        url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/onnx/model_int8.onnx".into(),
188        sha256: "pending_verification".into(),
189        size_bytes: 653_000_000,
190    }
191}
192
193/// GLiNER Large v2.1 tokenizer.
194pub fn gliner_large_v21_tokenizer() -> ModelSpec {
195    ModelSpec {
196        name: "gliner_large-v2.1/tokenizer.json".into(),
197        url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/tokenizer.json".into(),
198        sha256: "pending_verification".into(),
199        size_bytes: 17_000_000,
200    }
201}
202
203/// GLiNER Multitask Large v0.5 INT8 quantized (token-level, for relation extraction).
204///
205/// Community ONNX export of `knowledgator/gliner-multitask-large-v0.5`.
206/// This model uses `span_mode: "token_level"` (4 inputs: input_ids, attention_mask,
207/// words_mask, text_lengths) — the only format compatible with gline-rs
208/// `RelationPipeline` and `TokenPipeline`.
209///
210/// NOTE: Do NOT confuse with `gliner_multi-v2.1` which is span-level (6 inputs)
211/// and incompatible with gline-rs RelationPipeline.
212///
213/// From: <https://huggingface.co/onnx-community/gliner-multitask-large-v0.5>
214pub fn gliner_multitask_large() -> ModelSpec {
215    ModelSpec {
216        name: "gliner-multitask-large-v0.5/onnx/model_int8.onnx".into(),
217        url: "https://huggingface.co/onnx-community/gliner-multitask-large-v0.5/resolve/main/onnx/model_int8.onnx".into(),
218        sha256: "pending_verification".into(),
219        size_bytes: 647_920_426, // INT8 quantized
220    }
221}
222
223/// GLiNER Multitask Large v0.5 tokenizer.
224pub fn gliner_multitask_tokenizer() -> ModelSpec {
225    ModelSpec {
226        name: "gliner-multitask-large-v0.5/tokenizer.json".into(),
227        url: "https://huggingface.co/onnx-community/gliner-multitask-large-v0.5/resolve/main/tokenizer.json".into(),
228        sha256: "pending_verification".into(),
229        size_bytes: 8_657_198,
230    }
231}
232
233/// NLI cross-encoder (DeBERTa-v3-xsmall) INT8 quantized model.
234///
235/// Used for zero-shot relation classification via natural language inference.
236/// Input: (premise, hypothesis) pair → output: [contradiction, entailment, neutral] logits.
237///
238/// From: <https://huggingface.co/cross-encoder/nli-deberta-v3-xsmall>
239pub fn nli_deberta_v3_small() -> ModelSpec {
240    ModelSpec {
241        name: "nli-deberta-v3-small/onnx/model.onnx".into(),
242        url: "https://huggingface.co/cross-encoder/nli-deberta-v3-small/resolve/main/onnx/model.onnx".into(),
243        sha256: "pending_verification".into(),
244        size_bytes: 541_700_000,
245    }
246}
247
248/// NLI cross-encoder tokenizer.
249pub fn nli_deberta_v3_small_tokenizer() -> ModelSpec {
250    ModelSpec {
251        name: "nli-deberta-v3-small/tokenizer.json".into(),
252        url: "https://huggingface.co/cross-encoder/nli-deberta-v3-small/resolve/main/tokenizer.json".into(),
253        sha256: "pending_verification".into(),
254        size_bytes: 8_250_000,
255    }
256}
257
258/// MiniLM L6 v2 sentence-embedding model (for v0.3 semantic search).
259pub fn minilm_l6_v2() -> ModelSpec {
260    ModelSpec {
261        name: "minilm-l6-v2.onnx".into(),
262        url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".into(),
263        sha256: "pending_verification".into(),
264        size_bytes: 80_000_000,
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Convenience: download all models needed for NER extraction
270// ---------------------------------------------------------------------------
271
272impl ModelManager {
273    /// Download the NER model and tokenizer needed for Tier 1 extraction.
274    ///
275    /// Downloads `gliner_large-v2.1` INT8 model + tokenizer to the cache directory.
276    pub fn ensure_ner_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
277        let model = self.get_or_download(&gliner_large_v21_int8())?;
278        let tokenizer = self.get_or_download(&gliner_large_v21_tokenizer())?;
279        Ok((model, tokenizer))
280    }
281
282    /// Download the multitask model and tokenizer needed for relation extraction.
283    ///
284    /// Returns `None` if the model is not available (needs ONNX conversion).
285    pub fn ensure_rel_models(&self) -> Option<(PathBuf, PathBuf)> {
286        let model = self.get_or_download(&gliner_multitask_large()).ok()?;
287        let tokenizer = self.get_or_download(&gliner_multitask_tokenizer()).ok()?;
288        Some((model, tokenizer))
289    }
290
291    /// Download the NLI cross-encoder model and tokenizer.
292    ///
293    /// Used for zero-shot relation classification via entailment scoring.
294    pub fn ensure_nli_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
295        let model = self.get_or_download(&nli_deberta_v3_small())?;
296        let tokenizer = self.get_or_download(&nli_deberta_v3_small_tokenizer())?;
297        Ok((model, tokenizer))
298    }
299
300    /// Find locally cached NLI model. Returns `None` if not downloaded yet.
301    pub fn find_nli_model(&self) -> Option<(PathBuf, PathBuf)> {
302        let model = self.model_path(&nli_deberta_v3_small());
303        let tokenizer = self.model_path(&nli_deberta_v3_small_tokenizer());
304        if model.exists() && tokenizer.exists() {
305            Some((model, tokenizer))
306        } else {
307            None
308        }
309    }
310
311    /// Check for a locally available fine-tuned relation classifier model.
312    ///
313    /// Looks for `relation_classifier/model_int8.onnx` (or `model.onnx`) and
314    /// `relation_classifier/tokenizer.json` in the cache directory.
315    ///
316    /// Returns `Some((model_path, tokenizer_path))` if found, `None` otherwise.
317    pub fn find_relation_classifier(&self) -> Option<std::path::PathBuf> {
318        let base = self.cache_dir.join("relation_classifier");
319
320        [
321            base.join("model_int8.onnx"),
322            base.join("model.onnx"),
323        ]
324        .into_iter()
325        .find(|p| p.exists())
326    }
327
328    /// Check for locally exported gliner-relex ONNX model.
329    ///
330    /// The relex model must be exported manually using:
331    ///   `python scripts/export_relex_onnx.py [--quantize]`
332    ///
333    /// Returns `Some((model_path, tokenizer_path))` if found, `None` otherwise.
334    pub fn find_relex_model(&self) -> Option<(PathBuf, PathBuf)> {
335        let base = self.cache_dir.join("gliner-relex-large-v0.5");
336
337        // Check for quantized first, then full precision
338        let model = [
339            base.join("onnx/model_quantized.onnx"),
340            base.join("onnx/model.onnx"),
341        ]
342        .into_iter()
343        .find(|p| p.exists())?;
344
345        let tokenizer = [
346            base.join("tokenizer.json"),
347            base.join("onnx/tokenizer.json"),
348        ]
349        .into_iter()
350        .find(|p| p.exists())?;
351
352        Some((model, tokenizer))
353    }
354}
355
356// ---------------------------------------------------------------------------
357// Errors
358// ---------------------------------------------------------------------------
359
360#[derive(Debug, thiserror::Error)]
361pub enum ModelManagerError {
362    #[error("could not determine cache directory")]
363    NoCacheDir,
364
365    #[error("I/O error ({context}): {source}")]
366    Io {
367        context: String,
368        source: std::io::Error,
369    },
370
371    #[error("download failed for {url}: {source}")]
372    Download {
373        url: String,
374        source: reqwest::Error,
375    },
376
377    #[error("HTTP {status} for {url}")]
378    HttpStatus { url: String, status: u16 },
379
380    #[error("SHA-256 hash mismatch for {model}")]
381    HashMismatch { model: String },
382}