Skip to main content

codesearch/embed/
mod.rs

1mod batch;
2mod cache;
3mod embedder;
4
5pub use batch::{BatchEmbedder, EmbeddedChunk};
6pub use cache::{CacheStats, CachedBatchEmbedder, QueryCache, QueryCacheStats};
7pub use embedder::{FastEmbedder, ModelType};
8
9use anyhow::Result;
10use std::env;
11use std::sync::{Arc, Mutex};
12
13/// High-level embedding service that combines all features
14pub struct EmbeddingService {
15    cached_embedder: CachedBatchEmbedder,
16    model_type: ModelType,
17    query_cache: QueryCache,
18}
19
20impl EmbeddingService {
21    /// Create a new embedding service with default model
22    pub fn new() -> Result<Self> {
23        Self::with_model(ModelType::default())
24    }
25
26    /// Create a new embedding service with specified model
27    pub fn with_model(model_type: ModelType) -> Result<Self> {
28        Self::with_cache_dir(model_type, None)
29    }
30
31    /// Create a new embedding service with specified model and cache directory
32    pub fn with_cache_dir(
33        model_type: ModelType,
34        cache_dir: Option<&std::path::Path>,
35    ) -> Result<Self> {
36        let embedder = FastEmbedder::with_cache_dir(model_type, cache_dir)?;
37        let arc_embedder = Arc::new(Mutex::new(embedder));
38        let batch_embedder = BatchEmbedder::new(arc_embedder);
39
40        // Get cache memory limit from environment variable
41        let cache_limit_mb = env::var("CODESEARCH_CACHE_MAX_MEMORY")
42            .ok()
43            .and_then(|s| s.parse().ok())
44            .unwrap_or(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB);
45
46        let cached_embedder =
47            CachedBatchEmbedder::with_memory_limit(batch_embedder, cache_limit_mb);
48
49        // Initialize query cache (separate from chunk cache)
50        let query_cache = QueryCache::new();
51
52        Ok(Self {
53            cached_embedder,
54            model_type,
55            query_cache,
56        })
57    }
58
59    /// Embed a batch of chunks with caching
60    pub fn embed_chunks(
61        &mut self,
62        chunks: Vec<crate::chunker::Chunk>,
63    ) -> Result<Vec<EmbeddedChunk>> {
64        self.cached_embedder.embed_chunks(chunks)
65    }
66
67    /// Embed query text (with caching)
68    pub fn embed_query(&mut self, query: &str) -> Result<Vec<f32>> {
69        // Check query cache first
70        if let Some(cached) = self.query_cache.get(query) {
71            return Ok(cached);
72        }
73
74        // Cache miss - embed the query
75        let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
76        let embedding = embedder_arc
77            .lock()
78            .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
79            .embed_one(query)?;
80
81        // Store in cache
82        self.query_cache.put(query, embedding.clone());
83
84        Ok(embedding)
85    }
86
87    /// Batch embed multiple query texts with caching (single ONNX call for misses)
88    pub fn embed_queries_batch(&mut self, queries: &[String]) -> Result<Vec<Vec<f32>>> {
89        if queries.is_empty() {
90            return Ok(Vec::new());
91        }
92
93        let total = queries.len();
94        let mut results = Vec::with_capacity(total);
95        let mut queries_to_embed = Vec::new();
96        let mut cache_indices = Vec::new();
97
98        // Check cache first
99        for (idx, query) in queries.iter().enumerate() {
100            if let Some(cached) = self.query_cache.get(query) {
101                results.push(cached);
102            } else {
103                queries_to_embed.push(query.clone());
104                cache_indices.push(idx);
105            }
106        }
107
108        // Batch embed remaining queries (single ONNX call)
109        if !queries_to_embed.is_empty() {
110            // Clone once before passing to embed_batch (which takes ownership)
111            let queries_for_caching = queries_to_embed.clone();
112            let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
113            let mut embedder = embedder_arc
114                .lock()
115                .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?;
116
117            let new_embeddings = embedder.embed_batch(queries_to_embed)?;
118
119            // Store in cache and add to results
120            for (i, embedding) in new_embeddings.into_iter().enumerate() {
121                self.query_cache
122                    .put(&queries_for_caching[i], embedding.clone());
123
124                // Place at correct position
125                results.insert(cache_indices[i], embedding);
126            }
127        }
128
129        Ok(results)
130    }
131
132    /// Get embedding dimensions
133    pub fn dimensions(&self) -> usize {
134        self.cached_embedder.dimensions()
135    }
136
137    /// Get model information
138    pub fn model_name(&self) -> &str {
139        self.model_type.name()
140    }
141
142    /// Get model short name (for storage)
143    pub fn model_short_name(&self) -> &str {
144        self.model_type.short_name()
145    }
146
147    /// Get cache statistics
148    #[allow(dead_code)] // Part of public API for debugging/monitoring
149    pub fn cache_stats(&self) -> CacheStats {
150        self.cached_embedder.cache_stats()
151    }
152
153    /// Get query cache statistics
154    #[allow(dead_code)] // Part of public API for debugging/monitoring
155    pub fn query_cache_stats(&self) -> QueryCacheStats {
156        self.query_cache.stats()
157    }
158}
159
160impl Default for EmbeddingService {
161    fn default() -> Self {
162        Self::new().expect("Failed to create default embedding service")
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_model_type_default() {
172        let model = ModelType::default();
173        assert_eq!(model.dimensions(), 384);
174    }
175
176    #[test]
177    #[ignore] // Requires model download
178    fn test_embedding_service_creation() {
179        let service = EmbeddingService::new();
180        assert!(service.is_ok());
181
182        let service = service.unwrap();
183        assert_eq!(service.dimensions(), 384);
184    }
185
186    #[test]
187    #[ignore] // Requires model
188    fn test_embed_query() {
189        let mut service = EmbeddingService::new().unwrap();
190        let query_embedding = service.embed_query("find authentication code").unwrap();
191
192        assert_eq!(query_embedding.len(), 384);
193    }
194
195    #[test]
196    #[ignore] // search method not implemented - uses VectorStore instead
197    fn test_embed_and_search() {
198        // EmbeddingService no longer has search - VectorStore handles searching
199        // Test kept for documentation purposes
200    }
201
202    #[test]
203    #[ignore] // search method not implemented - uses VectorStore instead
204    fn test_search() {
205        // EmbeddingService no longer has search - VectorStore handles searching
206        // Test kept for documentation purposes
207    }
208}