coding_agent_search/daemon/
models.rs1use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use parking_lot::RwLock;
10use tracing::{info, warn};
11
12use crate::search::embedder::{Embedder, EmbedderError, EmbedderResult};
13use crate::search::fastembed_embedder::FastEmbedder;
14use crate::search::fastembed_reranker::FastEmbedReranker;
15use crate::search::hash_embedder::HashEmbedder;
16use crate::search::reranker::{Reranker, RerankerError, RerankerResult, rerank_texts};
17
18pub struct ModelManager {
20 data_dir: PathBuf,
21 embedder: RwLock<Option<Arc<dyn Embedder>>>,
22 reranker: RwLock<Option<Arc<dyn Reranker>>>,
23 embedder_name: RwLock<String>,
24 reranker_name: RwLock<String>,
25 fallback_embedder: Arc<HashEmbedder>,
26}
27
28impl ModelManager {
29 pub fn new(data_dir: &Path) -> Self {
31 Self {
32 data_dir: data_dir.to_path_buf(),
33 embedder: RwLock::new(None),
34 reranker: RwLock::new(None),
35 embedder_name: RwLock::new("not-loaded".to_string()),
36 reranker_name: RwLock::new("not-loaded".to_string()),
37 fallback_embedder: Arc::new(HashEmbedder::new(384)),
38 }
39 }
40
41 pub fn is_ready(&self) -> bool {
43 self.embedder.read().is_some()
44 }
45
46 pub fn embedder_id(&self) -> String {
48 self.embedder
49 .read()
50 .as_ref()
51 .map(|e| e.id().to_string())
52 .unwrap_or_else(|| "hash-384".to_string())
53 }
54
55 pub fn embedder_name(&self) -> String {
57 self.embedder_name.read().clone()
58 }
59
60 pub fn embedder_dimension(&self) -> usize {
62 self.embedder
63 .read()
64 .as_ref()
65 .map(|e| e.dimension())
66 .unwrap_or(384)
67 }
68
69 pub fn embedder_loaded(&self) -> bool {
71 self.embedder.read().is_some()
72 }
73
74 pub fn reranker_id(&self) -> String {
76 self.reranker
77 .read()
78 .as_ref()
79 .map(|r| r.id().to_string())
80 .unwrap_or_else(|| "none".to_string())
81 }
82
83 pub fn reranker_name(&self) -> String {
85 self.reranker_name.read().clone()
86 }
87
88 pub fn reranker_loaded(&self) -> bool {
90 self.reranker.read().is_some()
91 }
92
93 pub fn warm_embedder(&self) -> EmbedderResult<()> {
95 if self.embedder.read().is_some() {
97 return Ok(());
98 }
99
100 let mut embedder_guard = self.embedder.write();
102 if embedder_guard.is_some() {
103 return Ok(());
104 }
105
106 let model_dir = FastEmbedder::default_model_dir(&self.data_dir);
107 info!(model_dir = %model_dir.display(), "Loading embedder");
108
109 match FastEmbedder::load_from_dir(&model_dir) {
110 Ok(embedder) => {
111 let id = embedder.id().to_string();
112 let dimension = embedder.dimension();
113 *embedder_guard = Some(Arc::new(embedder));
114 *self.embedder_name.write() = "MiniLM-L6-v2".to_string();
115 info!(id = %id, dimension = dimension, "Embedder loaded");
116 Ok(())
117 }
118 Err(e) => {
119 warn!(error = %e, "Failed to load embedder, using hash fallback");
120 *embedder_guard = Some(self.fallback_embedder.clone());
121 *self.embedder_name.write() = "hash-fallback".to_string();
122 Ok(())
124 }
125 }
126 }
127
128 pub fn warm_reranker(&self) -> RerankerResult<()> {
130 if self.reranker.read().is_some() {
132 return Ok(());
133 }
134
135 let mut reranker_guard = self.reranker.write();
137 if reranker_guard.is_some() {
138 return Ok(());
139 }
140
141 let model_dir = FastEmbedReranker::default_model_dir(&self.data_dir);
142 info!(model_dir = %model_dir.display(), "Loading reranker");
143
144 match FastEmbedReranker::load_from_dir(&model_dir) {
145 Ok(reranker) => {
146 let id = reranker.id().to_string();
147 *reranker_guard = Some(Arc::new(reranker));
148 *self.reranker_name.write() = "ms-marco-MiniLM-L-6-v2".to_string();
149 info!(id = %id, "Reranker loaded");
150 Ok(())
151 }
152 Err(e) => {
153 warn!(error = %e, "Failed to load reranker, reranking unavailable");
154 Err(e)
155 }
156 }
157 }
158
159 pub fn embed_batch(&self, texts: &[String]) -> EmbedderResult<Vec<Vec<f32>>> {
161 if self.embedder.read().is_none() {
163 self.warm_embedder()?;
164 }
165
166 let embedder = self.embedder.read();
167 let embedder = embedder
168 .as_ref()
169 .ok_or_else(|| EmbedderError::EmbedderUnavailable {
170 model: "unknown".to_string(),
171 reason: "embedder not loaded".to_string(),
172 })?;
173
174 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
176 embedder.embed_batch_sync(&text_refs)
177 }
178
179 pub fn embed(&self, text: &str) -> EmbedderResult<Vec<f32>> {
181 if self.embedder.read().is_none() {
183 self.warm_embedder()?;
184 }
185
186 let embedder = self.embedder.read();
187 let embedder = embedder
188 .as_ref()
189 .ok_or_else(|| EmbedderError::EmbedderUnavailable {
190 model: "unknown".to_string(),
191 reason: "embedder not loaded".to_string(),
192 })?;
193
194 embedder.embed_sync(text)
195 }
196
197 pub fn rerank(&self, query: &str, documents: &[String]) -> RerankerResult<Vec<f32>> {
199 if self.reranker.read().is_none() {
201 self.warm_reranker()?;
202 }
203
204 let reranker = self.reranker.read();
205 let reranker = reranker
206 .as_ref()
207 .ok_or_else(|| RerankerError::RerankerUnavailable {
208 model: "reranker".to_string(),
209 })?;
210
211 let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
213 rerank_texts(&**reranker, query, &doc_refs)
214 }
215
216 pub fn unload_all(&self) {
218 *self.embedder.write() = None;
219 *self.reranker.write() = None;
220 *self.embedder_name.write() = "not-loaded".to_string();
221 *self.reranker_name.write() = "not-loaded".to_string();
222 info!("All models unloaded");
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn test_data_dir() -> PathBuf {
231 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
232 }
233
234 #[allow(dead_code)]
235 fn model_fixture_dir() -> PathBuf {
236 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
237 .join("tests/fixtures/models/xenova-paraphrase-minilm-l3-v2-int8")
238 }
239
240 #[test]
241 fn test_model_manager_creation() {
242 let manager = ModelManager::new(&test_data_dir());
243 assert!(!manager.is_ready());
244 assert!(!manager.embedder_loaded());
245 assert!(!manager.reranker_loaded());
246 }
247
248 #[test]
249 fn test_embedder_fallback_on_missing_model() {
250 let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
252
253 let result = manager.warm_embedder();
255 assert!(result.is_ok());
256
257 assert!(manager.embedder_loaded());
259 assert_eq!(manager.embedder_name(), "hash-fallback");
260 }
261
262 #[test]
263 fn test_embedder_dimension() {
264 let manager = ModelManager::new(&test_data_dir());
265 assert_eq!(manager.embedder_dimension(), 384);
267 }
268
269 #[test]
270 fn test_unload_all() {
271 let manager = ModelManager::new(&test_data_dir());
272 let _ = manager.warm_embedder();
273
274 assert!(manager.embedder_loaded());
275
276 manager.unload_all();
277
278 assert!(!manager.embedder_loaded());
279 assert!(!manager.reranker_loaded());
280 }
281
282 #[test]
283 fn test_embed_with_fallback() {
284 let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
285
286 let result = manager.embed("test text");
288 assert!(result.is_ok());
289
290 let embedding = result.unwrap();
291 assert_eq!(embedding.len(), 384);
292 }
293}