ctxgraph_extract/
model_manager.rs1use std::fs;
2use std::io::{Read, Write};
3use std::path::PathBuf;
4
5use sha2::{Digest, Sha256};
6
7#[derive(Debug, Clone)]
9pub struct ModelSpec {
10 pub name: String,
11 pub url: String,
12 pub sha256: String,
13 pub size_bytes: u64,
14}
15
16pub struct ModelManager {
18 cache_dir: PathBuf,
19}
20
21impl ModelManager {
22 pub fn new() -> Result<Self, ModelManagerError> {
25 let cache = Self::default_cache_dir()?;
26 Ok(Self { cache_dir: cache })
27 }
28
29 pub fn with_cache_dir(cache_dir: PathBuf) -> Result<Self, ModelManagerError> {
31 fs::create_dir_all(&cache_dir).map_err(|e| ModelManagerError::Io {
32 context: format!("creating cache dir {}", cache_dir.display()),
33 source: e,
34 })?;
35 Ok(Self { cache_dir })
36 }
37
38 pub fn default_cache_dir() -> Result<PathBuf, ModelManagerError> {
41 let base = dirs::cache_dir().ok_or(ModelManagerError::NoCacheDir)?;
42 let dir = base.join("ctxgraph").join("models");
43 fs::create_dir_all(&dir).map_err(|e| ModelManagerError::Io {
44 context: format!("creating cache dir {}", dir.display()),
45 source: e,
46 })?;
47 Ok(dir)
48 }
49
50 pub fn model_path(&self, spec: &ModelSpec) -> PathBuf {
52 self.cache_dir.join(&spec.name)
53 }
54
55 pub fn is_cached(&self, spec: &ModelSpec) -> bool {
57 let path = self.model_path(spec);
58 match fs::metadata(&path) {
59 Ok(meta) => meta.len() == spec.size_bytes,
60 Err(_) => false,
61 }
62 }
63
64 pub fn verify(&self, spec: &ModelSpec) -> Result<bool, ModelManagerError> {
68 let path = self.model_path(spec);
69 let mut file = fs::File::open(&path).map_err(|e| ModelManagerError::Io {
70 context: format!("opening {} for verification", path.display()),
71 source: e,
72 })?;
73
74 let mut hasher = Sha256::new();
75 let mut buf = [0u8; 8192];
76 loop {
77 let n = file.read(&mut buf).map_err(|e| ModelManagerError::Io {
78 context: "reading file for hash".into(),
79 source: e,
80 })?;
81 if n == 0 {
82 break;
83 }
84 hasher.update(&buf[..n]);
85 }
86
87 let digest = format!("{:x}", hasher.finalize());
88 Ok(digest == spec.sha256)
89 }
90
91 pub fn download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
93 let dest = self.model_path(spec);
94
95 let response = reqwest::blocking::get(&spec.url).map_err(|e| {
96 ModelManagerError::Download {
97 url: spec.url.clone(),
98 source: e,
99 }
100 })?;
101
102 if !response.status().is_success() {
103 return Err(ModelManagerError::HttpStatus {
104 url: spec.url.clone(),
105 status: response.status().as_u16(),
106 });
107 }
108
109 let total_size = response.content_length().unwrap_or(spec.size_bytes);
110
111 let pb = indicatif::ProgressBar::new(total_size);
112 pb.set_style(
113 indicatif::ProgressStyle::default_bar()
114 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
115 .unwrap()
116 .progress_chars("#>-"),
117 );
118
119 let mut file = fs::File::create(&dest).map_err(|e| ModelManagerError::Io {
120 context: format!("creating {}", dest.display()),
121 source: e,
122 })?;
123
124 let mut downloaded: u64 = 0;
125 let mut reader = response;
126 let mut buf = [0u8; 8192];
127 loop {
128 let n = reader.read(&mut buf).map_err(|e| ModelManagerError::Io {
129 context: "reading download stream".into(),
130 source: e,
131 })?;
132 if n == 0 {
133 break;
134 }
135 file.write_all(&buf[..n]).map_err(|e| ModelManagerError::Io {
136 context: "writing model file".into(),
137 source: e,
138 })?;
139 downloaded += n as u64;
140 pb.set_position(downloaded);
141 }
142 pb.finish_with_message("download complete");
143
144 let ok = self.verify(spec)?;
146 if !ok {
147 let _ = fs::remove_file(&dest);
149 return Err(ModelManagerError::HashMismatch {
150 model: spec.name.clone(),
151 });
152 }
153
154 Ok(dest)
155 }
156
157 pub fn get_or_download(&self, spec: &ModelSpec) -> Result<PathBuf, ModelManagerError> {
159 if self.is_cached(spec) {
160 if self.verify(spec)? {
162 return Ok(self.model_path(spec));
163 }
164 }
165 self.download(spec)
166 }
167}
168
169pub fn gliner2_large() -> ModelSpec {
175 ModelSpec {
176 name: "gliner2-large-q8.onnx".into(),
177 url: "https://huggingface.co/ctxgraph/models/resolve/main/gliner2-large-q8.onnx".into(),
178 sha256: "placeholder_sha256_gliner2_large_q8".into(),
179 size_bytes: 200_000_000,
180 }
181}
182
183pub fn glirel_large() -> ModelSpec {
185 ModelSpec {
186 name: "glirel-large.onnx".into(),
187 url: "https://huggingface.co/ctxgraph/models/resolve/main/glirel-large.onnx".into(),
188 sha256: "placeholder_sha256_glirel_large".into(),
189 size_bytes: 150_000_000,
190 }
191}
192
193pub fn minilm_l6_v2() -> ModelSpec {
195 ModelSpec {
196 name: "minilm-l6-v2.onnx".into(),
197 url: "https://huggingface.co/ctxgraph/models/resolve/main/minilm-l6-v2.onnx".into(),
198 sha256: "placeholder_sha256_minilm_l6_v2".into(),
199 size_bytes: 80_000_000,
200 }
201}
202
203#[derive(Debug, thiserror::Error)]
208pub enum ModelManagerError {
209 #[error("could not determine cache directory")]
210 NoCacheDir,
211
212 #[error("I/O error ({context}): {source}")]
213 Io {
214 context: String,
215 source: std::io::Error,
216 },
217
218 #[error("download failed for {url}: {source}")]
219 Download {
220 url: String,
221 source: reqwest::Error,
222 },
223
224 #[error("HTTP {status} for {url}")]
225 HttpStatus { url: String, status: u16 },
226
227 #[error("SHA-256 hash mismatch for {model}")]
228 HashMismatch { model: String },
229}