1use anyhow::Result;
2#[cfg(feature = "hf-hub")]
3use hf_hub::api::sync::{ApiBuilder, ApiRepo};
4#[cfg(feature = "hf-hub")]
5use std::path::PathBuf;
6use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
7
8const DEFAULT_CACHE_DIR: &str = ".fastembed_cache";
9
10pub fn get_cache_dir() -> String {
11 std::env::var("FASTEMBED_CACHE_DIR").unwrap_or(DEFAULT_CACHE_DIR.into())
12}
13
14pub struct SparseEmbedding {
15 pub indices: Vec<usize>,
16 pub values: Vec<f32>,
17}
18
19pub type Embedding = Vec<f32>;
21
22pub type Error = anyhow::Error;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct TokenizerFiles {
28 pub tokenizer_file: Vec<u8>,
29 pub config_file: Vec<u8>,
30 pub special_tokens_map_file: Vec<u8>,
31 pub tokenizer_config_file: Vec<u8>,
32}
33
34#[cfg(feature = "hf-hub")]
37pub fn load_tokenizer_hf_hub(model_repo: ApiRepo, max_length: usize) -> Result<Tokenizer> {
38 let tokenizer_files: TokenizerFiles = TokenizerFiles {
39 tokenizer_file: std::fs::read(model_repo.get("tokenizer.json")?)?,
40 config_file: std::fs::read(&model_repo.get("config.json")?)?,
41 special_tokens_map_file: std::fs::read(&model_repo.get("special_tokens_map.json")?)?,
42
43 tokenizer_config_file: std::fs::read(&model_repo.get("tokenizer_config.json")?)?,
44 };
45
46 load_tokenizer(tokenizer_files, max_length)
47}
48
49pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result<Tokenizer> {
53 let base_error_message =
54 "Error building TokenizerFiles for UserDefinedEmbeddingModel. Could not read {} file.";
55
56 let config: serde_json::Value =
58 serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| {
59 std::io::Error::new(
60 std::io::ErrorKind::InvalidData,
61 base_error_message.replace("{}", "config.json"),
62 )
63 })?;
64 let special_tokens_map: serde_json::Value =
65 serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| {
66 std::io::Error::new(
67 std::io::ErrorKind::InvalidData,
68 base_error_message.replace("{}", "special_tokens_map.json"),
69 )
70 })?;
71 let tokenizer_config: serde_json::Value =
72 serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| {
73 std::io::Error::new(
74 std::io::ErrorKind::InvalidData,
75 base_error_message.replace("{}", "tokenizer_config.json"),
76 )
77 })?;
78 let mut tokenizer: tokenizers::Tokenizer =
79 tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| {
80 std::io::Error::new(
81 std::io::ErrorKind::InvalidData,
82 base_error_message.replace("{}", "tokenizer.json"),
83 )
84 })?;
85
86 let model_max_length = tokenizer_config["model_max_length"]
88 .as_f64()
89 .expect("Error reading model_max_length from tokenizer_config.json")
90 as f32;
91 let max_length = max_length.min(model_max_length as usize);
92 let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32;
93 let pad_token = tokenizer_config["pad_token"]
94 .as_str()
95 .expect("Error reading pad_token from tokenizer_config.json")
96 .into();
97
98 let mut tokenizer = tokenizer
99 .with_padding(Some(PaddingParams {
100 strategy: PaddingStrategy::BatchLongest,
102 pad_token,
103 pad_id,
104 ..Default::default()
105 }))
106 .with_truncation(Some(TruncationParams {
107 max_length,
108 ..Default::default()
109 }))
110 .map_err(anyhow::Error::msg)?
111 .clone();
112 if let serde_json::Value::Object(root_object) = special_tokens_map {
113 for (_, value) in root_object.iter() {
114 if value.is_string() {
115 if let Some(content) = value.as_str() {
116 tokenizer.add_special_tokens(&[AddedToken {
117 content: content.into(),
118 special: true,
119 ..Default::default()
120 }]);
121 }
122 } else if value.is_object() {
123 if let (
124 Some(content),
125 Some(single_word),
126 Some(lstrip),
127 Some(rstrip),
128 Some(normalized),
129 ) = (
130 value["content"].as_str(),
131 value["single_word"].as_bool(),
132 value["lstrip"].as_bool(),
133 value["rstrip"].as_bool(),
134 value["normalized"].as_bool(),
135 ) {
136 tokenizer.add_special_tokens(&[AddedToken {
137 content: content.into(),
138 special: true,
139 single_word,
140 lstrip,
141 rstrip,
142 normalized,
143 }]);
144 }
145 }
146 }
147 }
148 Ok(tokenizer.into())
149}
150
151pub fn normalize(v: &[f32]) -> Vec<f32> {
152 let norm = (v.iter().map(|val| val * val).sum::<f32>()).sqrt();
153 let epsilon = 1e-12;
154
155 v.iter().map(|&val| val / (norm + epsilon)).collect()
157}
158
159#[cfg(feature = "hf-hub")]
163pub fn pull_from_hf(
164 model_name: String,
165 default_cache_dir: PathBuf,
166 show_download_progress: bool,
167) -> anyhow::Result<ApiRepo> {
168 use std::env;
169
170 let cache_dir = env::var("HF_HOME")
171 .map(PathBuf::from)
172 .unwrap_or(default_cache_dir);
173
174 let endpoint = env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
175
176 let api = ApiBuilder::new()
177 .with_cache_dir(cache_dir)
178 .with_endpoint(endpoint)
179 .with_progress(show_download_progress)
180 .build()?;
181
182 let repo = api.model(model_name);
183 Ok(repo)
184}