search_semantically/
embedder.rs1use anyhow::{Context, Result};
2use ort::session::Session;
3use std::path::{Path, PathBuf};
4use std::str::FromStr;
5use std::sync::Arc;
6
7pub type DownloadCallback = Arc<dyn Fn(&str) + Send + Sync>;
8
9const DEFAULT_MODEL_NAME: &str = "Xenova/all-MiniLM-L6-v2";
10const DEFAULT_DIMENSION: usize = 384;
11
12pub struct Embedder {
13 model_cache_dir: PathBuf,
14 model_name: String,
15 session: Option<Session>,
16 tokenizer: Option<tokenizers::Tokenizer>,
17 dimension: usize,
18 download_callback: Option<DownloadCallback>,
19}
20
21impl Embedder {
22 pub fn new(model_cache_dir: PathBuf) -> Self {
23 Self {
24 model_cache_dir,
25 model_name: DEFAULT_MODEL_NAME.to_string(),
26 session: None,
27 tokenizer: None,
28 dimension: DEFAULT_DIMENSION,
29 download_callback: None,
30 }
31 }
32
33 pub fn with_download_callback(mut self, callback: DownloadCallback) -> Self {
34 self.download_callback = Some(callback);
35 self
36 }
37
38 pub fn set_download_callback(&mut self, callback: DownloadCallback) {
39 self.download_callback = Some(callback);
40 }
41
42 pub fn initialize(&mut self) -> Result<()> {
43 if self.session.is_some() {
44 return Ok(());
45 }
46
47 let model_dir = self.model_cache_dir.join(&self.model_name);
48 std::fs::create_dir_all(&model_dir)
49 .with_context(|| format!("Creating model cache dir: {}", model_dir.display()))?;
50
51 let onnx_path = model_dir.join("model.onnx");
52 let tokenizer_path = model_dir.join("tokenizer.json");
53
54 if !onnx_path.exists() || !tokenizer_path.exists() {
55 download_model(&self.model_name, &model_dir, self.download_callback.clone())?;
56 }
57
58 let session = Session::builder()
59 .context("Creating ONNX session builder")?
60 .commit_from_file(&onnx_path)
61 .with_context(|| format!("Loading ONNX model from {}", onnx_path.display()))?;
62
63 let tokenizer_data = std::fs::read_to_string(&tokenizer_path)
64 .with_context(|| format!("Reading tokenizer from {}", tokenizer_path.display()))?;
65 let tokenizer = tokenizers::Tokenizer::from_str(&tokenizer_data)
66 .map_err(|e| anyhow::anyhow!("Parsing tokenizer JSON: {e}"))?;
67
68 self.dimension = detect_dimension(&session);
69 self.session = Some(session);
70 self.tokenizer = Some(tokenizer);
71
72 Ok(())
73 }
74
75 pub fn embed(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
76 let tokenizer = self.tokenizer.as_ref().expect("Embedder not initialized");
77
78 let mut results = Vec::with_capacity(texts.len());
79
80 for text in texts {
81 let encoding = tokenizer
82 .encode(*text, true)
83 .map_err(|e| anyhow::anyhow!("Tokenization failed: {e}"))?;
84
85 let ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
86 let attention_mask: Vec<i64> = encoding
87 .get_attention_mask()
88 .iter()
89 .map(|&m| m as i64)
90 .collect();
91 let type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&t| t as i64).collect();
92
93 let len = ids.len();
94 let input_ids = ndarray::Array2::from_shape_vec((1, len), ids)
95 .context("Creating input_ids array")?;
96 let attn_mask = ndarray::Array2::from_shape_vec((1, len), attention_mask)
97 .context("Creating attention_mask array")?;
98 let token_types = ndarray::Array2::from_shape_vec((1, len), type_ids)
99 .context("Creating token_type_ids array")?;
100
101 let session = self.session.as_mut().expect("Embedder not session");
102
103 let input_ids_val =
104 ort::value::Tensor::from_array(input_ids).context("Creating input_ids tensor")?;
105 let attn_mask_val = ort::value::Tensor::from_array(attn_mask)
106 .context("Creating attention_mask tensor")?;
107 let token_types_val = ort::value::Tensor::from_array(token_types)
108 .context("Creating token_type_ids tensor")?;
109
110 let outputs = session
111 .run(ort::inputs! {
112 "input_ids" => input_ids_val,
113 "attention_mask" => attn_mask_val,
114 "token_type_ids" => token_types_val,
115 })
116 .context("Running ONNX inference")?;
117
118 let output = outputs.iter().next().context("No output from model")?.1;
119
120 let (_, data) = output
121 .try_extract_tensor::<f32>()
122 .context("Extracting tensor")?;
123
124 let mask_f32: Vec<f32> = encoding
125 .get_attention_mask()
126 .iter()
127 .map(|&m| m as f32)
128 .collect();
129 let embedding = mean_pool_normalize(data, len, self.dimension, &mask_f32);
130
131 results.push(embedding);
132 }
133
134 Ok(results)
135 }
136
137 pub fn dimension(&self) -> usize {
138 self.dimension
139 }
140}
141
142fn detect_dimension(session: &Session) -> usize {
143 session
144 .outputs()
145 .first()
146 .and_then(|outlet| outlet.dtype().tensor_shape())
147 .and_then(|shape| shape.last().copied())
148 .filter(|&d| d > 0)
149 .map(|d| d as usize)
150 .unwrap_or(DEFAULT_DIMENSION)
151}
152
153fn mean_pool_normalize(data: &[f32], seq_len: usize, dim: usize, mask: &[f32]) -> Vec<f32> {
154 let mut pooled = vec![0.0_f32; dim];
155 let mut mask_sum = 0.0_f32;
156
157 for i in 0..seq_len {
158 let weight = mask[i];
159 mask_sum += weight;
160 for j in 0..dim {
161 pooled[j] += data[i * dim + j] * weight;
162 }
163 }
164
165 if mask_sum > 0.0 {
166 for val in pooled.iter_mut() {
167 *val /= mask_sum;
168 }
169 }
170
171 let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
172 if norm > 0.0 {
173 for val in pooled.iter_mut() {
174 *val /= norm;
175 }
176 }
177
178 pooled
179}
180
181fn download_model(
182 model_name: &str,
183 target_dir: &Path,
184 callback: Option<DownloadCallback>,
185) -> Result<()> {
186 let files = ["model.onnx", "tokenizer.json"];
187
188 let model_name_owned = model_name.to_string();
189 let target_dir_owned = target_dir.to_path_buf();
190
191 let handle = std::thread::spawn(move || -> Result<()> {
192 for file in &files {
193 let url = format!("https://huggingface.co/{model_name_owned}/resolve/main/{file}");
194 let dest = target_dir_owned.join(file);
195
196 if let Some(ref cb) = callback {
197 cb(&url);
198 }
199
200 let response = reqwest::blocking::get(&url)
201 .with_context(|| format!("HTTP request to {url}"))?
202 .error_for_status()
203 .context("HTTP request failed")?;
204 let buf = response.bytes().context("Reading response body")?;
205 std::fs::write(&dest, &buf).with_context(|| format!("Writing {}", dest.display()))?;
206 }
207 Ok(())
208 });
209
210 handle
211 .join()
212 .map_err(|e| anyhow::anyhow!("Model download thread panicked: {e:?}"))?
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn embedder_new_has_no_session() {
221 let embedder = Embedder::new(std::env::temp_dir());
222 assert!(embedder.session.is_none());
223 assert!(embedder.tokenizer.is_none());
224 }
225
226 #[test]
227 fn embedder_default_dimension() {
228 let embedder = Embedder::new(std::env::temp_dir());
229 assert_eq!(embedder.dimension(), 384);
230 }
231
232 #[test]
233 fn mean_pool_normalize_produces_unit_vector() {
234 let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let mask = vec![1.0_f32, 1.0];
236 let result = mean_pool_normalize(&data, 2, 3, &mask);
237
238 let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
239 assert!(
240 (norm - 1.0).abs() < 1e-5,
241 "Should be unit vector, got norm {norm}"
242 );
243 }
244
245 #[test]
246 fn mean_pool_normalize_with_zero_mask() {
247 let data = vec![1.0_f32, 2.0, 3.0];
248 let mask = vec![0.0_f32];
249 let result = mean_pool_normalize(&data, 1, 3, &mask);
250 assert!(result.iter().all(|&v| v == 0.0));
251 }
252}