Skip to main content

cognee_embedding/
download.rs

1//! Lazy downloading of embedding models and tokenizers from HuggingFace Hub.
2//!
3//! Automatically downloads missing model files when creating an embedding engine.
4
5use crate::error::{EmbeddingError, EmbeddingResult};
6use std::path::{Path, PathBuf};
7use tokio::fs;
8use tokio::io::AsyncWriteExt;
9
10/// HuggingFace Hub URLs for supported models
11pub struct ModelUrls {
12    /// URL to the ONNX model file.
13    pub model_url: &'static str,
14    /// URL to the tokenizer JSON file.
15    pub tokenizer_url: &'static str,
16}
17
18impl ModelUrls {
19    /// BGE-Small-v1.5 URLs
20    pub const BGE_SMALL: ModelUrls = ModelUrls {
21        model_url: "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/onnx/model_quantized.onnx",
22        tokenizer_url: "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/tokenizer.json",
23    };
24
25    /// all-MiniLM-L6-v2 URLs
26    pub const MINILM_L6: ModelUrls = ModelUrls {
27        model_url: "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx",
28        tokenizer_url: "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/tokenizer.json",
29    };
30}
31
32/// Download a file from a URL to a local path.
33///
34/// Creates parent directories if they don't exist.
35/// Shows progress during download.
36async fn download_file(url: &str, dest: &Path) -> EmbeddingResult<()> {
37    if let Some(parent) = dest.parent() {
38        fs::create_dir_all(parent).await?;
39    }
40
41    let response = reqwest::get(url)
42        .await
43        .map_err(|e| EmbeddingError::ModelLoadError(format!("Failed to download {url}: {e}")))?;
44
45    if !response.status().is_success() {
46        return Err(EmbeddingError::ModelLoadError(format!(
47            "Failed to download {}: HTTP {}",
48            url,
49            response.status()
50        )));
51    }
52
53    let bytes = response
54        .bytes()
55        .await
56        .map_err(|e| EmbeddingError::ModelLoadError(format!("Failed to read response: {e}")))?;
57
58    let mut file = fs::File::create(dest).await?;
59    file.write_all(&bytes).await?;
60    file.flush().await?;
61
62    Ok(())
63}
64
65/// Ensure a model file exists, downloading it if necessary.
66///
67/// # Arguments
68/// * `path` - Path where the model should be
69/// * `url` - URL to download from if file doesn't exist
70///
71/// # Returns
72/// * `Ok(true)` if file was downloaded
73/// * `Ok(false)` if file already existed
74/// * `Err` if download failed
75pub async fn ensure_model_exists(path: &Path, url: &str) -> EmbeddingResult<bool> {
76    if path.exists() {
77        return Ok(false);
78    }
79
80    download_file(url, path).await?;
81    Ok(true)
82}
83
84/// Ensure a tokenizer file exists, downloading it if necessary.
85///
86/// # Arguments
87/// * `path` - Path where tokenizer.json should be
88/// * `url` - URL to download from if file doesn't exist
89///
90/// # Returns
91/// * `Ok(true)` if file was downloaded
92/// * `Ok(false)` if file already existed
93/// * `Err` if download failed
94pub async fn ensure_tokenizer_exists(path: &Path, url: &str) -> EmbeddingResult<bool> {
95    if path.exists() {
96        return Ok(false);
97    }
98
99    download_file(url, path).await?;
100    Ok(true)
101}
102
103/// Download both model and tokenizer for a specific configuration.
104///
105/// Uses predefined URLs for known models.
106///
107/// # Arguments
108/// * `model_name` - Name of the model ("bge-small" or "minilm-l6")
109/// * `model_dir` - Directory to download into
110///
111/// # Returns
112/// * Tuple of (model_path, tokenizer_path)
113pub async fn download_model(
114    model_name: &str,
115    model_dir: &Path,
116) -> EmbeddingResult<(PathBuf, PathBuf)> {
117    let urls = match model_name.to_lowercase().as_str() {
118        "bge-small" | "bge-small-v1.5" => ModelUrls::BGE_SMALL,
119        "minilm-l6" | "all-minilm-l6-v2" => ModelUrls::MINILM_L6,
120        _ => {
121            return Err(EmbeddingError::ConfigError(format!(
122                "Unknown model name: {model_name}. Supported: bge-small, minilm-l6"
123            )));
124        }
125    };
126
127    let model_path = if model_name.contains("bge") {
128        model_dir.join("BGE-Small-v1.5-model_quantized.onnx")
129    } else {
130        model_dir.join("all-MiniLM-L6-v2.onnx")
131    };
132
133    let tokenizer_path = if model_name.contains("bge") {
134        model_dir.join("bge-small-tokenizer.json")
135    } else {
136        model_dir.join("minilm-l6-tokenizer.json")
137    };
138
139    ensure_model_exists(&model_path, urls.model_url).await?;
140
141    ensure_tokenizer_exists(&tokenizer_path, urls.tokenizer_url).await?;
142
143    Ok((model_path, tokenizer_path))
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_model_urls() {
152        assert!(ModelUrls::BGE_SMALL.model_url.contains("bge-small"));
153        assert!(
154            ModelUrls::BGE_SMALL
155                .tokenizer_url
156                .contains("tokenizer.json")
157        );
158        assert!(ModelUrls::MINILM_L6.model_url.contains("MiniLM"));
159    }
160}