1mod batch;
2mod cache;
3mod embedder;
4
5pub use batch::{BatchEmbedder, EmbeddedChunk};
6pub use cache::{CacheStats, CachedBatchEmbedder, QueryCache, QueryCacheStats};
7pub use embedder::{FastEmbedder, ModelType};
8
9use anyhow::Result;
10use std::env;
11use std::sync::{Arc, Mutex};
12
13pub struct EmbeddingService {
15 cached_embedder: CachedBatchEmbedder,
16 model_type: ModelType,
17 query_cache: QueryCache,
18}
19
20impl EmbeddingService {
21 pub fn new() -> Result<Self> {
23 Self::with_model(ModelType::default())
24 }
25
26 pub fn with_model(model_type: ModelType) -> Result<Self> {
28 Self::with_cache_dir(model_type, None)
29 }
30
31 pub fn with_cache_dir(
33 model_type: ModelType,
34 cache_dir: Option<&std::path::Path>,
35 ) -> Result<Self> {
36 let embedder = FastEmbedder::with_cache_dir(model_type, cache_dir)?;
37 let arc_embedder = Arc::new(Mutex::new(embedder));
38 let batch_embedder = BatchEmbedder::new(arc_embedder);
39
40 let cache_limit_mb = env::var("CODESEARCH_CACHE_MAX_MEMORY")
42 .ok()
43 .and_then(|s| s.parse().ok())
44 .unwrap_or(crate::constants::DEFAULT_CACHE_MAX_MEMORY_MB);
45
46 let cached_embedder =
47 CachedBatchEmbedder::with_memory_limit(batch_embedder, cache_limit_mb);
48
49 let query_cache = QueryCache::new();
51
52 Ok(Self {
53 cached_embedder,
54 model_type,
55 query_cache,
56 })
57 }
58
59 pub fn embed_chunks(
61 &mut self,
62 chunks: Vec<crate::chunker::Chunk>,
63 ) -> Result<Vec<EmbeddedChunk>> {
64 self.cached_embedder.embed_chunks(chunks)
65 }
66
67 pub fn embed_query(&mut self, query: &str) -> Result<Vec<f32>> {
69 if let Some(cached) = self.query_cache.get(query) {
71 return Ok(cached);
72 }
73
74 let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
76 let embedding = embedder_arc
77 .lock()
78 .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?
79 .embed_one(query)?;
80
81 self.query_cache.put(query, embedding.clone());
83
84 Ok(embedding)
85 }
86
87 pub fn embed_queries_batch(&mut self, queries: &[String]) -> Result<Vec<Vec<f32>>> {
89 if queries.is_empty() {
90 return Ok(Vec::new());
91 }
92
93 let total = queries.len();
94 let mut results = Vec::with_capacity(total);
95 let mut queries_to_embed = Vec::new();
96 let mut cache_indices = Vec::new();
97
98 for (idx, query) in queries.iter().enumerate() {
100 if let Some(cached) = self.query_cache.get(query) {
101 results.push(cached);
102 } else {
103 queries_to_embed.push(query.clone());
104 cache_indices.push(idx);
105 }
106 }
107
108 if !queries_to_embed.is_empty() {
110 let queries_for_caching = queries_to_embed.clone();
112 let embedder_arc = &self.cached_embedder.batch_embedder.embedder;
113 let mut embedder = embedder_arc
114 .lock()
115 .map_err(|e| anyhow::anyhow!("Embedder mutex poisoned: {}", e))?;
116
117 let new_embeddings = embedder.embed_batch(queries_to_embed)?;
118
119 for (i, embedding) in new_embeddings.into_iter().enumerate() {
121 self.query_cache
122 .put(&queries_for_caching[i], embedding.clone());
123
124 results.insert(cache_indices[i], embedding);
126 }
127 }
128
129 Ok(results)
130 }
131
132 pub fn dimensions(&self) -> usize {
134 self.cached_embedder.dimensions()
135 }
136
137 pub fn model_name(&self) -> &str {
139 self.model_type.name()
140 }
141
142 pub fn model_short_name(&self) -> &str {
144 self.model_type.short_name()
145 }
146
147 #[allow(dead_code)] pub fn cache_stats(&self) -> CacheStats {
150 self.cached_embedder.cache_stats()
151 }
152
153 #[allow(dead_code)] pub fn query_cache_stats(&self) -> QueryCacheStats {
156 self.query_cache.stats()
157 }
158}
159
160impl Default for EmbeddingService {
161 fn default() -> Self {
162 Self::new().expect("Failed to create default embedding service")
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_model_type_default() {
172 let model = ModelType::default();
173 assert_eq!(model.dimensions(), 384);
174 }
175
176 #[test]
177 #[ignore] fn test_embedding_service_creation() {
179 let service = EmbeddingService::new();
180 assert!(service.is_ok());
181
182 let service = service.unwrap();
183 assert_eq!(service.dimensions(), 384);
184 }
185
186 #[test]
187 #[ignore] fn test_embed_query() {
189 let mut service = EmbeddingService::new().unwrap();
190 let query_embedding = service.embed_query("find authentication code").unwrap();
191
192 assert_eq!(query_embedding.len(), 384);
193 }
194
195 #[test]
196 #[ignore] fn test_embed_and_search() {
198 }
201
202 #[test]
203 #[ignore] fn test_search() {
205 }
208}