Skip to main content

cmdhub_cli/
installer.rs

1use crate::config::Config;
2use anyhow::{Context, Result};
3use sha2::{Digest, Sha256};
4use std::io::{self, Write};
5use std::path::PathBuf;
6use tokio::io::AsyncWriteExt;
7
8pub async fn ensure_model_installed(config: &Config) -> Result<PathBuf> {
9    let default_path = crate::config::get_cache_dir().join("models/bge-micro-v2.onnx");
10    let model_path = config
11        .vector
12        .model_path
13        .as_ref()
14        .map(PathBuf::from)
15        .unwrap_or(default_path);
16
17    if model_path.exists() {
18        return Ok(model_path);
19    }
20
21    if let Some(parent) = model_path.parent() {
22        std::fs::create_dir_all(parent).context("Failed to create model parent directory")?;
23    }
24
25    let url = config
26        .vector
27        .model_url
28        .as_deref()
29        .unwrap_or("https://cdn.cmdhub.org/models/bge-micro-v2.onnx");
30
31    let client = reqwest::Client::builder()
32        .timeout(std::time::Duration::from_secs(
33            config.timeout_seconds.max(60),
34        ))
35        .build()
36        .context("Failed to build reqwest client for model download")?;
37
38    eprintln!(
39        "ONNX embedding model is missing. Downloading from {}...",
40        url
41    );
42
43    let mut response = client
44        .get(url)
45        .send()
46        .await
47        .context("Failed to send model download request")?;
48
49    if !response.status().is_success() {
50        anyhow::bail!(
51            "Server returned status code: {} when downloading model",
52            response.status()
53        );
54    }
55
56    let total_size = response.content_length().unwrap_or(0);
57
58    let staging_path = model_path.with_extension("onnx.tmp");
59    let mut file = tokio::fs::File::create(&staging_path)
60        .await
61        .context("Failed to create temporary staging file for model")?;
62
63    let mut downloaded: u64 = 0;
64    let mut last_progress_pct = 999; // force print on first chunk
65
66    while let Some(chunk) = response
67        .chunk()
68        .await
69        .context("Error downloading model chunk")?
70    {
71        file.write_all(&chunk)
72            .await
73            .context("Failed to write model chunk to file")?;
74        downloaded += chunk.len() as u64;
75
76        if let Some(progress_pct) = (downloaded * 100).checked_div(total_size) {
77            let progress_pct = progress_pct as usize;
78            if progress_pct != last_progress_pct {
79                last_progress_pct = progress_pct;
80                let bar_width = 30;
81                let filled = progress_pct * bar_width / 100;
82                let empty = bar_width - filled;
83                let bar = format!(
84                    "Downloading model: [{}{}] {}% ({:.1} MB / {:.1} MB)\r",
85                    "=".repeat(filled),
86                    " ".repeat(empty),
87                    progress_pct,
88                    (downloaded as f64) / 1_048_576.0,
89                    (total_size as f64) / 1_048_576.0
90                );
91                let mut stderr = io::stderr();
92                let _ = stderr.write_all(bar.as_bytes());
93                let _ = stderr.flush();
94            }
95        } else {
96            let bar = format!(
97                "Downloading model: {:.1} MB...\r",
98                (downloaded as f64) / 1_048_576.0
99            );
100            let mut stderr = io::stderr();
101            let _ = stderr.write_all(bar.as_bytes());
102            let _ = stderr.flush();
103        }
104    }
105    eprintln!(); // newline to clear carriage return
106
107    // Ensure staging file is synced to disk
108    file.sync_all()
109        .await
110        .context("Failed to sync model file to disk")?;
111    drop(file);
112
113    // Calculate SHA-256 of downloaded file
114    eprintln!("Verifying model integrity...");
115    let file_bytes = std::fs::read(&staging_path).context("Failed to read staging model file")?;
116    let mut hasher = Sha256::new();
117    hasher.update(&file_bytes);
118    let hash_str = format!("{:x}", hasher.finalize());
119    let target_hash = config
120        .vector
121        .model_sha256
122        .as_deref()
123        .unwrap_or("9f705befe60d00ca3d8d14c9dd61a3ecfca9f1920a39fbc4a5b056c0ccd977d4");
124
125    if hash_str != target_hash {
126        let _ = std::fs::remove_file(&staging_path);
127        anyhow::bail!(
128            "SHA-256 verification failed. Expected {}, got {}",
129            target_hash,
130            hash_str
131        );
132    }
133
134    std::fs::rename(&staging_path, &model_path)
135        .context("Failed to rename staging file to final model path")?;
136    eprintln!("Model installed successfully to {:?}", model_path);
137
138    Ok(model_path)
139}
140
141pub async fn install_vector(
142    config: &Config,
143    from_file: Option<PathBuf>,
144    force: bool,
145) -> Result<()> {
146    let default_path = crate::config::get_cache_dir().join("models/bge-micro-v2.onnx");
147    let model_path = config
148        .vector
149        .model_path
150        .as_ref()
151        .map(PathBuf::from)
152        .unwrap_or(default_path);
153
154    if !force && model_path.exists() {
155        println!("Model is already installed at {:?}", model_path);
156        return Ok(());
157    }
158
159    if let Some(src_path) = from_file {
160        if let Some(parent) = model_path.parent() {
161            std::fs::create_dir_all(parent)?;
162        }
163        println!("Copying model from {:?} to {:?}...", src_path, model_path);
164        std::fs::copy(&src_path, &model_path).context("Failed to copy custom model file")?;
165
166        // SHA-256 verification of copied file
167        let file_bytes = std::fs::read(&model_path)?;
168        let mut hasher = Sha256::new();
169        hasher.update(&file_bytes);
170        let hash_str = format!("{:x}", hasher.finalize());
171        let target_hash = config
172            .vector
173            .model_sha256
174            .as_deref()
175            .unwrap_or("9f705befe60d00ca3d8d14c9dd61a3ecfca9f1920a39fbc4a5b056c0ccd977d4");
176        if hash_str != target_hash {
177            std::fs::remove_file(&model_path)?;
178            anyhow::bail!(
179                "SHA-256 verification failed. Expected {}, got {}",
180                target_hash,
181                hash_str
182            );
183        }
184        println!("Model installed successfully to {:?}", model_path);
185    } else {
186        // Force re-download by deleting existing first
187        if model_path.exists() {
188            let _ = std::fs::remove_file(&model_path);
189        }
190        ensure_model_installed(config).await?;
191    }
192    Ok(())
193}