ctxgraph_extract/
model_manager.rs1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7#[derive(Debug, Clone)]
9pub struct ModelSpec {
10 pub name: String,
11 pub url: String,
12 pub sha256: String,
13 pub size_bytes: u64,
14}
15
16pub struct ModelManager {
18 cache_dir: PathBuf,
19}
20
21impl ModelManager {
22 pub fn new() -> Result<Self, ModelManagerError> {
25 let cache = Self::default_cache_dir()?;
26 Ok(Self { cache_dir: cache })
27 }
28
29 pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31 fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32 context: format!("creating cache dir {}", cache_dir.display()),
33 source: e,
34 })?;
35 Ok(Self { cache_dir })
36 }
37
38 pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41 let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42 let dir = base.join("ctxgraph").join("models");
43 fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44 context: format!("creating cache dir {}", dir.display()),
45 source: e,
46 })?;
47 Ok(dir)
48 }
49
50 pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52 self.cache_dir.join(&spec.name)
53 }
54
55 pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57 let path = self.model_path(spec);
58 match fs::metadata(&path) {
59 Ok(meta) => meta.len() == spec.size_bytes,
60 Err(_) => false,
61 }
62 }
63
64 pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
68 let path = self.model_path(spec);
69 let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
70 context: format!("opening {} for verification", path.display()),
71 source: e,
72 })?;
73
74 let mut hasher = Sha256::new();
75 let mut buf = [0u8; 8192];
76 loop {
77 let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
78 context: "reading file for hash".into(),
79 source: e,
80 })?;
81 if n == 0 {
82 break;
83 }
84 hasher.update(&buf[..n]);
85 }
86
87 let digest = format!("{:x}", hasher.finalize());
88 Ok(digest == spec.sha256)
89 }
90
91 pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
93 let dest = self.model_path(spec);
94
95 let response = reqwest::blocking::get(&spec.url).map_err(|e| {
96 ModelManagerError::Download {
97 url: spec.url.clone(),
98 source: e,
99 }
100 })?;
101
102 if !response.status().is_success() {
103 return Err(ModelManagerError::HttpStatus {
104 url: spec.url.clone(),
105 status: response.status().as_u16(),
106 });
107 }
108
109 let total_size = response.content_length().unwrap_or(spec.size_bytes);
110
111 let pb = indicatif::ProgressBar::new(total_size);
112 pb.set_style(
113 indicatif::ProgressStyle::default_bar()
114 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
115 .unwrap()
116 .progress_chars("#>-"),
117 );
118
119 let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
120 context: format!("creating {}", dest.display()),
121 source: e,
122 })?;
123
124 let mut downloaded: u64 = 0;
125 let mut reader = response;
126 let mut buf = [0u8; 8192];
127 loop {
128 let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
129 context: "reading download stream".into(),
130 source: e,
131 })?;
132 if n == 0 {
133 break;
134 }
135 file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
136 context: "writing model file".into(),
137 source: e,
138 })?;
139 downloaded += n as u64;
140 pb.set_position(downloaded);
141 }
142 pb.finish_with_message("download complete");
143
144 let ok = self.verify(spec)?;
146 if !ok {
147 let _ = fs::remove_file(&dest);
149 return Err(ModelManagerError::HashMismatch {
150 model: spec.name.clone(),
151 });
152 }
153
154 Ok(dest)
155 }
156
157 pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
159 if self.is_cached(spec) {
160 if self.verify(spec)? {
162 return Ok(self.model_path(spec));
163 }
164 }
165 self.download(spec)
166 }
167}
168
169pub fn gliner_large_v21_int8() -> ModelSpec {
177 ModelSpec {
178 name: "gliner_large-v2.1/onnx/model_int8.onnx".into(),
179 url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/onnx/model_int8.onnx".into(),
180 sha256: "pending_verification".into(),
181 size_bytes: 653_000_000,
182 }
183}
184
185pub fn gliner_large_v21_tokenizer() -> ModelSpec {
187 ModelSpec {
188 name: "gliner_large-v2.1/tokenizer.json".into(),
189 url: "https://huggingface.co/onnx-community/gliner_large-v2.1/resolve/main/tokenizer.json".into(),
190 sha256: "pending_verification".into(),
191 size_bytes: 17_000_000,
192 }
193}
194
195pub fn gliner_multitask_large() -> ModelSpec {
201 ModelSpec {
202 name: "gliner-multitask-large-v0.5/onnx/model.onnx".into(),
203 url: "https://huggingface.co/knowledgator/gliner-multitask-large-v0.5/resolve/main/onnx/model.onnx".into(),
204 sha256: "pending_conversion".into(),
205 size_bytes: 1_760_000_000,
206 }
207}
208
209pub fn gliner_multitask_tokenizer() -> ModelSpec {
211 ModelSpec {
212 name: "gliner-multitask-large-v0.5/tokenizer.json".into(),
213 url: "https://huggingface.co/knowledgator/gliner-multitask-large-v0.5/resolve/main/tokenizer.json".into(),
214 sha256: "pending_verification".into(),
215 size_bytes: 8_660_000,
216 }
217}
218
219pub fn minilm_l6_v2() -> ModelSpec {
221 ModelSpec {
222 name: "minilm-l6-v2.onnx".into(),
223 url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".into(),
224 sha256: "pending_verification".into(),
225 size_bytes: 80_000_000,
226 }
227}
228
229impl ModelManager {
234 pub fn ensure_ner_models(&self) -> Result<(PathBuf, PathBuf), ModelManagerError> {
238 let model = self.get_or_download(&gliner_large_v21_int8())?;
239 let tokenizer = self.get_or_download(&gliner_large_v21_tokenizer())?;
240 Ok((model, tokenizer))
241 }
242
243 pub fn ensure_rel_models(&self) -> Option<(PathBuf, PathBuf)> {
247 let model = self.get_or_download(&gliner_multitask_large()).ok()?;
248 let tokenizer = self.get_or_download(&gliner_multitask_tokenizer()).ok()?;
249 Some((model, tokenizer))
250 }
251}
252
253#[derive(Debug, thiserror::Error)]
258pub enum ModelManagerError {
259 #[error("could not determine cache directory")]
260 NoCacheDir,
261
262 #[error("I/O error ({context}): {source}")]
263 Io {
264 context: String,
265 source: std::io::Error,
266 },
267
268 #[error("download failed for {url}: {source}")]
269 Download {
270 url: String,
271 source: reqwest::Error,
272 },
273
274 #[error("HTTP {status} for {url}")]
275 HttpStatus { url: String, status: u16 },
276
277 #[error("SHA-256 hash mismatch for {model}")]
278 HashMismatch { model: String },
279}