use std::sync::Arc;
use async_trait::async_trait;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use solo_core::{Embedder, Embedding, EmbeddingDtype, Error, Result};
use tokio::sync::{Mutex, OnceCell};
pub const BUNDLED_EMBEDDER_NAME: &str = "bundled:all-MiniLM-L6-v2";
pub const BUNDLED_EMBEDDER_VERSION: &str = "v1";
pub const BUNDLED_EMBEDDER_DIM: usize = 384;
#[derive(Clone)]
pub struct BundledEmbedder {
model: Arc<OnceCell<Arc<Mutex<TextEmbedding>>>>,
}
impl Default for BundledEmbedder {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for BundledEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BundledEmbedder")
.field("name", &BUNDLED_EMBEDDER_NAME)
.field("dim", &BUNDLED_EMBEDDER_DIM)
.field("loaded", &self.model.get().is_some())
.finish()
}
}
impl BundledEmbedder {
pub fn new() -> Self {
Self {
model: Arc::new(OnceCell::new()),
}
}
pub async fn try_new(&self) -> Result<Arc<Mutex<TextEmbedding>>> {
self.ensure_model().await.cloned()
}
async fn ensure_model(&self) -> Result<&Arc<Mutex<TextEmbedding>>> {
self.model
.get_or_try_init(|| async {
tracing::info!(
model = BUNDLED_EMBEDDER_NAME,
"loading bundled embedder (first-use download from hf-hub if not cached)"
);
let model = tokio::task::spawn_blocking(|| {
TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::AllMiniLML6V2)
.with_show_download_progress(false),
)
})
.await
.map_err(|e| {
Error::embedder(format!(
"bundled embedder init task panicked or was cancelled: {e}"
))
})?
.map_err(|e| {
Error::embedder(format!(
"bundled embedder init failed (fastembed/ort): {e}. \
Fall back to SOLO_EMBEDDER=ollama or set [embedder] \
name = \"stub\" in solo.config.toml."
))
})?;
Ok(Arc::new(Mutex::new(model)))
})
.await
}
}
#[async_trait]
impl Embedder for BundledEmbedder {
fn name(&self) -> &str {
BUNDLED_EMBEDDER_NAME
}
fn version(&self) -> &str {
BUNDLED_EMBEDDER_VERSION
}
fn dim(&self) -> usize {
BUNDLED_EMBEDDER_DIM
}
fn dtype(&self) -> EmbeddingDtype {
EmbeddingDtype::F32
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let model = self.ensure_model().await?.clone();
let owned: Vec<String> = texts.iter().map(|t| (*t).to_string()).collect();
let vectors: Vec<Vec<f32>> = tokio::task::spawn_blocking(move || {
let mut guard = model.blocking_lock();
guard.embed(&owned, None)
})
.await
.map_err(|e| {
Error::embedder(format!(
"bundled embedder inference task panicked or was cancelled: {e}"
))
})?
.map_err(|e| {
Error::embedder(format!("bundled embedder embed_batch failed: {e}"))
})?;
let mut out = Vec::with_capacity(vectors.len());
for v in vectors {
if v.len() != BUNDLED_EMBEDDER_DIM {
return Err(Error::embedder(format!(
"bundled embedder returned dim {} (expected {})",
v.len(),
BUNDLED_EMBEDDER_DIM
)));
}
let mut data = Vec::with_capacity(v.len() * 4);
for f in &v {
data.extend_from_slice(&f.to_le_bytes());
}
out.push(Embedding {
dtype: EmbeddingDtype::F32,
dim: BUNDLED_EMBEDDER_DIM,
data,
});
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::OnceLock;
fn shared() -> &'static BundledEmbedder {
static SHARED: OnceLock<BundledEmbedder> = OnceLock::new();
SHARED.get_or_init(BundledEmbedder::new)
}
fn cosine(a: &Embedding, b: &Embedding) -> f32 {
let av = a.as_f32_slice().expect("a is f32");
let bv = b.as_f32_slice().expect("b is f32");
assert_eq!(av.len(), bv.len(), "dim mismatch");
let dot: f32 = av.iter().zip(bv.iter()).map(|(x, y)| x * y).sum();
let na: f32 = av.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = bv.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (na * nb).max(1e-9)
}
#[tokio::test]
async fn bundled_embedder_produces_384_dim_vectors() {
let v = shared()
.embed("hello world")
.await
.expect("embed should succeed");
assert_eq!(v.dim, BUNDLED_EMBEDDER_DIM);
assert_eq!(v.dtype, EmbeddingDtype::F32);
assert_eq!(v.data.len(), BUNDLED_EMBEDDER_DIM * 4);
v.validate().expect("embedding length invariant");
}
#[tokio::test]
async fn bundled_embedder_emits_expected_identity() {
let e = BundledEmbedder::new();
assert_eq!(e.name(), "bundled:all-MiniLM-L6-v2");
assert_eq!(e.version(), "v1");
assert_eq!(e.dim(), 384);
assert_eq!(e.dtype(), EmbeddingDtype::F32);
}
#[tokio::test]
async fn bundled_embedder_is_deterministic_across_calls() {
let a = shared().embed("the quick brown fox").await.unwrap();
let b = shared().embed("the quick brown fox").await.unwrap();
assert_eq!(a.data, b.data, "same input must produce identical bytes");
}
#[tokio::test]
async fn bundled_embedder_distinct_inputs_produce_distinct_vectors() {
let a = shared().embed("alpha").await.unwrap();
let b = shared().embed("beta").await.unwrap();
assert_ne!(a.data, b.data);
}
#[tokio::test]
async fn bundled_embedder_does_semantic_work() {
let a = shared().embed("the cat sat on the mat").await.unwrap();
let b = shared().embed("a feline rested on the rug").await.unwrap();
let c = shared()
.embed("Rust's borrow checker enforces aliasing rules")
.await
.unwrap();
let sim_ab = cosine(&a, &b);
let sim_ac = cosine(&a, &c);
assert!(
sim_ab > sim_ac,
"semantically similar pair (cat/feline) should beat dissimilar \
(cat/Rust): sim_ab={sim_ab} sim_ac={sim_ac}"
);
assert!(sim_ab > 0.0, "semantic similarity should be positive");
}
#[tokio::test]
async fn bundled_embedder_handles_utf8_multi_byte() {
let v = shared()
.embed("こんにちは 🦀 مرحبا")
.await
.expect("multi-byte UTF-8 must embed cleanly");
assert_eq!(v.dim, BUNDLED_EMBEDDER_DIM);
v.validate().unwrap();
}
#[tokio::test]
async fn bundled_embedder_empty_input_returns_empty_batch() {
let out = shared().embed_batch(&[]).await.unwrap();
assert_eq!(out.len(), 0);
}
#[tokio::test]
async fn bundled_embedder_empty_string_returns_valid_vector() {
let v = shared()
.embed("")
.await
.expect("empty string is valid input");
assert_eq!(v.dim, BUNDLED_EMBEDDER_DIM);
v.validate().unwrap();
}
#[tokio::test]
async fn bundled_embedder_batch_preserves_input_order() {
let inputs = ["one", "two", "three", "four"];
let batch = shared().embed_batch(&inputs).await.unwrap();
assert_eq!(batch.len(), inputs.len());
for (i, text) in inputs.iter().enumerate() {
let single = shared().embed(text).await.unwrap();
assert_eq!(batch[i].data, single.data, "batch[{i}] != single({text})");
}
}
#[tokio::test]
async fn bundled_embedder_concurrent_calls_do_not_deadlock() {
let mut handles = Vec::new();
for i in 0..8 {
handles.push(tokio::spawn(async move {
let text = format!("concurrent call number {i}");
shared().embed(&text).await
}));
}
for h in handles {
let v = h.await.expect("join").expect("embed");
assert_eq!(v.dim, BUNDLED_EMBEDDER_DIM);
}
}
#[tokio::test]
async fn bundled_embedder_is_lazy_at_construction() {
let e = BundledEmbedder::new();
assert!(
e.model.get().is_none(),
"OnceCell must be empty before any embed/try_new call"
);
}
#[tokio::test]
async fn bundled_embedder_try_new_loads_eagerly_and_is_idempotent() {
let model1 = shared().try_new().await.expect("eager init");
let model2 = shared().try_new().await.expect("second eager init");
assert!(Arc::ptr_eq(&model1, &model2), "try_new should be idempotent");
}
#[tokio::test]
async fn bundled_embedder_normalised_or_valid_floats() {
let v = shared().embed("finite floats only").await.unwrap();
let slice = v.as_f32_slice().unwrap();
for (i, f) in slice.iter().enumerate() {
assert!(
f.is_finite(),
"non-finite component at index {i}: {f}"
);
}
}
}