use anyhow::{Context, Result};
use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
use super::{model_cache_dir, EMBEDDING_DIM};
const EMBED_BATCH_SIZE: usize = 64;
pub struct EmbeddingEngine {
model: TextEmbedding,
}
impl EmbeddingEngine {
pub fn new() -> Result<Self> {
let model = TextEmbedding::try_new(
TextInitOptions::new(EmbeddingModel::BGESmallENV15Q)
.with_cache_dir(model_cache_dir())
.with_show_download_progress(false),
)
.context("Failed to initialize embedding model")?;
Ok(Self { model })
}
pub fn new_with_progress() -> Result<Self> {
let model = TextEmbedding::try_new(
TextInitOptions::new(EmbeddingModel::BGESmallENV15Q)
.with_cache_dir(model_cache_dir())
.with_show_download_progress(true),
)
.context("Failed to initialize embedding model")?;
Ok(Self { model })
}
pub fn embed(&mut self, text: &str) -> Result<Vec<f32>> {
let results = self
.model
.embed(vec![text], Some(1))
.context("Embedding failed")?;
let vec = results
.into_iter()
.next()
.context("No embedding returned")?;
debug_assert_eq!(
vec.len(),
EMBEDDING_DIM,
"Expected {EMBEDDING_DIM}-dim embedding, got {}",
vec.len()
);
Ok(vec)
}
pub fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let results = self
.model
.embed(texts, Some(EMBED_BATCH_SIZE))
.context("Batch embedding failed")?;
debug_assert!(
results.iter().all(|v| v.len() == EMBEDDING_DIM),
"All embeddings should be {EMBEDDING_DIM}-dim"
);
Ok(results)
}
}
pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(embedding.len() * 4);
for &val in embedding {
bytes.extend_from_slice(&val.to_le_bytes());
}
bytes
}
#[allow(dead_code)]
pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_roundtrip() {
let original = vec![0.1_f32, -0.5, 1.0, 0.0, std::f32::consts::PI];
let bytes = embedding_to_bytes(&original);
let restored = bytes_to_embedding(&bytes);
assert_eq!(original, restored);
}
#[test]
fn test_embedding_byte_length() {
let vec = vec![0.0_f32; EMBEDDING_DIM];
let bytes = embedding_to_bytes(&vec);
assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
}
#[test]
fn test_empty_bytes_roundtrip() {
let original: Vec<f32> = vec![];
let bytes = embedding_to_bytes(&original);
assert!(bytes.is_empty());
let restored = bytes_to_embedding(&bytes);
assert!(restored.is_empty());
}
#[test]
fn test_engine_embed_dimension() {
let mut engine = match EmbeddingEngine::new() {
Ok(e) => e,
Err(_) => return, };
let vec = engine.embed("fn validate_token(token: &str)").unwrap();
assert_eq!(vec.len(), EMBEDDING_DIM);
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"embedding should be L2-normalized, got norm={norm}"
);
}
#[test]
fn test_engine_embed_batch() {
let mut engine = match EmbeddingEngine::new() {
Ok(e) => e,
Err(_) => return, };
let texts = [
"fn foo() -> i32",
"class AuthService",
"def validate(token)",
];
let results = engine.embed_batch(&texts).unwrap();
assert_eq!(results.len(), 3);
for v in &results {
assert_eq!(v.len(), EMBEDDING_DIM);
}
}
#[test]
fn test_engine_embed_batch_empty() {
let mut engine = match EmbeddingEngine::new() {
Ok(e) => e,
Err(_) => return,
};
let texts: &[&str] = &[];
let results = engine.embed_batch(texts).unwrap();
assert!(results.is_empty());
}
}