1mod bert;
27
28use std::path::PathBuf;
29
30use anyhow::{Context, Result, bail};
31use burn::tensor::{Tensor, backend::Backend};
32use burn_wgpu::{Wgpu, WgpuDevice};
33
34use crate::bert::{
35 BertEmbeddingModel, BertEmbeddingVariant, EmbeddingInputKind,
36 load_pretrained_bert_embedding,
37};
38
39pub type DefaultBackend = Wgpu;
40pub type DefaultDevice = WgpuDevice;
41const DEFAULT_BATCH_SIZE: usize = 32;
42
43#[non_exhaustive]
45#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
46pub enum EmbeddingModel {
47 MiniLmL6,
48 #[default]
49 MiniLmL12,
50 BgeSmallEnV15,
51 BgeBaseEnV15,
52}
53
54impl From<EmbeddingModel> for BertEmbeddingVariant {
55 fn from(value: EmbeddingModel) -> Self {
56 match value {
57 EmbeddingModel::MiniLmL6 => BertEmbeddingVariant::MiniLmL6,
58 EmbeddingModel::MiniLmL12 => BertEmbeddingVariant::MiniLmL12,
59 EmbeddingModel::BgeSmallEnV15 => {
60 BertEmbeddingVariant::BgeSmallEnV15
61 }
62 EmbeddingModel::BgeBaseEnV15 => BertEmbeddingVariant::BgeBaseEnV15,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct TextEmbeddingOptions {
70 pub model: EmbeddingModel,
72 pub cache_dir: Option<PathBuf>,
74}
75
76#[derive(Debug)]
78pub struct TextEmbedding<B: Backend = DefaultBackend> {
79 model: BertEmbeddingModel<B>,
80 device: B::Device,
81}
82
83impl TextEmbedding<DefaultBackend> {
84 pub async fn new(options: TextEmbeddingOptions) -> Result<Self> {
86 let device = WgpuDevice::default();
87 Self::new_with_device(&device, options).await
88 }
89}
90
91impl<B> TextEmbedding<B>
92where
93 B: Backend,
94{
95 pub async fn new_with_device(
97 device: &B::Device,
98 options: TextEmbeddingOptions,
99 ) -> Result<Self> {
100 let model = load_pretrained_bert_embedding(
101 device,
102 options.model.into(),
103 options.cache_dir,
104 )
105 .await?;
106
107 Ok(Self {
108 model,
109 device: device.clone(),
110 })
111 }
112
113 pub fn embed(&self, document: impl AsRef<str>) -> Result<Vec<f32>> {
115 let document = document.as_ref();
116 let documents = [document];
117 let mut embeddings = self.embed_batch(documents.as_slice(), None)?;
118 embeddings
119 .pop()
120 .context("expected one embedding for a single input document")
121 }
122
123 pub fn embed_query(&self, query: impl AsRef<str>) -> Result<Vec<f32>> {
129 let query = query.as_ref();
130 let queries = [query];
131 let mut embeddings =
132 self.embed_query_batch(queries.as_slice(), None)?;
133 embeddings
134 .pop()
135 .context("expected one embedding for a single input query")
136 }
137
138 pub fn embed_batch<S: AsRef<str>>(
140 &self,
141 documents: &[S],
142 batch_size: Option<usize>,
143 ) -> Result<Vec<Vec<f32>>> {
144 self.embed_batch_with_kind(
145 documents,
146 batch_size,
147 EmbeddingInputKind::Document,
148 )
149 }
150
151 pub fn embed_query_batch<S: AsRef<str>>(
156 &self,
157 queries: &[S],
158 batch_size: Option<usize>,
159 ) -> Result<Vec<Vec<f32>>> {
160 self.embed_batch_with_kind(
161 queries,
162 batch_size,
163 EmbeddingInputKind::Query,
164 )
165 }
166
167 fn embed_batch_with_kind<S: AsRef<str>>(
168 &self,
169 inputs: &[S],
170 batch_size: Option<usize>,
171 input_kind: EmbeddingInputKind,
172 ) -> Result<Vec<Vec<f32>>> {
173 if inputs.is_empty() {
174 return Ok(Vec::new());
175 }
176
177 let batch_size = batch_size_or_default(inputs.len(), batch_size)?;
178
179 let mut embeddings = Vec::with_capacity(inputs.len());
180 for batch in inputs.chunks(batch_size) {
181 let batch_inputs =
182 batch.iter().map(AsRef::as_ref).collect::<Vec<_>>();
183 let batch_embeddings =
184 self.model.encode(&batch_inputs, input_kind, &self.device)?;
185 embeddings.extend(tensor_to_rows(batch_embeddings)?);
186 }
187
188 Ok(embeddings)
189 }
190
191 pub fn model(&self) -> EmbeddingModel {
193 match self.model.variant {
194 BertEmbeddingVariant::MiniLmL6 => EmbeddingModel::MiniLmL6,
195 BertEmbeddingVariant::MiniLmL12 => EmbeddingModel::MiniLmL12,
196 BertEmbeddingVariant::BgeSmallEnV15 => {
197 EmbeddingModel::BgeSmallEnV15
198 }
199 BertEmbeddingVariant::BgeBaseEnV15 => EmbeddingModel::BgeBaseEnV15,
200 }
201 }
202}
203
204fn batch_size_or_default(
205 document_count: usize,
206 batch_size: Option<usize>,
207) -> Result<usize> {
208 let batch_size =
209 batch_size.unwrap_or(document_count.min(DEFAULT_BATCH_SIZE));
210 if batch_size == 0 {
211 bail!("batch size must be greater than zero");
212 }
213
214 Ok(batch_size)
215}
216
217fn tensor_to_rows<B: Backend>(
218 embeddings: Tensor<B, 2>,
219) -> Result<Vec<Vec<f32>>> {
220 let [row_count, column_count] = embeddings.dims();
221 let data = embeddings.into_data().convert::<f32>();
222 let values = data
223 .as_slice::<f32>()
224 .map_err(|error| anyhow::anyhow!(error.to_string()))
225 .context("failed to read embedding output tensor")?;
226
227 Ok(values
228 .chunks(column_count)
229 .take(row_count)
230 .map(|row| row.to_vec())
231 .collect())
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use burn::tensor::Tensor;
238 use burn_wgpu::{Wgpu, WgpuDevice};
239 use std::sync::OnceLock;
240 use tokio::sync::Mutex;
241
242 static LIVE_MODEL_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
243
244 #[test]
245 fn api_model_mapping_converts_all_public_variants() {
246 assert_eq!(
247 BertEmbeddingVariant::from(EmbeddingModel::MiniLmL6),
248 BertEmbeddingVariant::MiniLmL6
249 );
250 assert_eq!(
251 BertEmbeddingVariant::from(EmbeddingModel::MiniLmL12),
252 BertEmbeddingVariant::MiniLmL12
253 );
254 assert_eq!(
255 BertEmbeddingVariant::from(EmbeddingModel::BgeSmallEnV15),
256 BertEmbeddingVariant::BgeSmallEnV15
257 );
258 assert_eq!(
259 BertEmbeddingVariant::from(EmbeddingModel::BgeBaseEnV15),
260 BertEmbeddingVariant::BgeBaseEnV15
261 );
262 }
263
264 #[test]
265 fn api_model_metadata_returns_bge_repo_ids() {
266 assert_eq!(
267 BertEmbeddingVariant::BgeSmallEnV15.repo_id(),
268 "BAAI/bge-small-en-v1.5"
269 );
270 assert_eq!(
271 BertEmbeddingVariant::BgeBaseEnV15.repo_id(),
272 "BAAI/bge-base-en-v1.5"
273 );
274 }
275
276 #[test]
277 fn api_options_default_uses_minilm_l12() {
278 assert_eq!(
279 TextEmbeddingOptions::default().model,
280 EmbeddingModel::MiniLmL12
281 );
282 }
283
284 #[tokio::test]
285 async fn model_bge_small_embed_returns_document_and_query_vectors() {
286 let _guard = live_model_test_lock().lock().await;
287 let model = TextEmbedding::new(TextEmbeddingOptions {
288 model: EmbeddingModel::BgeSmallEnV15,
289 ..Default::default()
290 })
291 .await
292 .expect("model should load");
293
294 let document = model
295 .embed("Hello world")
296 .expect("document embed should work");
297 let query = model
298 .embed_query("Hello world")
299 .expect("query embed should work");
300
301 assert_eq!(document.len(), 384);
302 assert_eq!(query.len(), 384);
303 }
304
305 #[tokio::test]
306 async fn model_minilm_l6_backend_supports_i32_indices() {
307 let _guard = live_model_test_lock().lock().await;
308 let device = WgpuDevice::default();
309 let model = TextEmbedding::<Wgpu<f32, i32>>::new_with_device(
310 &device,
311 TextEmbeddingOptions {
312 model: EmbeddingModel::MiniLmL6,
313 cache_dir: None,
314 },
315 )
316 .await
317 .expect("model should load");
318
319 let single = model
320 .embed("Hello world")
321 .expect("single embed should work");
322 assert!(!single.is_empty());
323 }
324
325 #[tokio::test]
326 async fn model_minilm_l6_embed_returns_vectors() {
327 let _guard = live_model_test_lock().lock().await;
328 let model = TextEmbedding::new(TextEmbeddingOptions {
329 model: EmbeddingModel::MiniLmL6,
330 ..Default::default()
331 })
332 .await
333 .expect("model should load");
334
335 let single = model
336 .embed("Hello world")
337 .expect("single embed should work");
338 assert!(!single.is_empty());
339
340 let batch = model
341 .embed_batch(&["Hello world", "Rust embeddings"], None)
342 .expect("batch embed should work");
343 assert_eq!(batch.len(), 2);
344 assert!(batch.iter().all(|embedding| !embedding.is_empty()));
345 }
346
347 #[tokio::test]
348 async fn parity_bge_base_document_matches_sentence_transformers() {
349 assert_model_matches_sentence_transformers(
350 EmbeddingModel::BgeBaseEnV15,
351 "BAAI/bge-base-en-v1.5",
352 ReferenceInputKind::Document,
353 )
354 .await;
355 }
356
357 #[tokio::test]
358 async fn parity_bge_base_query_matches_sentence_transformers() {
359 assert_model_matches_sentence_transformers(
360 EmbeddingModel::BgeBaseEnV15,
361 "BAAI/bge-base-en-v1.5",
362 ReferenceInputKind::Query,
363 )
364 .await;
365 }
366
367 #[tokio::test]
368 async fn parity_bge_small_document_matches_sentence_transformers() {
369 assert_model_matches_sentence_transformers(
370 EmbeddingModel::BgeSmallEnV15,
371 "BAAI/bge-small-en-v1.5",
372 ReferenceInputKind::Document,
373 )
374 .await;
375 }
376
377 #[tokio::test]
378 async fn parity_bge_small_query_matches_sentence_transformers() {
379 assert_model_matches_sentence_transformers(
380 EmbeddingModel::BgeSmallEnV15,
381 "BAAI/bge-small-en-v1.5",
382 ReferenceInputKind::Query,
383 )
384 .await;
385 }
386
387 #[tokio::test]
388 async fn parity_minilm_l12_document_matches_sentence_transformers() {
389 assert_model_matches_sentence_transformers(
390 EmbeddingModel::MiniLmL12,
391 "sentence-transformers/all-MiniLM-L12-v2",
392 ReferenceInputKind::Document,
393 )
394 .await;
395 }
396
397 #[tokio::test]
398 async fn parity_minilm_l12_query_matches_sentence_transformers() {
399 assert_model_matches_sentence_transformers(
400 EmbeddingModel::MiniLmL12,
401 "sentence-transformers/all-MiniLM-L12-v2",
402 ReferenceInputKind::Query,
403 )
404 .await;
405 }
406
407 #[tokio::test]
408 async fn parity_minilm_l6_document_matches_sentence_transformers() {
409 assert_model_matches_sentence_transformers(
410 EmbeddingModel::MiniLmL6,
411 "sentence-transformers/all-MiniLM-L6-v2",
412 ReferenceInputKind::Document,
413 )
414 .await;
415 }
416
417 #[tokio::test]
418 async fn parity_minilm_l6_query_matches_sentence_transformers() {
419 assert_model_matches_sentence_transformers(
420 EmbeddingModel::MiniLmL6,
421 "sentence-transformers/all-MiniLM-L6-v2",
422 ReferenceInputKind::Query,
423 )
424 .await;
425 }
426
427 #[test]
428 fn util_batch_size_default_caps_large_batches() {
429 let batch_size = batch_size_or_default(128, None)
430 .expect("default batch size should work");
431 assert_eq!(batch_size, DEFAULT_BATCH_SIZE);
432 }
433
434 #[test]
435 fn util_batch_size_default_uses_document_count_when_small() {
436 let batch_size = batch_size_or_default(4, None)
437 .expect("default batch size should work");
438 assert_eq!(batch_size, 4);
439 }
440
441 #[test]
442 fn util_batch_size_validate_rejects_zero() {
443 let error = batch_size_or_default(1, Some(0))
444 .expect_err("zero batch size should fail");
445 assert!(
446 error
447 .to_string()
448 .contains("batch size must be greater than zero")
449 );
450 }
451
452 #[test]
453 fn util_tensor_rows_extract_returns_rows() {
454 let device = WgpuDevice::default();
455 let embeddings = Tensor::<Wgpu<f32, i64>, 2>::from_floats(
456 [[1.0, 2.0], [3.0, 4.0]],
457 &device,
458 );
459
460 let rows = tensor_to_rows(embeddings).expect("rows should extract");
461 assert_eq!(rows, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
462 }
463
464 #[derive(Debug, Clone, Copy)]
465 enum ReferenceInputKind {
466 Document,
467 Query,
468 }
469
470 impl ReferenceInputKind {
471 fn as_str(self) -> &'static str {
472 match self {
473 Self::Document => "document",
474 Self::Query => "query",
475 }
476 }
477 }
478
479 async fn assert_model_matches_sentence_transformers(
480 model: EmbeddingModel,
481 reference_model: &str,
482 input_kind: ReferenceInputKind,
483 ) {
484 let _guard = live_model_test_lock().lock().await;
485 let texts =
486 vec!["Hello world".to_string(), "Rust embeddings".to_string()];
487 let model = TextEmbedding::new(TextEmbeddingOptions {
488 model,
489 ..Default::default()
490 })
491 .await
492 .expect("model should load");
493 let actual = match input_kind {
494 ReferenceInputKind::Document => model
495 .embed_batch(&texts, Some(2))
496 .expect("Burn document embeddings should work"),
497 ReferenceInputKind::Query => model
498 .embed_query_batch(&texts, Some(2))
499 .expect("Burn query embeddings should work"),
500 };
501 let expected =
502 reference_embeddings(reference_model, input_kind.as_str(), &texts)
503 .expect("reference embeddings should work");
504
505 assert_embedding_batches_close(&actual, &expected, 1e-3, 0.999);
506 }
507
508 fn live_model_test_lock() -> &'static Mutex<()> {
509 LIVE_MODEL_TEST_LOCK.get_or_init(|| Mutex::new(()))
510 }
511
512 fn reference_embeddings(
513 model: &str,
514 kind: &str,
515 texts: &[String],
516 ) -> Result<Vec<Vec<f32>>> {
517 use std::io::Write;
518 use std::process::{Command, Stdio};
519
520 let mut child = Command::new("uv")
521 .args([
522 "run",
523 "scripts/reference_embeddings.py",
524 "--model",
525 model,
526 "--kind",
527 kind,
528 ])
529 .stdin(Stdio::piped())
530 .stdout(Stdio::piped())
531 .stderr(Stdio::piped())
532 .spawn()
533 .context("failed to spawn uv reference embedding script")?;
534
535 let mut stdin = child
536 .stdin
537 .take()
538 .context("failed to open reference script stdin")?;
539 let input = serde_json::to_vec(texts)
540 .context("failed to serialize reference input")?;
541 stdin
542 .write_all(&input)
543 .context("failed to write reference input")?;
544 drop(stdin);
545
546 let output = child
547 .wait_with_output()
548 .context("failed to wait for reference script")?;
549 if !output.status.success() {
550 bail!(
551 "reference script failed: {}",
552 String::from_utf8_lossy(&output.stderr)
553 );
554 }
555
556 serde_json::from_slice(&output.stdout)
557 .context("failed to parse reference embeddings")
558 }
559
560 fn assert_embedding_batches_close(
561 actual: &[Vec<f32>],
562 expected: &[Vec<f32>],
563 tolerance: f32,
564 min_cosine_similarity: f32,
565 ) {
566 assert_eq!(actual.len(), expected.len());
567 for (actual, expected) in actual.iter().zip(expected) {
568 assert_eq!(actual.len(), expected.len());
569 let max_delta = actual
570 .iter()
571 .zip(expected)
572 .map(|(actual, expected)| (actual - expected).abs())
573 .fold(0.0f32, f32::max);
574 assert!(
575 max_delta <= tolerance,
576 "max embedding delta {max_delta} exceeded tolerance {tolerance}"
577 );
578 let cosine_similarity = cosine_similarity(actual, expected);
579 assert!(
580 cosine_similarity >= min_cosine_similarity,
581 "cosine similarity {cosine_similarity} fell below {min_cosine_similarity}"
582 );
583 }
584 }
585
586 fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
587 let dot_product = left
588 .iter()
589 .zip(right)
590 .map(|(left, right)| left * right)
591 .sum::<f32>();
592 let left_norm =
593 left.iter().map(|value| value * value).sum::<f32>().sqrt();
594 let right_norm =
595 right.iter().map(|value| value * value).sum::<f32>().sqrt();
596
597 dot_product / (left_norm * right_norm)
598 }
599}