1use crate::config::Config;
2use anyhow::{Context, Result};
3use sha2::{Digest, Sha256};
4use std::io::{self, Write};
5use std::path::PathBuf;
6use tokio::io::AsyncWriteExt;
7
8pub async fn ensure_model_installed(config: &Config) -> Result<PathBuf> {
9 let default_path = crate::config::get_cache_dir().join("models/bge-micro-v2.onnx");
10 let model_path = config
11 .vector
12 .model_path
13 .as_ref()
14 .map(PathBuf::from)
15 .unwrap_or(default_path);
16
17 if model_path.exists() {
18 return Ok(model_path);
19 }
20
21 if let Some(parent) = model_path.parent() {
22 std::fs::create_dir_all(parent).context("Failed to create model parent directory")?;
23 }
24
25 let url = config
26 .vector
27 .model_url
28 .as_deref()
29 .unwrap_or("https://cdn.cmdhub.xyz/models/bge-micro-v2.onnx");
30
31 let client = reqwest::Client::builder()
32 .timeout(std::time::Duration::from_secs(
33 config.timeout_seconds.max(60),
34 ))
35 .build()
36 .context("Failed to build reqwest client for model download")?;
37
38 eprintln!(
39 "ONNX embedding model is missing. Downloading from {}...",
40 url
41 );
42
43 let mut response = client
44 .get(url)
45 .send()
46 .await
47 .context("Failed to send model download request")?;
48
49 if !response.status().is_success() {
50 anyhow::bail!(
51 "Server returned status code: {} when downloading model",
52 response.status()
53 );
54 }
55
56 let total_size = response.content_length().unwrap_or(0);
57
58 let staging_path = model_path.with_extension("onnx.tmp");
59 let mut file = tokio::fs::File::create(&staging_path)
60 .await
61 .context("Failed to create temporary staging file for model")?;
62
63 let mut downloaded: u64 = 0;
64 let mut last_progress_pct = 999; while let Some(chunk) = response
67 .chunk()
68 .await
69 .context("Error downloading model chunk")?
70 {
71 file.write_all(&chunk)
72 .await
73 .context("Failed to write model chunk to file")?;
74 downloaded += chunk.len() as u64;
75
76 if let Some(progress_pct) = (downloaded * 100).checked_div(total_size) {
77 let progress_pct = progress_pct as usize;
78 if progress_pct != last_progress_pct {
79 last_progress_pct = progress_pct;
80 let bar_width = 30;
81 let filled = progress_pct * bar_width / 100;
82 let empty = bar_width - filled;
83 let bar = format!(
84 "Downloading model: [{}{}] {}% ({:.1} MB / {:.1} MB)\r",
85 "=".repeat(filled),
86 " ".repeat(empty),
87 progress_pct,
88 (downloaded as f64) / 1_048_576.0,
89 (total_size as f64) / 1_048_576.0
90 );
91 let mut stderr = io::stderr();
92 let _ = stderr.write_all(bar.as_bytes());
93 let _ = stderr.flush();
94 }
95 } else {
96 let bar = format!(
97 "Downloading model: {:.1} MB...\r",
98 (downloaded as f64) / 1_048_576.0
99 );
100 let mut stderr = io::stderr();
101 let _ = stderr.write_all(bar.as_bytes());
102 let _ = stderr.flush();
103 }
104 }
105 eprintln!(); file.sync_all()
109 .await
110 .context("Failed to sync model file to disk")?;
111 drop(file);
112
113 eprintln!("Verifying model integrity...");
115 let file_bytes = std::fs::read(&staging_path).context("Failed to read staging model file")?;
116 let mut hasher = Sha256::new();
117 hasher.update(&file_bytes);
118 let hash_str = format!("{:x}", hasher.finalize());
119 let target_hash = config
120 .vector
121 .model_sha256
122 .as_deref()
123 .unwrap_or("d3b07384d113edec49eaa6238ad5ff00b192e2ad47a8a6cf23bdc1048b292e2a");
124
125 if hash_str != target_hash {
126 let _ = std::fs::remove_file(&staging_path);
127 anyhow::bail!(
128 "SHA-256 verification failed. Expected {}, got {}",
129 target_hash,
130 hash_str
131 );
132 }
133
134 std::fs::rename(&staging_path, &model_path)
135 .context("Failed to rename staging file to final model path")?;
136 eprintln!("Model installed successfully to {:?}", model_path);
137
138 Ok(model_path)
139}
140
141pub async fn install_vector(
142 config: &Config,
143 from_file: Option<PathBuf>,
144 force: bool,
145) -> Result<()> {
146 let default_path = crate::config::get_cache_dir().join("models/bge-micro-v2.onnx");
147 let model_path = config
148 .vector
149 .model_path
150 .as_ref()
151 .map(PathBuf::from)
152 .unwrap_or(default_path);
153
154 if !force && model_path.exists() {
155 println!("Model is already installed at {:?}", model_path);
156 return Ok(());
157 }
158
159 if let Some(src_path) = from_file {
160 if let Some(parent) = model_path.parent() {
161 std::fs::create_dir_all(parent)?;
162 }
163 println!("Copying model from {:?} to {:?}...", src_path, model_path);
164 std::fs::copy(&src_path, &model_path).context("Failed to copy custom model file")?;
165
166 let file_bytes = std::fs::read(&model_path)?;
168 let mut hasher = Sha256::new();
169 hasher.update(&file_bytes);
170 let hash_str = format!("{:x}", hasher.finalize());
171 let target_hash = config
172 .vector
173 .model_sha256
174 .as_deref()
175 .unwrap_or("d3b07384d113edec49eaa6238ad5ff00b192e2ad47a8a6cf23bdc1048b292e2a");
176 if hash_str != target_hash {
177 std::fs::remove_file(&model_path)?;
178 anyhow::bail!(
179 "SHA-256 verification failed. Expected {}, got {}",
180 target_hash,
181 hash_str
182 );
183 }
184 println!("Model installed successfully to {:?}", model_path);
185 } else {
186 if model_path.exists() {
188 let _ = std::fs::remove_file(&model_path);
189 }
190 ensure_model_installed(config).await?;
191 }
192 Ok(())
193}