1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7#[derive(Debug, Clone)]
9pub struct ModelSpec {
10 pub name: String,
11 pub url: String,
12 pub sha256: String,
13 pub size_bytes: u64,
14}
15
16pub struct ModelManager {
18 cache_dir: PathBuf,
19}
20
21impl ModelManager {
22 pub fn new() -> Result<Self, ModelManagerError> {
25 let cache = Self::default_cache_dir()?;
26 Ok(Self { cache_dir: cache })
27 }
28
29 pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31 fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32 context: format!("creating cache dir {}", cache_dir.display()),
33 source: e,
34 })?;
35 Ok(Self { cache_dir })
36 }
37
38 pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41 let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42 let dir = base.join("ctxgraph").join("models");
43 fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44 context: format!("creating cache dir {}", dir.display()),
45 source: e,
46 })?;
47 Ok(dir)
48 }
49
50 pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52 self.cache_dir.join(&spec.name)
53 }
54
55 pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57 let path = self.model_path(spec);
58 match fs::metadata(&path) {
59 Ok(meta) => meta.len() == spec.size_bytes,
60 Err(_) => false,
61 }
62 }
63
64 pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
71 if spec.sha256.starts_with("pending") || spec.sha256 == "skip" {
73 return Ok(true);
74 }
75
76 let path = self.model_path(spec);
77 let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
78 context: format!("opening {} for verification", path.display()),
79 source: e,
80 })?;
81
82 let mut hasher = Sha256::new();
83 let mut buf = [0u8; 8192];
84 loop {
85 let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
86 context: "reading file for hash".into(),
87 source: e,
88 })?;
89 if n == 0 {
90 break;
91 }
92 hasher.update(&buf[..n]);
93 }
94
95 let digest = format!("{:x}", hasher.finalize());
96 Ok(digest == spec.sha256)
97 }
98
99 pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
101 let dest = self.model_path(spec);
102
103 let response = reqwest::blocking::get(&spec.url).map_err(|e| {
104 ModelManagerError::Download {
105 url: spec.url.clone(),
106 source: e,
107 }
108 })?;
109
110 if !response.status().is_success() {
111 return Err(ModelManagerError::HttpStatus {
112 url: spec.url.clone(),
113 status: response.status().as_u16(),
114 });
115 }
116
117 let total_size = response.content_length().unwrap_or(spec.size_bytes);
118
119 let pb = indicatif::ProgressBar::new(total_size);
120 pb.set_style(
121 indicatif::ProgressStyle::default_bar()
122 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
123 .unwrap()
124 .progress_chars("#>-"),
125 );
126
127 let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
128 context: format!("creating {}", dest.display()),
129 source: e,
130 })?;
131
132 let mut downloaded: u64 = 0;
133 let mut reader = response;
134 let mut buf = [0u8; 8192];
135 loop {
136 let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
137 context: "reading download stream".into(),
138 source: e,
139 })?;
140 if n == 0 {
141 break;
142 }
143 file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
144 context: "writing model file".into(),
145 source: e,
146 })?;
147 downloaded += n as u64;
148 pb.set_position(downloaded);
149 }
150 pb.finish_with_message("download complete");
151
152 let ok = self.verify(spec)?;
154 if !ok {
155 let _ = fs::remove_file(&dest);
157 return Err(ModelManagerError::HashMismatch {
158 model: spec.name.clone(),
159 });
160 }
161
162 Ok(dest)
163 }
164
165 pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
167 if self.is_cached(spec) {
168 if self.verify(spec)? {
170 return Ok(self.model_path(spec));
171 }
172 }
173 self.download(spec)
174 }
175}
176
177pub fn gliner_large_v21_int8() -> ModelSpec {
185 ModelSpec {
186 name: "gliner_large-v2.1/onnx/model_int8.onnx".into(),
187 url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/onnx/model_int8.onnx".into(),
188 sha256: "pending_verification".into(),
189 size_bytes: 653_000_000,
190 }
191}
192
193pub fn gliner_large_v21_tokenizer() -> ModelSpec {
195 ModelSpec {
196 name: "gliner_large-v2.1/tokenizer.json".into(),
197 url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/tokenizer.json".into(),
198 sha256: "pending_verification".into(),
199 size_bytes: 17_000_000,
200 }
201}
202
203pub fn gliner_multitask_large() -> ModelSpec {
215 ModelSpec {
216 name: "gliner-multitask-large-v0.5/onnx/model_int8.onnx".into(),
217 url: "https://huggingface.co/onnx-community/gliner-multitask-large-v0.5/resolve/main/onnx/model_int8.onnx".into(),
218 sha256: "pending_verification".into(),
219 size_bytes: 647_920_426, }
221}
222
223pub fn gliner_multitask_tokenizer() -> ModelSpec {
225 ModelSpec {
226 name: "gliner-multitask-large-v0.5/tokenizer.json".into(),
227 url: "https://huggingface.co/onnx-community/gliner-multitask-large-v0.5/resolve/main/tokenizer.json".into(),
228 sha256: "pending_verification".into(),
229 size_bytes: 8_657_198,
230 }
231}
232
233pub fn nli_deberta_v3_small() -> ModelSpec {
240 ModelSpec {
241 name: "nli-deberta-v3-small/onnx/model.onnx".into(),
242 url: "https://huggingface.co/cross-encoder/nli-deberta-v3-small/resolve/main/onnx/model.onnx".into(),
243 sha256: "pending_verification".into(),
244 size_bytes: 541_700_000,
245 }
246}
247
248pub fn nli_deberta_v3_small_tokenizer() -> ModelSpec {
250 ModelSpec {
251 name: "nli-deberta-v3-small/tokenizer.json".into(),
252 url: "https://huggingface.co/cross-encoder/nli-deberta-v3-small/resolve/main/tokenizer.json".into(),
253 sha256: "pending_verification".into(),
254 size_bytes: 8_250_000,
255 }
256}
257
258pub fn minilm_l6_v2() -> ModelSpec {
260 ModelSpec {
261 name: "minilm-l6-v2.onnx".into(),
262 url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".into(),
263 sha256: "pending_verification".into(),
264 size_bytes: 80_000_000,
265 }
266}
267
268impl ModelManager {
273 pub fn ensure_ner_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
277 let model = self.get_or_download(&gliner_large_v21_int8())?;
278 let tokenizer = self.get_or_download(&gliner_large_v21_tokenizer())?;
279 Ok((model, tokenizer))
280 }
281
282 pub fn ensure_rel_models(&self) -> Option<(PathBuf, PathBuf)> {
286 let model = self.get_or_download(&gliner_multitask_large()).ok()?;
287 let tokenizer = self.get_or_download(&gliner_multitask_tokenizer()).ok()?;
288 Some((model, tokenizer))
289 }
290
291 pub fn ensure_nli_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
295 let model = self.get_or_download(&nli_deberta_v3_small())?;
296 let tokenizer = self.get_or_download(&nli_deberta_v3_small_tokenizer())?;
297 Ok((model, tokenizer))
298 }
299
300 pub fn find_nli_model(&self) -> Option<(PathBuf, PathBuf)> {
302 let model = self.model_path(&nli_deberta_v3_small());
303 let tokenizer = self.model_path(&nli_deberta_v3_small_tokenizer());
304 if model.exists() && tokenizer.exists() {
305 Some((model, tokenizer))
306 } else {
307 None
308 }
309 }
310
311 pub fn find_relation_classifier(&self) -> Option<std::path::PathBuf> {
318 let base = self.cache_dir.join("relation_classifier");
319
320 [
321 base.join("model_int8.onnx"),
322 base.join("model.onnx"),
323 ]
324 .into_iter()
325 .find(|p| p.exists())
326 }
327
328 pub fn find_relex_model(&self) -> Option<(PathBuf, PathBuf)> {
335 let base = self.cache_dir.join("gliner-relex-large-v0.5");
336
337 let model = [
339 base.join("onnx/model_quantized.onnx"),
340 base.join("onnx/model.onnx"),
341 ]
342 .into_iter()
343 .find(|p| p.exists())?;
344
345 let tokenizer = [
346 base.join("tokenizer.json"),
347 base.join("onnx/tokenizer.json"),
348 ]
349 .into_iter()
350 .find(|p| p.exists())?;
351
352 Some((model, tokenizer))
353 }
354}
355
356#[derive(Debug, thiserror::Error)]
361pub enum ModelManagerError {
362 #[error("could not determine cache directory")]
363 NoCacheDir,
364
365 #[error("I/O error ({context}): {source}")]
366 Io {
367 context: String,
368 source: std::io::Error,
369 },
370
371 #[error("download failed for {url}: {source}")]
372 Download {
373 url: String,
374 source: reqwest::Error,
375 },
376
377 #[error("HTTP {status} for {url}")]
378 HttpStatus { url: String, status: u16 },
379
380 #[error("SHA-256 hash mismatch for {model}")]
381 HashMismatch { model: String },
382}