Skip to main content

coding_agent_search/daemon/
models.rs

1//! Model manager for lazy loading embedder and reranker models.
2//!
3//! This module provides lazy-loaded access to embedding and reranking models,
4//! supporting graceful fallback when models are unavailable.
5
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use parking_lot::RwLock;
10use tracing::{info, warn};
11
12use crate::search::embedder::{Embedder, EmbedderError, EmbedderResult};
13use crate::search::fastembed_embedder::FastEmbedder;
14use crate::search::fastembed_reranker::FastEmbedReranker;
15use crate::search::hash_embedder::HashEmbedder;
16use crate::search::reranker::{Reranker, RerankerError, RerankerResult, rerank_texts};
17
18/// Model manager that handles lazy loading of embedder and reranker models.
19pub struct ModelManager {
20    data_dir: PathBuf,
21    embedder: RwLock<Option<Arc<dyn Embedder>>>,
22    reranker: RwLock<Option<Arc<dyn Reranker>>>,
23    embedder_name: RwLock<String>,
24    reranker_name: RwLock<String>,
25    fallback_embedder: Arc<HashEmbedder>,
26}
27
28impl ModelManager {
29    /// Create a new model manager with the given data directory.
30    pub fn new(data_dir: &Path) -> Self {
31        Self {
32            data_dir: data_dir.to_path_buf(),
33            embedder: RwLock::new(None),
34            reranker: RwLock::new(None),
35            embedder_name: RwLock::new("not-loaded".to_string()),
36            reranker_name: RwLock::new("not-loaded".to_string()),
37            fallback_embedder: Arc::new(HashEmbedder::new(384)),
38        }
39    }
40
41    /// Check if any model is loaded and ready.
42    pub fn is_ready(&self) -> bool {
43        self.embedder.read().is_some()
44    }
45
46    /// Get the embedder ID.
47    pub fn embedder_id(&self) -> String {
48        self.embedder
49            .read()
50            .as_ref()
51            .map(|e| e.id().to_string())
52            .unwrap_or_else(|| "hash-384".to_string())
53    }
54
55    /// Get the embedder name.
56    pub fn embedder_name(&self) -> String {
57        self.embedder_name.read().clone()
58    }
59
60    /// Get the embedder dimension.
61    pub fn embedder_dimension(&self) -> usize {
62        self.embedder
63            .read()
64            .as_ref()
65            .map(|e| e.dimension())
66            .unwrap_or(384)
67    }
68
69    /// Check if embedder is loaded.
70    pub fn embedder_loaded(&self) -> bool {
71        self.embedder.read().is_some()
72    }
73
74    /// Get the reranker ID.
75    pub fn reranker_id(&self) -> String {
76        self.reranker
77            .read()
78            .as_ref()
79            .map(|r| r.id().to_string())
80            .unwrap_or_else(|| "none".to_string())
81    }
82
83    /// Get the reranker name.
84    pub fn reranker_name(&self) -> String {
85        self.reranker_name.read().clone()
86    }
87
88    /// Check if reranker is loaded.
89    pub fn reranker_loaded(&self) -> bool {
90        self.reranker.read().is_some()
91    }
92
93    /// Pre-warm the embedder by loading it.
94    pub fn warm_embedder(&self) -> EmbedderResult<()> {
95        // Fast path: already loaded
96        if self.embedder.read().is_some() {
97            return Ok(());
98        }
99
100        // Slow path: need to load. Take write lock and check again.
101        let mut embedder_guard = self.embedder.write();
102        if embedder_guard.is_some() {
103            return Ok(());
104        }
105
106        let model_dir = FastEmbedder::default_model_dir(&self.data_dir);
107        info!(model_dir = %model_dir.display(), "Loading embedder");
108
109        match FastEmbedder::load_from_dir(&model_dir) {
110            Ok(embedder) => {
111                let id = embedder.id().to_string();
112                let dimension = embedder.dimension();
113                *embedder_guard = Some(Arc::new(embedder));
114                *self.embedder_name.write() = "MiniLM-L6-v2".to_string();
115                info!(id = %id, dimension = dimension, "Embedder loaded");
116                Ok(())
117            }
118            Err(e) => {
119                warn!(error = %e, "Failed to load embedder, using hash fallback");
120                *embedder_guard = Some(self.fallback_embedder.clone());
121                *self.embedder_name.write() = "hash-fallback".to_string();
122                // Return Ok since we have a fallback
123                Ok(())
124            }
125        }
126    }
127
128    /// Pre-warm the reranker by loading it.
129    pub fn warm_reranker(&self) -> RerankerResult<()> {
130        // Fast path: already loaded
131        if self.reranker.read().is_some() {
132            return Ok(());
133        }
134
135        // Slow path: need to load. Take write lock and check again.
136        let mut reranker_guard = self.reranker.write();
137        if reranker_guard.is_some() {
138            return Ok(());
139        }
140
141        let model_dir = FastEmbedReranker::default_model_dir(&self.data_dir);
142        info!(model_dir = %model_dir.display(), "Loading reranker");
143
144        match FastEmbedReranker::load_from_dir(&model_dir) {
145            Ok(reranker) => {
146                let id = reranker.id().to_string();
147                *reranker_guard = Some(Arc::new(reranker));
148                *self.reranker_name.write() = "ms-marco-MiniLM-L-6-v2".to_string();
149                info!(id = %id, "Reranker loaded");
150                Ok(())
151            }
152            Err(e) => {
153                warn!(error = %e, "Failed to load reranker, reranking unavailable");
154                Err(e)
155            }
156        }
157    }
158
159    /// Embed a batch of texts.
160    pub fn embed_batch(&self, texts: &[String]) -> EmbedderResult<Vec<Vec<f32>>> {
161        // Ensure embedder is loaded
162        if self.embedder.read().is_none() {
163            self.warm_embedder()?;
164        }
165
166        let embedder = self.embedder.read();
167        let embedder = embedder
168            .as_ref()
169            .ok_or_else(|| EmbedderError::EmbedderUnavailable {
170                model: "unknown".to_string(),
171                reason: "embedder not loaded".to_string(),
172            })?;
173
174        // Convert to &str slice for the batch call
175        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
176        embedder.embed_batch_sync(&text_refs)
177    }
178
179    /// Embed a single text.
180    pub fn embed(&self, text: &str) -> EmbedderResult<Vec<f32>> {
181        // Ensure embedder is loaded
182        if self.embedder.read().is_none() {
183            self.warm_embedder()?;
184        }
185
186        let embedder = self.embedder.read();
187        let embedder = embedder
188            .as_ref()
189            .ok_or_else(|| EmbedderError::EmbedderUnavailable {
190                model: "unknown".to_string(),
191                reason: "embedder not loaded".to_string(),
192            })?;
193
194        embedder.embed_sync(text)
195    }
196
197    /// Rerank documents against a query.
198    pub fn rerank(&self, query: &str, documents: &[String]) -> RerankerResult<Vec<f32>> {
199        // Ensure reranker is loaded
200        if self.reranker.read().is_none() {
201            self.warm_reranker()?;
202        }
203
204        let reranker = self.reranker.read();
205        let reranker = reranker
206            .as_ref()
207            .ok_or_else(|| RerankerError::RerankerUnavailable {
208                model: "reranker".to_string(),
209            })?;
210
211        // Convert to &str slice and use rerank_texts bridge
212        let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
213        rerank_texts(&**reranker, query, &doc_refs)
214    }
215
216    /// Unload all models to free memory.
217    pub fn unload_all(&self) {
218        *self.embedder.write() = None;
219        *self.reranker.write() = None;
220        *self.embedder_name.write() = "not-loaded".to_string();
221        *self.reranker_name.write() = "not-loaded".to_string();
222        info!("All models unloaded");
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn test_data_dir() -> PathBuf {
231        PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
232    }
233
234    #[allow(dead_code)]
235    fn model_fixture_dir() -> PathBuf {
236        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
237            .join("tests/fixtures/models/xenova-paraphrase-minilm-l3-v2-int8")
238    }
239
240    #[test]
241    fn test_model_manager_creation() {
242        let manager = ModelManager::new(&test_data_dir());
243        assert!(!manager.is_ready());
244        assert!(!manager.embedder_loaded());
245        assert!(!manager.reranker_loaded());
246    }
247
248    #[test]
249    fn test_embedder_fallback_on_missing_model() {
250        // Use a directory without models
251        let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
252
253        // Should succeed with fallback
254        let result = manager.warm_embedder();
255        assert!(result.is_ok());
256
257        // Should be using hash fallback
258        assert!(manager.embedder_loaded());
259        assert_eq!(manager.embedder_name(), "hash-fallback");
260    }
261
262    #[test]
263    fn test_embedder_dimension() {
264        let manager = ModelManager::new(&test_data_dir());
265        // Before loading, should return default dimension
266        assert_eq!(manager.embedder_dimension(), 384);
267    }
268
269    #[test]
270    fn test_unload_all() {
271        let manager = ModelManager::new(&test_data_dir());
272        let _ = manager.warm_embedder();
273
274        assert!(manager.embedder_loaded());
275
276        manager.unload_all();
277
278        assert!(!manager.embedder_loaded());
279        assert!(!manager.reranker_loaded());
280    }
281
282    #[test]
283    fn test_embed_with_fallback() {
284        let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
285
286        // Should work with fallback
287        let result = manager.embed("test text");
288        assert!(result.is_ok());
289
290        let embedding = result.unwrap();
291        assert_eq!(embedding.len(), 384);
292    }
293}