use std::sync::Arc;
use anyhow::Result;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use tokio::sync::Mutex;
#[allow(dead_code)]
pub const EMBEDDING_DIMENSIONS: usize = 384;
pub struct EmbeddingEngine {
inner: Arc<Mutex<TextEmbedding>>,
}
impl EmbeddingEngine {
pub fn try_new() -> Result<Self> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::BGESmallENV15).with_show_download_progress(true),
)?;
Ok(Self {
inner: Arc::new(Mutex::new(model)),
})
}
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let model = Arc::clone(&self.inner);
tokio::task::spawn_blocking(move || {
let mut guard = model.blocking_lock();
guard
.embed(texts, None)
.map_err(|e| anyhow::anyhow!("embedding failed: {}", e))
})
.await
.map_err(|e| anyhow::anyhow!("spawn_blocking join error: {}", e))?
}
}
pub async fn embed_symbols(
engine: &EmbeddingEngine,
symbols: &[(String, String, usize)],
) -> Result<Vec<Vec<f32>>> {
let texts: Vec<String> = symbols
.iter()
.map(|(name, path, line)| format!("{} in {}:{}", name, path, line))
.collect();
engine.embed_batch(texts).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires model download (~23MB); run explicitly with --ignored"]
fn embedding_engine_new() {
let engine = EmbeddingEngine::try_new();
assert!(engine.is_ok(), "EmbeddingEngine::try_new() should succeed");
}
#[tokio::test]
#[ignore = "requires model download (~23MB); run explicitly with --ignored"]
async fn embed_batch_produces_384_dim_vectors() {
let engine = EmbeddingEngine::try_new().expect("engine should initialize");
let texts = vec![
"authenticate_user in src/auth.rs:42".to_string(),
"UserService in src/services/user.ts:10".to_string(),
];
let embeddings = engine
.embed_batch(texts)
.await
.expect("embed_batch should succeed");
assert_eq!(embeddings.len(), 2, "should return one embedding per text");
for emb in &embeddings {
assert_eq!(
emb.len(),
EMBEDDING_DIMENSIONS,
"each embedding should have {} dimensions",
EMBEDDING_DIMENSIONS
);
}
}
#[tokio::test]
#[ignore = "requires model download (~23MB); run explicitly with --ignored"]
async fn embed_symbols_produces_correct_count() {
let engine = EmbeddingEngine::try_new().expect("engine should initialize");
let symbols = vec![
("my_fn".to_string(), "src/lib.rs".to_string(), 5usize),
("MyStruct".to_string(), "src/types.rs".to_string(), 20usize),
("run_loop".to_string(), "src/main.rs".to_string(), 100usize),
];
let embeddings = embed_symbols(&engine, &symbols)
.await
.expect("embed_symbols should succeed");
assert_eq!(
embeddings.len(),
symbols.len(),
"should return one embedding per symbol"
);
for emb in &embeddings {
assert_eq!(
emb.len(),
EMBEDDING_DIMENSIONS,
"each embedding should have {} dimensions",
EMBEDDING_DIMENSIONS
);
}
}
}