1use std::{fs, path::PathBuf};
2use std::process::Command;
3use anyhow::Context;
4use reqwest::blocking::Client;
5use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT, ACCEPT};
6use std::time::Duration;
7use std::io::{Write, Read};
8use indicatif::{ProgressBar, ProgressStyle};
9use nu_ansi_term::Color::{White, Purple};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
14pub enum ModelSource {
15 HuggingFace { repo: String, file: Option<String> },
16 Url { url: String, filename: Option<String> },
17 Ollama { model: String },
18 LocalPath { path: PathBuf },
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ModelMetadata {
23 pub model_id: String,
24 pub architecture: String,
25 pub parameters: u64,
26 pub model_type: String,
27 pub tokenizer_config: Option<HashMap<String, serde_json::Value>>,
28 pub config: Option<HashMap<String, serde_json::Value>>,
29 pub files: Vec<String>,
30 pub license: Option<String>,
31 pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone)]
35pub struct FetchResult {
36 pub local_path: PathBuf,
37 pub metadata: Option<ModelMetadata>,
38 pub model_format: ModelFormat,
39}
40
41#[derive(Debug, Clone)]
42pub enum ModelFormat {
43 SafeTensors,
44 PyTorch,
45 GGUF,
46 ONNX,
47 Unknown,
48}
49
50impl ModelSource {
51 pub fn is_remote(&self) -> bool {
53 match self {
54 ModelSource::LocalPath { .. } => false,
55 _ => true,
56 }
57 }
58
59 pub fn to_url(&self) -> anyhow::Result<String> {
61 match self {
62 ModelSource::HuggingFace { repo, file } => {
63 let file_name = file.as_deref().unwrap_or("model.safetensors");
64 Ok(format!("https://huggingface.co/{}/resolve/main/{}", repo, file_name))
65 }
66 ModelSource::Url { url, .. } => Ok(url.clone()),
67 ModelSource::Ollama { model } => {
68 Err(anyhow::anyhow!("Ollama models require local pull first: ollama pull {}", model))
69 }
70 ModelSource::LocalPath { .. } => {
71 Err(anyhow::anyhow!("Local paths don't have URLs"))
72 }
73 }
74 }
75}
76
77pub struct ModelFetcher;
78
79impl ModelFetcher {
80 pub fn fetch(source: &ModelSource) -> anyhow::Result<FetchResult> {
81 match source {
82 ModelSource::LocalPath { path } => {
83 let format = detect_model_format(path)?;
84 Ok(FetchResult {
85 local_path: path.clone(),
86 metadata: None,
87 model_format: format,
88 })
89 },
90 ModelSource::Url { url, filename } => fetch_via_http(url, filename.as_deref()),
91 ModelSource::HuggingFace { repo, file } => fetch_hf_with_metadata(repo, file.as_deref()),
92 ModelSource::Ollama { model } => fetch_ollama_with_metadata(model),
93 }
94 }
95
96 pub fn fetch_metadata(repo: &str) -> anyhow::Result<ModelMetadata> {
98 let token = std::env::var("HF_TOKEN").ok().or_else(|| std::env::var("HUGGINGFACE_HUB_TOKEN").ok());
99
100 let client = Client::builder()
101 .timeout(Duration::from_secs(30))
102 .build()?;
103
104 let mut headers = HeaderMap::new();
105 headers.insert(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"));
106 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
107
108 if let Some(t) = token {
109 let hv = HeaderValue::from_str(&format!("Bearer {}", t))
110 .map_err(|_| anyhow::anyhow!("Invalid token format"))?;
111 headers.insert(AUTHORIZATION, hv);
112 }
113
114 let url = format!("https://huggingface.co/api/models/{}", repo);
115 let resp = client.get(&url).headers(headers).send()?;
116
117 if !resp.status().is_success() {
118 return Err(anyhow::anyhow!("Failed to fetch metadata: {}", resp.status()));
119 }
120
121 let metadata: serde_json::Value = resp.json()?;
122
123 Ok(ModelMetadata {
124 model_id: repo.to_string(),
125 architecture: metadata["model_type"].as_str().unwrap_or("unknown").to_string(),
126 parameters: metadata["safetensors"]["total"].as_u64().unwrap_or(0),
127 model_type: metadata["model_type"].as_str().unwrap_or("unknown").to_string(),
128 tokenizer_config: metadata["tokenizer_config"].as_object().map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
129 config: metadata["config"].as_object().map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
130 files: metadata["siblings"]
131 .as_array()
132 .map(|arr| arr.iter()
133 .filter_map(|v| v["rfilename"].as_str().map(|s| s.to_string()))
134 .collect())
135 .unwrap_or_default(),
136 license: metadata["license"].as_str().map(|s| s.to_string()),
137 tags: metadata["tags"]
138 .as_array()
139 .map(|arr| arr.iter()
140 .filter_map(|v| v.as_str().map(|s| s.to_string()))
141 .collect())
142 .unwrap_or_default(),
143 })
144 }
145}
146
147fn fetch_via_http(url: &str, filename: Option<&str>) -> anyhow::Result<FetchResult> {
148 let client = Client::builder()
149 .timeout(Duration::from_secs(900))
150 .build()?;
151 let mut resp = client
152 .get(url)
153 .header(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"))
154 .header(ACCEPT, HeaderValue::from_static("application/octet-stream"))
155 .send()
156 .with_context(|| format!("GET {} failed", url))?;
157 anyhow::ensure!(resp.status().is_success(), "download failed: {}", resp.status());
158 let name = filename.map(|s| s.to_string()).unwrap_or_else(|| infer_filename_from_url(url));
159 let path = std::env::temp_dir().join(name);
160 let mut file = std::fs::File::create(&path)?;
161
162 let total = resp.content_length().unwrap_or(0);
163 let bar = ProgressBar::new(total);
164 bar.set_style(
165 ProgressStyle::with_template("{prefix:.bold} {bar:40.cyan/blue} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
166 .unwrap()
167 .progress_chars("##-"),
168 );
169 bar.set_prefix(format!("{} {}",
170 White.bold().paint("Downloading"),
171 Purple.bold().paint("OHMS"),
172 ));
173
174 let mut buf = [0u8; 1 << 16];
175 loop {
176 let n = resp.read(&mut buf)?;
177 if n == 0 { break; }
178 file.write_all(&buf[..n])?;
179 bar.inc(n as u64);
180 }
181 bar.finish_and_clear();
182 file.flush()?;
183
184 let format = detect_model_format(&path)?;
185 Ok(FetchResult {
186 local_path: path,
187 metadata: None,
188 model_format: format,
189 })
190}
191
192fn fetch_via_http_with_auth(url: &str, filename: &str, token: Option<String>) -> anyhow::Result<FetchResult> {
193 let client = Client::builder()
194 .timeout(Duration::from_secs(900))
195 .build()?;
196 let mut headers = HeaderMap::new();
197 headers.insert(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"));
198 headers.insert(ACCEPT, HeaderValue::from_static("application/octet-stream"));
199 if let Some(t) = token {
200 let hv = HeaderValue::from_str(&format!("Bearer {}", t))
201 .map_err(|_| anyhow::anyhow!("Invalid token format"))?;
202 headers.insert(AUTHORIZATION, hv);
203 }
204 let mut resp = client
205 .get(url)
206 .headers(headers)
207 .send()
208 .with_context(|| format!("GET {} failed", url))?;
209 anyhow::ensure!(resp.status().is_success(), "download failed: {}", resp.status());
210 let path = std::env::temp_dir().join(filename);
211 let mut file = std::fs::File::create(&path)?;
212
213 let total = resp.content_length().unwrap_or(0);
214 let bar = ProgressBar::new(total);
215 bar.set_style(
216 ProgressStyle::with_template("{prefix:.bold} {bar:40.cyan/blue} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
217 .unwrap()
218 .progress_chars("##-"),
219 );
220 bar.set_prefix(format!("{} {}",
221 White.bold().paint("Downloading"),
222 Purple.bold().paint("OHMS"),
223 ));
224
225 let mut buf = [0u8; 1 << 16];
226 loop {
227 let n = resp.read(&mut buf)?;
228 if n == 0 { break; }
229 file.write_all(&buf[..n])?;
230 bar.inc(n as u64);
231 }
232 bar.finish_and_clear();
233 file.flush()?;
234
235 let format = detect_model_format(&path)?;
236 Ok(FetchResult {
237 local_path: path,
238 metadata: None,
239 model_format: format,
240 })
241}
242
243fn fetch_hf_with_metadata(repo: &str, file: Option<&str>) -> anyhow::Result<FetchResult> {
244 let metadata = ModelFetcher::fetch_metadata(repo)?;
246
247 if let Some(res) = try_hf_cli_download(repo, file)? {
249 let format = detect_model_format(&res)?;
250 return Ok(FetchResult {
251 local_path: res,
252 metadata: Some(metadata),
253 model_format: format,
254 });
255 }
256
257 let token = std::env::var("HF_TOKEN").ok().or_else(|| std::env::var("HUGGINGFACE_HUB_TOKEN").ok());
259 let try_files: Vec<String> = if let Some(f) = file {
260 vec![f.to_string()]
261 } else {
262 let mut files = vec![
264 "model.safetensors".into(),
265 "consolidated.safetensors".into(),
266 "pytorch_model.bin".into(),
267 ];
268
269 for file_name in &metadata.files {
271 if file_name.ends_with(".safetensors") || file_name.ends_with(".bin") {
272 files.push(file_name.clone());
273 }
274 }
275 files
276 };
277
278 let base = format!("https://huggingface.co/{}/resolve/main/", repo);
279 let mut last_err: Option<anyhow::Error> = None;
280 for fname in try_files {
281 let url = format!("{}{}", &base, &fname);
282 match fetch_via_http_with_auth(&url, &fname, token.clone()) {
283 Ok(mut result) => {
284 result.metadata = Some(metadata);
285 return Ok(result);
286 },
287 Err(e) => { last_err = Some(e); }
288 }
289 }
290 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("no compatible weight file found in repo '{}'; specify :<file>", repo)))
291}
292
293fn try_hf_cli_download(repo: &str, file: Option<&str>) -> anyhow::Result<Option<PathBuf>> {
294 let which = Command::new("bash").arg("-lc").arg("command -v huggingface-cli").output()?;
296 if which.status.code().unwrap_or(1) != 0 { return Ok(None); }
297
298 let tmp = std::env::temp_dir().join(format!("ohms-hf-{}", repo.replace('/', "_")));
299 let _ = std::fs::create_dir_all(&tmp);
300
301 let mut args: Vec<String> = vec![
303 "download".into(),
304 repo.into(),
305 "--local-dir".into(), tmp.to_string_lossy().to_string(),
306 "--resume-download".into(),
307 ];
308 if let Some(f) = file {
309 args.push("--include".into());
310 args.push(f.into());
311 } else {
312 args.push("--include".into()); args.push("model.safetensors".into());
313 args.push("--include".into()); args.push("consolidated.safetensors".into());
314 args.push("--include".into()); args.push("pytorch_model.bin".into());
315 }
316
317 let mut cmd = Command::new("huggingface-cli");
318 cmd.args(&args);
319 cmd.env("HF_HUB_ENABLE_HF_TRANSFER", "1");
321 if let Ok(tok) = std::env::var("HF_TOKEN") { cmd.env("HF_TOKEN", tok); }
322 let status = cmd.status()?;
323 if !status.success() { return Ok(None); }
324
325 if let Some(f) = file {
327 let p = tmp.join(f);
328 if p.exists() { return Ok(Some(p)); }
329 }
330 for cand in ["model.safetensors", "consolidated.safetensors", "pytorch_model.bin"] {
331 let p = tmp.join(cand);
332 if p.exists() { return Ok(Some(p)); }
333 }
334 Ok(None)
335}
336
337fn fetch_ollama_with_metadata(model: &str) -> anyhow::Result<FetchResult> {
338 let status = std::process::Command::new("ollama").arg("pull").arg(model).status();
340 anyhow::ensure!(status.map(|s| s.success()).unwrap_or(false), "ollama pull {} failed", model);
341
342 let output = std::process::Command::new("ollama")
344 .arg("show")
345 .arg(model)
346 .output()?;
347
348 let model_info = if output.status.success() {
349 serde_json::from_slice::<serde_json::Value>(&output.stdout).ok()
350 } else {
351 None
352 };
353
354 let export_path = std::env::temp_dir().join(format!("{}.gguf", model.replace(':', "_")));
356
357 let export_status = std::process::Command::new("ollama")
358 .arg("export")
359 .arg(model)
360 .arg(&export_path)
361 .status()?;
362
363 anyhow::ensure!(export_status.success(), "ollama export {} failed", model);
364
365 let metadata = ModelMetadata {
367 model_id: model.to_string(),
368 architecture: model_info.as_ref()
369 .and_then(|info| info["model_type"].as_str())
370 .unwrap_or("gguf")
371 .to_string(),
372 parameters: model_info.as_ref()
373 .and_then(|info| info["parameter_size"].as_u64())
374 .unwrap_or(0),
375 model_type: "gguf".to_string(),
376 tokenizer_config: None,
377 config: model_info.as_ref()
378 .and_then(|info| info["config"].as_object())
379 .map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
380 files: vec![export_path.file_name().unwrap().to_string_lossy().to_string()],
381 license: None,
382 tags: vec!["ollama".to_string(), "gguf".to_string()],
383 };
384
385 Ok(FetchResult {
386 local_path: export_path,
387 metadata: Some(metadata),
388 model_format: ModelFormat::GGUF,
389 })
390}
391
392fn detect_model_format(path: &PathBuf) -> anyhow::Result<ModelFormat> {
393 let file_name = path.file_name()
394 .and_then(|n| n.to_str())
395 .unwrap_or("");
396
397 if file_name.ends_with(".safetensors") {
398 Ok(ModelFormat::SafeTensors)
399 } else if file_name.ends_with(".bin") || file_name.ends_with(".pt") || file_name.ends_with(".pth") {
400 Ok(ModelFormat::PyTorch)
401 } else if file_name.ends_with(".gguf") {
402 Ok(ModelFormat::GGUF)
403 } else if file_name.ends_with(".onnx") {
404 Ok(ModelFormat::ONNX)
405 } else {
406 if let Ok(mut file) = std::fs::File::open(path) {
408 let mut header = [0u8; 16];
409 if let Ok(_) = file.read_exact(&mut header) {
410 if &header[0..8] == b"__safetensors__" {
412 return Ok(ModelFormat::SafeTensors);
413 }
414 if &header[0..4] == b"GGUF" {
416 return Ok(ModelFormat::GGUF);
417 }
418 if &header[0..8] == b"PK\x03\x04" {
420 return Ok(ModelFormat::PyTorch);
421 }
422 }
423 }
424 Ok(ModelFormat::Unknown)
425 }
426}
427
428fn infer_filename_from_url(url: &str) -> String {
429 url.split('/')
430 .last()
431 .filter(|s| !s.is_empty())
432 .unwrap_or("model.bin")
433 .to_string()
434}
435
436pub fn parse_model_source(s: &str) -> ModelSource {
437 if let Some(rest) = s.strip_prefix("hf:") {
443 let mut parts = rest.splitn(2, ':');
444 let repo = parts.next().unwrap_or("").to_string();
445 let file = parts.next().map(|v| v.to_string());
446 return ModelSource::HuggingFace { repo, file };
447 }
448 if let Some(rest) = s.strip_prefix("url:") {
449 return ModelSource::Url { url: rest.to_string(), filename: None };
450 }
451 if let Some(rest) = s.strip_prefix("ollama:") {
452 return ModelSource::Ollama { model: rest.to_string() };
453 }
454 if let Some(rest) = s.strip_prefix("file:") {
455 return ModelSource::LocalPath { path: PathBuf::from(rest) };
456 }
457 ModelSource::LocalPath { path: PathBuf::from(s) }
459}
460