cognee_embedding/
download.rs1use crate::error::{EmbeddingError, EmbeddingResult};
6use std::path::{Path, PathBuf};
7use tokio::fs;
8use tokio::io::AsyncWriteExt;
9
10pub struct ModelUrls {
12 pub model_url: &'static str,
14 pub tokenizer_url: &'static str,
16}
17
18impl ModelUrls {
19 pub const BGE_SMALL: ModelUrls = ModelUrls {
21 model_url: "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/onnx/model_quantized.onnx",
22 tokenizer_url: "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/tokenizer.json",
23 };
24
25 pub const MINILM_L6: ModelUrls = ModelUrls {
27 model_url: "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx",
28 tokenizer_url: "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/tokenizer.json",
29 };
30}
31
32async fn download_file(url: &str, dest: &Path) -> EmbeddingResult<()> {
37 if let Some(parent) = dest.parent() {
38 fs::create_dir_all(parent).await?;
39 }
40
41 let response = reqwest::get(url)
42 .await
43 .map_err(|e| EmbeddingError::ModelLoadError(format!("Failed to download {url}: {e}")))?;
44
45 if !response.status().is_success() {
46 return Err(EmbeddingError::ModelLoadError(format!(
47 "Failed to download {}: HTTP {}",
48 url,
49 response.status()
50 )));
51 }
52
53 let bytes = response
54 .bytes()
55 .await
56 .map_err(|e| EmbeddingError::ModelLoadError(format!("Failed to read response: {e}")))?;
57
58 let mut file = fs::File::create(dest).await?;
59 file.write_all(&bytes).await?;
60 file.flush().await?;
61
62 Ok(())
63}
64
65pub async fn ensure_model_exists(path: &Path, url: &str) -> EmbeddingResult<bool> {
76 if path.exists() {
77 return Ok(false);
78 }
79
80 download_file(url, path).await?;
81 Ok(true)
82}
83
84pub async fn ensure_tokenizer_exists(path: &Path, url: &str) -> EmbeddingResult<bool> {
95 if path.exists() {
96 return Ok(false);
97 }
98
99 download_file(url, path).await?;
100 Ok(true)
101}
102
103pub async fn download_model(
114 model_name: &str,
115 model_dir: &Path,
116) -> EmbeddingResult<(PathBuf, PathBuf)> {
117 let urls = match model_name.to_lowercase().as_str() {
118 "bge-small" | "bge-small-v1.5" => ModelUrls::BGE_SMALL,
119 "minilm-l6" | "all-minilm-l6-v2" => ModelUrls::MINILM_L6,
120 _ => {
121 return Err(EmbeddingError::ConfigError(format!(
122 "Unknown model name: {model_name}. Supported: bge-small, minilm-l6"
123 )));
124 }
125 };
126
127 let model_path = if model_name.contains("bge") {
128 model_dir.join("BGE-Small-v1.5-model_quantized.onnx")
129 } else {
130 model_dir.join("all-MiniLM-L6-v2.onnx")
131 };
132
133 let tokenizer_path = if model_name.contains("bge") {
134 model_dir.join("bge-small-tokenizer.json")
135 } else {
136 model_dir.join("minilm-l6-tokenizer.json")
137 };
138
139 ensure_model_exists(&model_path, urls.model_url).await?;
140
141 ensure_tokenizer_exists(&tokenizer_path, urls.tokenizer_url).await?;
142
143 Ok((model_path, tokenizer_path))
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_model_urls() {
152 assert!(ModelUrls::BGE_SMALL.model_url.contains("bge-small"));
153 assert!(
154 ModelUrls::BGE_SMALL
155 .tokenizer_url
156 .contains("tokenizer.json")
157 );
158 assert!(ModelUrls::MINILM_L6.model_url.contains("MiniLM"));
159 }
160}