Skip to main content

fastembed/
common.rs

1use anyhow::Result;
2#[cfg(feature = "hf-hub")]
3use hf_hub::api::sync::{ApiBuilder, ApiRepo};
4#[cfg(feature = "hf-hub")]
5use std::path::PathBuf;
6use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
7
8const DEFAULT_CACHE_DIR: &str = ".fastembed_cache";
9
10pub fn get_cache_dir() -> String {
11    std::env::var("FASTEMBED_CACHE_DIR").unwrap_or(DEFAULT_CACHE_DIR.into())
12}
13
14pub struct SparseEmbedding {
15    pub indices: Vec<usize>,
16    pub values: Vec<f32>,
17}
18
19/// Type alias for the embedding vector
20pub type Embedding = Vec<f32>;
21
22/// Type alias for the error type
23pub type Error = anyhow::Error;
24
25// Tokenizer files for "bring your own" models
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct TokenizerFiles {
28    pub tokenizer_file: Vec<u8>,
29    pub config_file: Vec<u8>,
30    pub special_tokens_map_file: Vec<u8>,
31    pub tokenizer_config_file: Vec<u8>,
32}
33
34/// The procedure for loading tokenizer files from the hugging face hub is separated
35/// from the main load_tokenizer function (which is expecting bytes, from any source).
36#[cfg(feature = "hf-hub")]
37pub fn load_tokenizer_hf_hub(model_repo: ApiRepo, max_length: usize) -> Result<Tokenizer> {
38    let tokenizer_files: TokenizerFiles = TokenizerFiles {
39        tokenizer_file: std::fs::read(model_repo.get("tokenizer.json")?)?,
40        config_file: std::fs::read(&model_repo.get("config.json")?)?,
41        special_tokens_map_file: std::fs::read(&model_repo.get("special_tokens_map.json")?)?,
42
43        tokenizer_config_file: std::fs::read(&model_repo.get("tokenizer_config.json")?)?,
44    };
45
46    load_tokenizer(tokenizer_files, max_length)
47}
48
49/// Function can be called directly from the try_new_from_user_defined function (providing file bytes)
50///
51/// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes)
52pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result<Tokenizer> {
53    let base_error_message =
54        "Error building TokenizerFiles for UserDefinedEmbeddingModel. Could not read {} file.";
55
56    // Deserialize each tokenizer file
57    let config: serde_json::Value =
58        serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| {
59            std::io::Error::new(
60                std::io::ErrorKind::InvalidData,
61                base_error_message.replace("{}", "config.json"),
62            )
63        })?;
64    let special_tokens_map: serde_json::Value =
65        serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| {
66            std::io::Error::new(
67                std::io::ErrorKind::InvalidData,
68                base_error_message.replace("{}", "special_tokens_map.json"),
69            )
70        })?;
71    let tokenizer_config: serde_json::Value =
72        serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| {
73            std::io::Error::new(
74                std::io::ErrorKind::InvalidData,
75                base_error_message.replace("{}", "tokenizer_config.json"),
76            )
77        })?;
78    let mut tokenizer: tokenizers::Tokenizer =
79        tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| {
80            std::io::Error::new(
81                std::io::ErrorKind::InvalidData,
82                base_error_message.replace("{}", "tokenizer.json"),
83            )
84        })?;
85
86    //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64
87    let model_max_length = tokenizer_config["model_max_length"]
88        .as_f64()
89        .expect("Error reading model_max_length from tokenizer_config.json")
90        as f32;
91    let max_length = max_length.min(model_max_length as usize);
92    let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32;
93    let pad_token = tokenizer_config["pad_token"]
94        .as_str()
95        .expect("Error reading pad_token from tokenizer_config.json")
96        .into();
97
98    let mut tokenizer = tokenizer
99        .with_padding(Some(PaddingParams {
100            // TODO: the user should be able to choose the padding strategy
101            strategy: PaddingStrategy::BatchLongest,
102            pad_token,
103            pad_id,
104            ..Default::default()
105        }))
106        .with_truncation(Some(TruncationParams {
107            max_length,
108            ..Default::default()
109        }))
110        .map_err(anyhow::Error::msg)?
111        .clone();
112    if let serde_json::Value::Object(root_object) = special_tokens_map {
113        for (_, value) in root_object.iter() {
114            if value.is_string() {
115                if let Some(content) = value.as_str() {
116                    tokenizer.add_special_tokens(&[AddedToken {
117                        content: content.into(),
118                        special: true,
119                        ..Default::default()
120                    }]);
121                }
122            } else if value.is_object() {
123                if let (
124                    Some(content),
125                    Some(single_word),
126                    Some(lstrip),
127                    Some(rstrip),
128                    Some(normalized),
129                ) = (
130                    value["content"].as_str(),
131                    value["single_word"].as_bool(),
132                    value["lstrip"].as_bool(),
133                    value["rstrip"].as_bool(),
134                    value["normalized"].as_bool(),
135                ) {
136                    tokenizer.add_special_tokens(&[AddedToken {
137                        content: content.into(),
138                        special: true,
139                        single_word,
140                        lstrip,
141                        rstrip,
142                        normalized,
143                    }]);
144                }
145            }
146        }
147    }
148    Ok(tokenizer.into())
149}
150
151pub fn normalize(v: &[f32]) -> Vec<f32> {
152    let norm = (v.iter().map(|val| val * val).sum::<f32>()).sqrt();
153    let epsilon = 1e-12;
154
155    // We add the super-small epsilon to avoid dividing by zero
156    v.iter().map(|&val| val / (norm + epsilon)).collect()
157}
158
159/// Pulls a model repo from HuggingFace..
160/// HF_HOME decides the location of the cache folder
161/// HF_ENDPOINT modifies the URL for the HuggingFace location.
162#[cfg(feature = "hf-hub")]
163pub fn pull_from_hf(
164    model_name: String,
165    default_cache_dir: PathBuf,
166    show_download_progress: bool,
167) -> anyhow::Result<ApiRepo> {
168    use std::env;
169
170    let cache_dir = env::var("HF_HOME")
171        .map(PathBuf::from)
172        .unwrap_or(default_cache_dir);
173
174    let endpoint = env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
175
176    let api = ApiBuilder::new()
177        .with_cache_dir(cache_dir)
178        .with_endpoint(endpoint)
179        .with_progress(show_download_progress)
180        .build()?;
181
182    let repo = api.model(model_name);
183    Ok(repo)
184}