mod batch;
mod cache;
mod embedder;
pub use batch::{BatchEmbedder, EmbeddedChunk};
pub use cache::{CacheStats, CachedBatchEmbedder, QueryCache, QueryCacheStats};
pub use embedder::{FastEmbedder, ModelType};
use anyhow::Result;
use std::env;
use std::sync::{Arc, Mutex};
pub struct EmbeddingService {
cached_embedder: CachedBatchEmbedder,
model_type: ModelType,
query_cache: QueryCache,
}
impl EmbeddingService {
pub fn new() -> Result<Self> {
Self::with_model(ModelType::default())
}
pub fn with_model(model_type: ModelType) -> Result<Self> {
Self::with_cache_dir(model_type, None)
}
pub fn with_cache_dir(
model_type: ModelType,
cache_dir: Option<&std::path::Path>,
) -> Result<Self> {
let embedder = FastEmbedder::with_cache_dir(model_type, cache_dir)?;
let arc_embedder = Arc::new(Mutex::new(embedder));
let batch_embedder = BatchEmbedder::new(arc_embedder);
let cache_limit_mb = env::var("CODESEARCH_CACHE_MAX_MEMORY")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB);
let cached_embedder =
CachedBatchEmbedder::with_memory_limit(batch_embedder, cache_limit_mb);
let query_cache = QueryCache::new();
Ok(Self {
cached_embedder,
model_type,
query_cache,
})
}
pub fn embed_chunks(
&mut self,
chunks: Vec<crate::chunker::Chunk>,
) -> Result<Vec<EmbeddedChunk>> {
self.cached_embedder.embed_chunks(chunks)
}
pub fn embed_query(&mut self, query: &str) -> Result<Vec<f32>> {
if let Some(cached) = self.query_cache.get(query) {
return Ok(cached);
}
let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
let embedding = embedder_arc
.lock()
.map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
.embed_one(query)?;
self.query_cache.put(query, embedding.clone());
Ok(embedding)
}
pub fn embed_queries_batch(&mut self, queries: &[String]) -> Result<Vec<Vec<f32>>> {
if queries.is_empty() {
return Ok(Vec::new());
}
let total = queries.len();
let mut results = Vec::with_capacity(total);
let mut queries_to_embed = Vec::new();
let mut cache_indices = Vec::new();
for (idx, query) in queries.iter().enumerate() {
if let Some(cached) = self.query_cache.get(query) {
results.push(cached);
} else {
queries_to_embed.push(query.clone());
cache_indices.push(idx);
}
}
if !queries_to_embed.is_empty() {
let queries_for_caching = queries_to_embed.clone();
let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
let mut embedder = embedder_arc
.lock()
.map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?;
let new_embeddings = embedder.embed_batch(queries_to_embed)?;
for (i, embedding) in new_embeddings.into_iter().enumerate() {
self.query_cache
.put(&queries_for_caching[i], embedding.clone());
results.insert(cache_indices[i], embedding);
}
}
Ok(results)
}
pub fn dimensions(&self) -> usize {
self.cached_embedder.dimensions()
}
pub fn model_name(&self) -> &str {
self.model_type.name()
}
pub fn model_short_name(&self) -> &str {
self.model_type.short_name()
}
#[allow(dead_code)] pub fn cache_stats(&self) -> CacheStats {
self.cached_embedder.cache_stats()
}
#[allow(dead_code)] pub fn query_cache_stats(&self) -> QueryCacheStats {
self.query_cache.stats()
}
}
impl Default for EmbeddingService {
fn default() -> Self {
Self::new().expect("Failed to create default embedding service")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_type_default() {
let model = ModelType::default();
assert_eq!(model.dimensions(), 384);
}
#[test]
#[ignore] fn test_embedding_service_creation() {
let service = EmbeddingService::new();
assert!(service.is_ok());
let service = service.unwrap();
assert_eq!(service.dimensions(), 384);
}
#[test]
#[ignore] fn test_embed_query() {
let mut service = EmbeddingService::new().unwrap();
let query_embedding = service.embed_query("find authentication code").unwrap();
assert_eq!(query_embedding.len(), 384);
}
#[test]
#[ignore] fn test_embed_and_search() {
}
#[test]
#[ignore] fn test_search() {
}
}