Skip to main content

kyma_embed/
fastembed.rs

1use crate::{EmbedError, EmbeddingBackend};
2use async_trait::async_trait;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7/// ONNX-backed embedding via `fastembed-rs`. Loads the model at construction;
8/// inference runs on a tokio blocking thread per batch.
9///
10/// The inner `TextEmbedding` is held behind a mutex, which serializes
11/// `.embed()` calls across concurrent users of the same `FastembedBackend`.
12/// For schema-RAG (low QPS, small batches) this is fine. For bulk user-data
13/// embedding (Phase C+) we may want a pool of `TextEmbedding` instances or
14/// a lock-free inference path — revisit when that workload materializes.
15pub struct FastembedBackend {
16    id: String,
17    dimension: u16,
18    inner: Arc<Mutex<TextEmbedding>>,
19}
20
21impl std::fmt::Debug for FastembedBackend {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("FastembedBackend")
24            .field("id", &self.id)
25            .field("dimension", &self.dimension)
26            .finish()
27    }
28}
29
30impl FastembedBackend {
31    /// `model_id` is the short name (e.g., `"bge-small-en-v1.5"`).
32    /// `model_path` optionally points at a pre-downloaded ONNX dir for
33    /// air-gapped deployments (env `KYMA_EMBED_MODEL_PATH`).
34    pub async fn new(model_id: &str, model_path: Option<&str>) -> Result<Self, EmbedError> {
35        let em = pick_model(model_id)?;
36        let dimension = em_dimension(&em);
37        let mut opts = InitOptions::new(em);
38        if let Some(path) = model_path {
39            opts = opts.with_cache_dir(path.into());
40        }
41        let model = tokio::task::spawn_blocking(move || TextEmbedding::try_new(opts))
42            .await
43            .map_err(|e| EmbedError::ModelLoad(e.to_string()))?
44            .map_err(|e| EmbedError::ModelLoad(e.to_string()))?;
45        Ok(Self {
46            id: format!("fastembed/{model_id}"),
47            dimension,
48            inner: Arc::new(Mutex::new(model)),
49        })
50    }
51}
52
53#[async_trait]
54impl EmbeddingBackend for FastembedBackend {
55    fn id(&self) -> &str {
56        &self.id
57    }
58    fn dimension(&self) -> u16 {
59        self.dimension
60    }
61    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedError> {
62        if texts.is_empty() {
63            return Ok(vec![]);
64        }
65        let inner = self.inner.clone();
66        let owned: Vec<String> = texts.to_vec();
67        let dim = self.dimension;
68        tokio::task::spawn_blocking(move || {
69            let guard = inner.blocking_lock();
70            let vecs = guard
71                .embed(owned, None)
72                .map_err(|e| EmbedError::Request(e.to_string()))?;
73            for v in &vecs {
74                if v.len() != dim as usize {
75                    return Err(EmbedError::DimensionMismatch {
76                        got: v.len() as u16,
77                        expected: dim,
78                    });
79                }
80            }
81            Ok(vecs)
82        })
83        .await
84        .map_err(|e| EmbedError::Internal(e.to_string()))?
85    }
86}
87
88fn pick_model(id: &str) -> Result<EmbeddingModel, EmbedError> {
89    match id {
90        "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
91        "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
92        "all-MiniLM-L6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
93        other => Err(EmbedError::NotConfigured(format!(
94            "unknown fastembed model: {other}"
95        ))),
96    }
97}
98
99fn em_dimension(em: &EmbeddingModel) -> u16 {
100    match em {
101        EmbeddingModel::BGESmallENV15 => 384,
102        EmbeddingModel::BGEBaseENV15  => 768,
103        EmbeddingModel::AllMiniLML6V2 => 384,
104        other => unreachable!(
105            "em_dimension: pick_model accepted model {other:?} but em_dimension has no arm. Add the dimension here."),
106    }
107}