ohms_adaptq/
model_fetcher.rs

1use std::{fs, path::PathBuf};
2use std::process::Command;
3use anyhow::Context;
4use reqwest::blocking::Client;
5use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT, ACCEPT};
6use std::time::Duration;
7use std::io::{Write, Read};
8use indicatif::{ProgressBar, ProgressStyle};
9use nu_ansi_term::Color::{White, Purple};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
14pub enum ModelSource {
15    HuggingFace { repo: String, file: Option<String> },
16    Url { url: String, filename: Option<String> },
17    Ollama { model: String },
18    LocalPath { path: PathBuf },
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ModelMetadata {
23    pub model_id: String,
24    pub architecture: String,
25    pub parameters: u64,
26    pub model_type: String,
27    pub tokenizer_config: Option<HashMap<String, serde_json::Value>>,
28    pub config: Option<HashMap<String, serde_json::Value>>,
29    pub files: Vec<String>,
30    pub license: Option<String>,
31    pub tags: Vec<String>,
32}
33
34#[derive(Debug, Clone)]
35pub struct FetchResult {
36    pub local_path: PathBuf,
37    pub metadata: Option<ModelMetadata>,
38    pub model_format: ModelFormat,
39}
40
41#[derive(Debug, Clone)]
42pub enum ModelFormat {
43    SafeTensors,
44    PyTorch,
45    GGUF,
46    ONNX,
47    Unknown,
48}
49
50impl ModelSource {
51    /// Check if this source requires remote fetching
52    pub fn is_remote(&self) -> bool {
53        match self {
54            ModelSource::LocalPath { .. } => false,
55            _ => true,
56        }
57    }
58    
59    /// Convert source to direct URL for streaming
60    pub fn to_url(&self) -> anyhow::Result<String> {
61        match self {
62            ModelSource::HuggingFace { repo, file } => {
63                let file_name = file.as_deref().unwrap_or("model.safetensors");
64                Ok(format!("https://huggingface.co/{}/resolve/main/{}", repo, file_name))
65            }
66            ModelSource::Url { url, .. } => Ok(url.clone()),
67            ModelSource::Ollama { model } => {
68                Err(anyhow::anyhow!("Ollama models require local pull first: ollama pull {}", model))
69            }
70            ModelSource::LocalPath { .. } => {
71                Err(anyhow::anyhow!("Local paths don't have URLs"))
72            }
73        }
74    }
75}
76
77pub struct ModelFetcher;
78
79impl ModelFetcher {
80    pub fn fetch(source: &ModelSource) -> anyhow::Result<FetchResult> {
81        match source {
82            ModelSource::LocalPath { path } => {
83                let format = detect_model_format(path)?;
84                Ok(FetchResult { 
85                    local_path: path.clone(),
86                    metadata: None,
87                    model_format: format,
88                })
89            },
90            ModelSource::Url { url, filename } => fetch_via_http(url, filename.as_deref()),
91            ModelSource::HuggingFace { repo, file } => fetch_hf_with_metadata(repo, file.as_deref()),
92            ModelSource::Ollama { model } => fetch_ollama_with_metadata(model),
93        }
94    }
95
96    /// Fetch model metadata from Hugging Face
97    pub fn fetch_metadata(repo: &str) -> anyhow::Result<ModelMetadata> {
98        let token = std::env::var("HF_TOKEN").ok().or_else(|| std::env::var("HUGGINGFACE_HUB_TOKEN").ok());
99        
100        let client = Client::builder()
101            .timeout(Duration::from_secs(30))
102            .build()?;
103        
104        let mut headers = HeaderMap::new();
105        headers.insert(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"));
106        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
107        
108        if let Some(t) = token {
109            let hv = HeaderValue::from_str(&format!("Bearer {}", t))
110                .map_err(|_| anyhow::anyhow!("Invalid token format"))?;
111            headers.insert(AUTHORIZATION, hv);
112        }
113
114        let url = format!("https://huggingface.co/api/models/{}", repo);
115        let resp = client.get(&url).headers(headers).send()?;
116        
117        if !resp.status().is_success() {
118            return Err(anyhow::anyhow!("Failed to fetch metadata: {}", resp.status()));
119        }
120
121        let metadata: serde_json::Value = resp.json()?;
122        
123        Ok(ModelMetadata {
124            model_id: repo.to_string(),
125            architecture: metadata["model_type"].as_str().unwrap_or("unknown").to_string(),
126            parameters: metadata["safetensors"]["total"].as_u64().unwrap_or(0),
127            model_type: metadata["model_type"].as_str().unwrap_or("unknown").to_string(),
128            tokenizer_config: metadata["tokenizer_config"].as_object().map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
129            config: metadata["config"].as_object().map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
130            files: metadata["siblings"]
131                .as_array()
132                .map(|arr| arr.iter()
133                    .filter_map(|v| v["rfilename"].as_str().map(|s| s.to_string()))
134                    .collect())
135                .unwrap_or_default(),
136            license: metadata["license"].as_str().map(|s| s.to_string()),
137            tags: metadata["tags"]
138                .as_array()
139                .map(|arr| arr.iter()
140                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
141                    .collect())
142                .unwrap_or_default(),
143        })
144    }
145}
146
147fn fetch_via_http(url: &str, filename: Option<&str>) -> anyhow::Result<FetchResult> {
148    let client = Client::builder()
149        .timeout(Duration::from_secs(900))
150        .build()?;
151    let mut resp = client
152        .get(url)
153        .header(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"))
154        .header(ACCEPT, HeaderValue::from_static("application/octet-stream"))
155        .send()
156        .with_context(|| format!("GET {} failed", url))?;
157    anyhow::ensure!(resp.status().is_success(), "download failed: {}", resp.status());
158    let name = filename.map(|s| s.to_string()).unwrap_or_else(|| infer_filename_from_url(url));
159    let path = std::env::temp_dir().join(name);
160    let mut file = std::fs::File::create(&path)?;
161
162    let total = resp.content_length().unwrap_or(0);
163    let bar = ProgressBar::new(total);
164    bar.set_style(
165        ProgressStyle::with_template("{prefix:.bold} {bar:40.cyan/blue} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
166            .unwrap()
167            .progress_chars("##-"),
168    );
169    bar.set_prefix(format!("{} {}",
170        White.bold().paint("Downloading"),
171        Purple.bold().paint("OHMS"),
172    ));
173
174    let mut buf = [0u8; 1 << 16];
175    loop {
176        let n = resp.read(&mut buf)?;
177        if n == 0 { break; }
178        file.write_all(&buf[..n])?;
179        bar.inc(n as u64);
180    }
181    bar.finish_and_clear();
182    file.flush()?;
183    
184    let format = detect_model_format(&path)?;
185    Ok(FetchResult { 
186        local_path: path,
187        metadata: None,
188        model_format: format,
189    })
190}
191
192fn fetch_via_http_with_auth(url: &str, filename: &str, token: Option<String>) -> anyhow::Result<FetchResult> {
193    let client = Client::builder()
194        .timeout(Duration::from_secs(900))
195        .build()?;
196    let mut headers = HeaderMap::new();
197    headers.insert(USER_AGENT, HeaderValue::from_static("ohms-adaptq/2.0"));
198    headers.insert(ACCEPT, HeaderValue::from_static("application/octet-stream"));
199    if let Some(t) = token {
200        let hv = HeaderValue::from_str(&format!("Bearer {}", t))
201            .map_err(|_| anyhow::anyhow!("Invalid token format"))?;
202        headers.insert(AUTHORIZATION, hv);
203    }
204    let mut resp = client
205        .get(url)
206        .headers(headers)
207        .send()
208        .with_context(|| format!("GET {} failed", url))?;
209    anyhow::ensure!(resp.status().is_success(), "download failed: {}", resp.status());
210    let path = std::env::temp_dir().join(filename);
211    let mut file = std::fs::File::create(&path)?;
212
213    let total = resp.content_length().unwrap_or(0);
214    let bar = ProgressBar::new(total);
215    bar.set_style(
216        ProgressStyle::with_template("{prefix:.bold} {bar:40.cyan/blue} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
217            .unwrap()
218            .progress_chars("##-"),
219    );
220    bar.set_prefix(format!("{} {}",
221        White.bold().paint("Downloading"),
222        Purple.bold().paint("OHMS"),
223    ));
224
225    let mut buf = [0u8; 1 << 16];
226    loop {
227        let n = resp.read(&mut buf)?;
228        if n == 0 { break; }
229        file.write_all(&buf[..n])?;
230        bar.inc(n as u64);
231    }
232    bar.finish_and_clear();
233    file.flush()?;
234    
235    let format = detect_model_format(&path)?;
236    Ok(FetchResult { 
237        local_path: path,
238        metadata: None,
239        model_format: format,
240    })
241}
242
243fn fetch_hf_with_metadata(repo: &str, file: Option<&str>) -> anyhow::Result<FetchResult> {
244    // First, fetch metadata
245    let metadata = ModelFetcher::fetch_metadata(repo)?;
246    
247    // Prefer huggingface-cli with hf_transfer if present; fall back to HTTP
248    if let Some(res) = try_hf_cli_download(repo, file)? {
249        let format = detect_model_format(&res)?;
250        return Ok(FetchResult { 
251            local_path: res,
252            metadata: Some(metadata),
253            model_format: format,
254        });
255    }
256
257    // Use lightweight HTTP with auth headers and try common filenames if none provided
258    let token = std::env::var("HF_TOKEN").ok().or_else(|| std::env::var("HUGGINGFACE_HUB_TOKEN").ok());
259    let try_files: Vec<String> = if let Some(f) = file { 
260        vec![f.to_string()] 
261    } else { 
262        // Use metadata to find the best file
263        let mut files = vec![
264            "model.safetensors".into(),
265            "consolidated.safetensors".into(),
266            "pytorch_model.bin".into(),
267        ];
268        
269        // Add files from metadata if available
270        for file_name in &metadata.files {
271            if file_name.ends_with(".safetensors") || file_name.ends_with(".bin") {
272                files.push(file_name.clone());
273            }
274        }
275        files
276    };
277
278    let base = format!("https://huggingface.co/{}/resolve/main/", repo);
279    let mut last_err: Option<anyhow::Error> = None;
280    for fname in try_files {
281        let url = format!("{}{}", &base, &fname);
282        match fetch_via_http_with_auth(&url, &fname, token.clone()) {
283            Ok(mut result) => {
284                result.metadata = Some(metadata);
285                return Ok(result);
286            },
287            Err(e) => { last_err = Some(e); }
288        }
289    }
290    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("no compatible weight file found in repo '{}'; specify :<file>", repo)))
291}
292
293fn try_hf_cli_download(repo: &str, file: Option<&str>) -> anyhow::Result<Option<PathBuf>> {
294    // Check if huggingface-cli exists
295    let which = Command::new("bash").arg("-lc").arg("command -v huggingface-cli").output()?;
296    if which.status.code().unwrap_or(1) != 0 { return Ok(None); }
297
298    let tmp = std::env::temp_dir().join(format!("ohms-hf-{}", repo.replace('/', "_")));
299    let _ = std::fs::create_dir_all(&tmp);
300
301    // Build args
302    let mut args: Vec<String> = vec![
303        "download".into(),
304        repo.into(),
305        "--local-dir".into(), tmp.to_string_lossy().to_string(),
306        "--resume-download".into(),
307    ];
308    if let Some(f) = file { 
309        args.push("--include".into()); 
310        args.push(f.into()); 
311    } else {
312        args.push("--include".into()); args.push("model.safetensors".into());
313        args.push("--include".into()); args.push("consolidated.safetensors".into());
314        args.push("--include".into()); args.push("pytorch_model.bin".into());
315    }
316
317    let mut cmd = Command::new("huggingface-cli");
318    cmd.args(&args);
319    // Enable accelerated transfer if available
320    cmd.env("HF_HUB_ENABLE_HF_TRANSFER", "1");
321    if let Ok(tok) = std::env::var("HF_TOKEN") { cmd.env("HF_TOKEN", tok); }
322    let status = cmd.status()?;
323    if !status.success() { return Ok(None); }
324
325    // Resolve target file path
326    if let Some(f) = file { 
327        let p = tmp.join(f);
328        if p.exists() { return Ok(Some(p)); }
329    }
330    for cand in ["model.safetensors", "consolidated.safetensors", "pytorch_model.bin"] {
331        let p = tmp.join(cand);
332        if p.exists() { return Ok(Some(p)); }
333    }
334    Ok(None)
335}
336
337fn fetch_ollama_with_metadata(model: &str) -> anyhow::Result<FetchResult> {
338    // First, ensure model is pulled
339    let status = std::process::Command::new("ollama").arg("pull").arg(model).status();
340    anyhow::ensure!(status.map(|s| s.success()).unwrap_or(false), "ollama pull {} failed", model);
341
342    // Try to get model info
343    let output = std::process::Command::new("ollama")
344        .arg("show")
345        .arg(model)
346        .output()?;
347    
348    let model_info = if output.status.success() {
349        serde_json::from_slice::<serde_json::Value>(&output.stdout).ok()
350    } else {
351        None
352    };
353
354    // For Ollama, we need to export the model to get the actual file
355    let export_path = std::env::temp_dir().join(format!("{}.gguf", model.replace(':', "_")));
356    
357    let export_status = std::process::Command::new("ollama")
358        .arg("export")
359        .arg(model)
360        .arg(&export_path)
361        .status()?;
362    
363    anyhow::ensure!(export_status.success(), "ollama export {} failed", model);
364
365    // Create metadata from Ollama info
366    let metadata = ModelMetadata {
367        model_id: model.to_string(),
368        architecture: model_info.as_ref()
369            .and_then(|info| info["model_type"].as_str())
370            .unwrap_or("gguf")
371            .to_string(),
372        parameters: model_info.as_ref()
373            .and_then(|info| info["parameter_size"].as_u64())
374            .unwrap_or(0),
375        model_type: "gguf".to_string(),
376        tokenizer_config: None,
377        config: model_info.as_ref()
378            .and_then(|info| info["config"].as_object())
379            .map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
380        files: vec![export_path.file_name().unwrap().to_string_lossy().to_string()],
381        license: None,
382        tags: vec!["ollama".to_string(), "gguf".to_string()],
383    };
384
385    Ok(FetchResult { 
386        local_path: export_path,
387        metadata: Some(metadata),
388        model_format: ModelFormat::GGUF,
389    })
390}
391
392fn detect_model_format(path: &PathBuf) -> anyhow::Result<ModelFormat> {
393    let file_name = path.file_name()
394        .and_then(|n| n.to_str())
395        .unwrap_or("");
396    
397    if file_name.ends_with(".safetensors") {
398        Ok(ModelFormat::SafeTensors)
399    } else if file_name.ends_with(".bin") || file_name.ends_with(".pt") || file_name.ends_with(".pth") {
400        Ok(ModelFormat::PyTorch)
401    } else if file_name.ends_with(".gguf") {
402        Ok(ModelFormat::GGUF)
403    } else if file_name.ends_with(".onnx") {
404        Ok(ModelFormat::ONNX)
405    } else {
406        // Try to detect by file header
407        if let Ok(mut file) = std::fs::File::open(path) {
408            let mut header = [0u8; 16];
409            if let Ok(_) = file.read_exact(&mut header) {
410                // Check for SafeTensors magic
411                if &header[0..8] == b"__safetensors__" {
412                    return Ok(ModelFormat::SafeTensors);
413                }
414                // Check for GGUF magic
415                if &header[0..4] == b"GGUF" {
416                    return Ok(ModelFormat::GGUF);
417                }
418                // Check for PyTorch magic
419                if &header[0..8] == b"PK\x03\x04" {
420                    return Ok(ModelFormat::PyTorch);
421                }
422            }
423        }
424        Ok(ModelFormat::Unknown)
425    }
426}
427
428fn infer_filename_from_url(url: &str) -> String {
429    url.split('/')
430        .last()
431        .filter(|s| !s.is_empty())
432        .unwrap_or("model.bin")
433        .to_string()
434}
435
436pub fn parse_model_source(s: &str) -> ModelSource {
437    // Examples:
438    // hf:meta-llama/Llama-3-8B:consolidated.safetensors
439    // url:https://host/path/file.bin
440    // ollama:llama3:8b
441    // file:/abs/path/model.onnx
442    if let Some(rest) = s.strip_prefix("hf:") {
443        let mut parts = rest.splitn(2, ':');
444        let repo = parts.next().unwrap_or("").to_string();
445        let file = parts.next().map(|v| v.to_string());
446        return ModelSource::HuggingFace { repo, file };
447    }
448    if let Some(rest) = s.strip_prefix("url:") {
449        return ModelSource::Url { url: rest.to_string(), filename: None };
450    }
451    if let Some(rest) = s.strip_prefix("ollama:") {
452        return ModelSource::Ollama { model: rest.to_string() };
453    }
454    if let Some(rest) = s.strip_prefix("file:") {
455        return ModelSource::LocalPath { path: PathBuf::from(rest) };
456    }
457    // Default: treat as local path
458    ModelSource::LocalPath { path: PathBuf::from(s) }
459}
460