Skip to main content

akuna_embed/
lib.rs

1//! Simple text embedding models built with Burn.
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use akuna_embed::{EmbeddingModel, TextEmbedding, TextEmbeddingOptions};
7//!
8//! #[tokio::main]
9//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
10//!     let model = TextEmbedding::new(TextEmbeddingOptions {
11//!         model: EmbeddingModel::MiniLmL12,
12//!         ..Default::default()
13//!     })
14//!     .await?;
15//!
16//!     let single = model.embed("Hello world")?;
17//!     assert!(!single.is_empty());
18//!
19//!     let batch = model.embed_batch(&["Hello world", "Rust embeddings"], None)?;
20//!     assert_eq!(batch.len(), 2);
21//!
22//!     Ok(())
23//! }
24//! ```
25
26mod bert;
27
28use std::path::PathBuf;
29
30use anyhow::{Context, Result, bail};
31use burn::tensor::{Tensor, backend::Backend};
32use burn_wgpu::{Wgpu, WgpuDevice};
33
34use crate::bert::{
35    BertEmbeddingModel, BertEmbeddingVariant, EmbeddingInputKind,
36    load_pretrained_bert_embedding,
37};
38
39pub type DefaultBackend = Wgpu;
40pub type DefaultDevice = WgpuDevice;
41const DEFAULT_BATCH_SIZE: usize = 32;
42
43/// Supported embedding model checkpoints.
44#[non_exhaustive]
45#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
46pub enum EmbeddingModel {
47    MiniLmL6,
48    #[default]
49    MiniLmL12,
50    BgeSmallEnV15,
51    BgeBaseEnV15,
52}
53
54impl From<EmbeddingModel> for BertEmbeddingVariant {
55    fn from(value: EmbeddingModel) -> Self {
56        match value {
57            EmbeddingModel::MiniLmL6 => BertEmbeddingVariant::MiniLmL6,
58            EmbeddingModel::MiniLmL12 => BertEmbeddingVariant::MiniLmL12,
59            EmbeddingModel::BgeSmallEnV15 => {
60                BertEmbeddingVariant::BgeSmallEnV15
61            }
62            EmbeddingModel::BgeBaseEnV15 => BertEmbeddingVariant::BgeBaseEnV15,
63        }
64    }
65}
66
67/// Options for [`TextEmbedding`].
68#[derive(Debug, Clone, Default)]
69pub struct TextEmbeddingOptions {
70    /// Which embedding checkpoint to load.
71    pub model: EmbeddingModel,
72    /// Optional Hugging Face cache directory override.
73    pub cache_dir: Option<PathBuf>,
74}
75
76/// Minimal text embedding interface inspired by `fastembed-rs`.
77#[derive(Debug)]
78pub struct TextEmbedding<B: Backend = DefaultBackend> {
79    model: BertEmbeddingModel<B>,
80    device: B::Device,
81}
82
83impl TextEmbedding<DefaultBackend> {
84    /// Loads a MiniLM text embedding model onto the default WGPU device.
85    pub async fn new(options: TextEmbeddingOptions) -> Result<Self> {
86        let device = WgpuDevice::default();
87        Self::new_with_device(&device, options).await
88    }
89}
90
91impl<B> TextEmbedding<B>
92where
93    B: Backend,
94{
95    /// Loads a MiniLM text embedding model onto the provided device.
96    pub async fn new_with_device(
97        device: &B::Device,
98        options: TextEmbeddingOptions,
99    ) -> Result<Self> {
100        let model = load_pretrained_bert_embedding(
101            device,
102            options.model.into(),
103            options.cache_dir,
104        )
105        .await?;
106
107        Ok(Self {
108            model,
109            device: device.clone(),
110        })
111    }
112
113    /// Embeds a single document and returns one embedding vector.
114    pub fn embed(&self, document: impl AsRef<str>) -> Result<Vec<f32>> {
115        let document = document.as_ref();
116        let documents = [document];
117        let mut embeddings = self.embed_batch(documents.as_slice(), None)?;
118        embeddings
119            .pop()
120            .context("expected one embedding for a single input document")
121    }
122
123    /// Embeds a search query using any model-specific retrieval prompt.
124    ///
125    /// Some retrieval models train queries and documents with different text
126    /// prefixes. Use this when the text is the thing being searched for.
127    /// Use [`TextEmbedding::embed`] when the text is the content being indexed.
128    pub fn embed_query(&self, query: impl AsRef<str>) -> Result<Vec<f32>> {
129        let query = query.as_ref();
130        let queries = [query];
131        let mut embeddings =
132            self.embed_query_batch(queries.as_slice(), None)?;
133        embeddings
134            .pop()
135            .context("expected one embedding for a single input query")
136    }
137
138    /// Embeds documents in batches and returns one vector per input string.
139    pub fn embed_batch<S: AsRef<str>>(
140        &self,
141        documents: &[S],
142        batch_size: Option<usize>,
143    ) -> Result<Vec<Vec<f32>>> {
144        self.embed_batch_with_kind(
145            documents,
146            batch_size,
147            EmbeddingInputKind::Document,
148        )
149    }
150
151    /// Embeds search queries in batches using model-specific retrieval prompts.
152    ///
153    /// See [`TextEmbedding::embed_query`] for when query embeddings differ from
154    /// document embeddings.
155    pub fn embed_query_batch<S: AsRef<str>>(
156        &self,
157        queries: &[S],
158        batch_size: Option<usize>,
159    ) -> Result<Vec<Vec<f32>>> {
160        self.embed_batch_with_kind(
161            queries,
162            batch_size,
163            EmbeddingInputKind::Query,
164        )
165    }
166
167    fn embed_batch_with_kind<S: AsRef<str>>(
168        &self,
169        inputs: &[S],
170        batch_size: Option<usize>,
171        input_kind: EmbeddingInputKind,
172    ) -> Result<Vec<Vec<f32>>> {
173        if inputs.is_empty() {
174            return Ok(Vec::new());
175        }
176
177        let batch_size = batch_size_or_default(inputs.len(), batch_size)?;
178
179        let mut embeddings = Vec::with_capacity(inputs.len());
180        for batch in inputs.chunks(batch_size) {
181            let batch_inputs =
182                batch.iter().map(AsRef::as_ref).collect::<Vec<_>>();
183            let batch_embeddings =
184                self.model.encode(&batch_inputs, input_kind, &self.device)?;
185            embeddings.extend(tensor_to_rows(batch_embeddings)?);
186        }
187
188        Ok(embeddings)
189    }
190
191    /// Returns the loaded embedding checkpoint.
192    pub fn model(&self) -> EmbeddingModel {
193        match self.model.variant {
194            BertEmbeddingVariant::MiniLmL6 => EmbeddingModel::MiniLmL6,
195            BertEmbeddingVariant::MiniLmL12 => EmbeddingModel::MiniLmL12,
196            BertEmbeddingVariant::BgeSmallEnV15 => {
197                EmbeddingModel::BgeSmallEnV15
198            }
199            BertEmbeddingVariant::BgeBaseEnV15 => EmbeddingModel::BgeBaseEnV15,
200        }
201    }
202}
203
204fn batch_size_or_default(
205    document_count: usize,
206    batch_size: Option<usize>,
207) -> Result<usize> {
208    let batch_size =
209        batch_size.unwrap_or(document_count.min(DEFAULT_BATCH_SIZE));
210    if batch_size == 0 {
211        bail!("batch size must be greater than zero");
212    }
213
214    Ok(batch_size)
215}
216
217fn tensor_to_rows<B: Backend>(
218    embeddings: Tensor<B, 2>,
219) -> Result<Vec<Vec<f32>>> {
220    let [row_count, column_count] = embeddings.dims();
221    let data = embeddings.into_data().convert::<f32>();
222    let values = data
223        .as_slice::<f32>()
224        .map_err(|error| anyhow::anyhow!(error.to_string()))
225        .context("failed to read embedding output tensor")?;
226
227    Ok(values
228        .chunks(column_count)
229        .take(row_count)
230        .map(|row| row.to_vec())
231        .collect())
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use burn::tensor::Tensor;
238    use burn_wgpu::{Wgpu, WgpuDevice};
239    use std::sync::OnceLock;
240    use tokio::sync::Mutex;
241
242    static LIVE_MODEL_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
243
244    #[test]
245    fn api_model_mapping_converts_all_public_variants() {
246        assert_eq!(
247            BertEmbeddingVariant::from(EmbeddingModel::MiniLmL6),
248            BertEmbeddingVariant::MiniLmL6
249        );
250        assert_eq!(
251            BertEmbeddingVariant::from(EmbeddingModel::MiniLmL12),
252            BertEmbeddingVariant::MiniLmL12
253        );
254        assert_eq!(
255            BertEmbeddingVariant::from(EmbeddingModel::BgeSmallEnV15),
256            BertEmbeddingVariant::BgeSmallEnV15
257        );
258        assert_eq!(
259            BertEmbeddingVariant::from(EmbeddingModel::BgeBaseEnV15),
260            BertEmbeddingVariant::BgeBaseEnV15
261        );
262    }
263
264    #[test]
265    fn api_model_metadata_returns_bge_repo_ids() {
266        assert_eq!(
267            BertEmbeddingVariant::BgeSmallEnV15.repo_id(),
268            "BAAI/bge-small-en-v1.5"
269        );
270        assert_eq!(
271            BertEmbeddingVariant::BgeBaseEnV15.repo_id(),
272            "BAAI/bge-base-en-v1.5"
273        );
274    }
275
276    #[test]
277    fn api_options_default_uses_minilm_l12() {
278        assert_eq!(
279            TextEmbeddingOptions::default().model,
280            EmbeddingModel::MiniLmL12
281        );
282    }
283
284    #[tokio::test]
285    async fn model_bge_small_embed_returns_document_and_query_vectors() {
286        let _guard = live_model_test_lock().lock().await;
287        let model = TextEmbedding::new(TextEmbeddingOptions {
288            model: EmbeddingModel::BgeSmallEnV15,
289            ..Default::default()
290        })
291        .await
292        .expect("model should load");
293
294        let document = model
295            .embed("Hello world")
296            .expect("document embed should work");
297        let query = model
298            .embed_query("Hello world")
299            .expect("query embed should work");
300
301        assert_eq!(document.len(), 384);
302        assert_eq!(query.len(), 384);
303    }
304
305    #[tokio::test]
306    async fn model_minilm_l6_backend_supports_i32_indices() {
307        let _guard = live_model_test_lock().lock().await;
308        let device = WgpuDevice::default();
309        let model = TextEmbedding::<Wgpu<f32, i32>>::new_with_device(
310            &device,
311            TextEmbeddingOptions {
312                model: EmbeddingModel::MiniLmL6,
313                cache_dir: None,
314            },
315        )
316        .await
317        .expect("model should load");
318
319        let single = model
320            .embed("Hello world")
321            .expect("single embed should work");
322        assert!(!single.is_empty());
323    }
324
325    #[tokio::test]
326    async fn model_minilm_l6_embed_returns_vectors() {
327        let _guard = live_model_test_lock().lock().await;
328        let model = TextEmbedding::new(TextEmbeddingOptions {
329            model: EmbeddingModel::MiniLmL6,
330            ..Default::default()
331        })
332        .await
333        .expect("model should load");
334
335        let single = model
336            .embed("Hello world")
337            .expect("single embed should work");
338        assert!(!single.is_empty());
339
340        let batch = model
341            .embed_batch(&["Hello world", "Rust embeddings"], None)
342            .expect("batch embed should work");
343        assert_eq!(batch.len(), 2);
344        assert!(batch.iter().all(|embedding| !embedding.is_empty()));
345    }
346
347    #[tokio::test]
348    async fn parity_bge_base_document_matches_sentence_transformers() {
349        assert_model_matches_sentence_transformers(
350            EmbeddingModel::BgeBaseEnV15,
351            "BAAI/bge-base-en-v1.5",
352            ReferenceInputKind::Document,
353        )
354        .await;
355    }
356
357    #[tokio::test]
358    async fn parity_bge_base_query_matches_sentence_transformers() {
359        assert_model_matches_sentence_transformers(
360            EmbeddingModel::BgeBaseEnV15,
361            "BAAI/bge-base-en-v1.5",
362            ReferenceInputKind::Query,
363        )
364        .await;
365    }
366
367    #[tokio::test]
368    async fn parity_bge_small_document_matches_sentence_transformers() {
369        assert_model_matches_sentence_transformers(
370            EmbeddingModel::BgeSmallEnV15,
371            "BAAI/bge-small-en-v1.5",
372            ReferenceInputKind::Document,
373        )
374        .await;
375    }
376
377    #[tokio::test]
378    async fn parity_bge_small_query_matches_sentence_transformers() {
379        assert_model_matches_sentence_transformers(
380            EmbeddingModel::BgeSmallEnV15,
381            "BAAI/bge-small-en-v1.5",
382            ReferenceInputKind::Query,
383        )
384        .await;
385    }
386
387    #[tokio::test]
388    async fn parity_minilm_l12_document_matches_sentence_transformers() {
389        assert_model_matches_sentence_transformers(
390            EmbeddingModel::MiniLmL12,
391            "sentence-transformers/all-MiniLM-L12-v2",
392            ReferenceInputKind::Document,
393        )
394        .await;
395    }
396
397    #[tokio::test]
398    async fn parity_minilm_l12_query_matches_sentence_transformers() {
399        assert_model_matches_sentence_transformers(
400            EmbeddingModel::MiniLmL12,
401            "sentence-transformers/all-MiniLM-L12-v2",
402            ReferenceInputKind::Query,
403        )
404        .await;
405    }
406
407    #[tokio::test]
408    async fn parity_minilm_l6_document_matches_sentence_transformers() {
409        assert_model_matches_sentence_transformers(
410            EmbeddingModel::MiniLmL6,
411            "sentence-transformers/all-MiniLM-L6-v2",
412            ReferenceInputKind::Document,
413        )
414        .await;
415    }
416
417    #[tokio::test]
418    async fn parity_minilm_l6_query_matches_sentence_transformers() {
419        assert_model_matches_sentence_transformers(
420            EmbeddingModel::MiniLmL6,
421            "sentence-transformers/all-MiniLM-L6-v2",
422            ReferenceInputKind::Query,
423        )
424        .await;
425    }
426
427    #[test]
428    fn util_batch_size_default_caps_large_batches() {
429        let batch_size = batch_size_or_default(128, None)
430            .expect("default batch size should work");
431        assert_eq!(batch_size, DEFAULT_BATCH_SIZE);
432    }
433
434    #[test]
435    fn util_batch_size_default_uses_document_count_when_small() {
436        let batch_size = batch_size_or_default(4, None)
437            .expect("default batch size should work");
438        assert_eq!(batch_size, 4);
439    }
440
441    #[test]
442    fn util_batch_size_validate_rejects_zero() {
443        let error = batch_size_or_default(1, Some(0))
444            .expect_err("zero batch size should fail");
445        assert!(
446            error
447                .to_string()
448                .contains("batch size must be greater than zero")
449        );
450    }
451
452    #[test]
453    fn util_tensor_rows_extract_returns_rows() {
454        let device = WgpuDevice::default();
455        let embeddings = Tensor::<Wgpu<f32, i64>, 2>::from_floats(
456            [[1.0, 2.0], [3.0, 4.0]],
457            &device,
458        );
459
460        let rows = tensor_to_rows(embeddings).expect("rows should extract");
461        assert_eq!(rows, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
462    }
463
464    #[derive(Debug, Clone, Copy)]
465    enum ReferenceInputKind {
466        Document,
467        Query,
468    }
469
470    impl ReferenceInputKind {
471        fn as_str(self) -> &'static str {
472            match self {
473                Self::Document => "document",
474                Self::Query => "query",
475            }
476        }
477    }
478
479    async fn assert_model_matches_sentence_transformers(
480        model: EmbeddingModel,
481        reference_model: &str,
482        input_kind: ReferenceInputKind,
483    ) {
484        let _guard = live_model_test_lock().lock().await;
485        let texts =
486            vec!["Hello world".to_string(), "Rust embeddings".to_string()];
487        let model = TextEmbedding::new(TextEmbeddingOptions {
488            model,
489            ..Default::default()
490        })
491        .await
492        .expect("model should load");
493        let actual = match input_kind {
494            ReferenceInputKind::Document => model
495                .embed_batch(&texts, Some(2))
496                .expect("Burn document embeddings should work"),
497            ReferenceInputKind::Query => model
498                .embed_query_batch(&texts, Some(2))
499                .expect("Burn query embeddings should work"),
500        };
501        let expected =
502            reference_embeddings(reference_model, input_kind.as_str(), &texts)
503                .expect("reference embeddings should work");
504
505        assert_embedding_batches_close(&actual, &expected, 1e-3, 0.999);
506    }
507
508    fn live_model_test_lock() -> &'static Mutex<()> {
509        LIVE_MODEL_TEST_LOCK.get_or_init(|| Mutex::new(()))
510    }
511
512    fn reference_embeddings(
513        model: &str,
514        kind: &str,
515        texts: &[String],
516    ) -> Result<Vec<Vec<f32>>> {
517        use std::io::Write;
518        use std::process::{Command, Stdio};
519
520        let mut child = Command::new("uv")
521            .args([
522                "run",
523                "scripts/reference_embeddings.py",
524                "--model",
525                model,
526                "--kind",
527                kind,
528            ])
529            .stdin(Stdio::piped())
530            .stdout(Stdio::piped())
531            .stderr(Stdio::piped())
532            .spawn()
533            .context("failed to spawn uv reference embedding script")?;
534
535        let mut stdin = child
536            .stdin
537            .take()
538            .context("failed to open reference script stdin")?;
539        let input = serde_json::to_vec(texts)
540            .context("failed to serialize reference input")?;
541        stdin
542            .write_all(&input)
543            .context("failed to write reference input")?;
544        drop(stdin);
545
546        let output = child
547            .wait_with_output()
548            .context("failed to wait for reference script")?;
549        if !output.status.success() {
550            bail!(
551                "reference script failed: {}",
552                String::from_utf8_lossy(&output.stderr)
553            );
554        }
555
556        serde_json::from_slice(&output.stdout)
557            .context("failed to parse reference embeddings")
558    }
559
560    fn assert_embedding_batches_close(
561        actual: &[Vec<f32>],
562        expected: &[Vec<f32>],
563        tolerance: f32,
564        min_cosine_similarity: f32,
565    ) {
566        assert_eq!(actual.len(), expected.len());
567        for (actual, expected) in actual.iter().zip(expected) {
568            assert_eq!(actual.len(), expected.len());
569            let max_delta = actual
570                .iter()
571                .zip(expected)
572                .map(|(actual, expected)| (actual - expected).abs())
573                .fold(0.0f32, f32::max);
574            assert!(
575                max_delta <= tolerance,
576                "max embedding delta {max_delta} exceeded tolerance {tolerance}"
577            );
578            let cosine_similarity = cosine_similarity(actual, expected);
579            assert!(
580                cosine_similarity >= min_cosine_similarity,
581                "cosine similarity {cosine_similarity} fell below {min_cosine_similarity}"
582            );
583        }
584    }
585
586    fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
587        let dot_product = left
588            .iter()
589            .zip(right)
590            .map(|(left, right)| left * right)
591            .sum::<f32>();
592        let left_norm =
593            left.iter().map(|value| value * value).sum::<f32>().sqrt();
594        let right_norm =
595            right.iter().map(|value| value * value).sum::<f32>().sqrt();
596
597        dot_product / (left_norm * right_norm)
598    }
599}