1use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Mutex;
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14
15use clawft_core::embeddings::hnsw_store::HnswStore;
16
17use crate::health::HealthStatus;
18use crate::service::{ServiceType, SystemService};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct HnswServiceConfig {
25 pub ef_search: usize,
27 pub ef_construction: usize,
29 pub default_dimensions: usize,
31}
32
33impl Default for HnswServiceConfig {
34 fn default() -> Self {
35 Self {
36 ef_search: 100,
37 ef_construction: 200,
38 default_dimensions: 384,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
47pub struct HnswSearchResult {
48 pub id: String,
50 pub score: f32,
52 pub metadata: serde_json::Value,
54}
55
56pub struct HnswService {
64 store: Mutex<HnswStore>,
65 config: HnswServiceConfig,
66 insert_count: AtomicU64,
67 search_count: AtomicU64,
68}
69
70impl HnswService {
71 pub fn new(config: HnswServiceConfig) -> Self {
73 let store = HnswStore::with_params(config.ef_search, config.ef_construction);
74 Self {
75 store: Mutex::new(store),
76 config,
77 insert_count: AtomicU64::new(0),
78 search_count: AtomicU64::new(0),
79 }
80 }
81
82 pub fn insert(&self, id: String, embedding: Vec<f32>, metadata: serde_json::Value) {
84 let mut store = self.store.lock().expect("HnswStore lock poisoned");
85 store.insert(id, embedding, metadata);
86 self.insert_count.fetch_add(1, Ordering::Relaxed);
87 }
88
89 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<HnswSearchResult> {
91 let mut store = self.store.lock().expect("HnswStore lock poisoned");
92 self.search_count.fetch_add(1, Ordering::Relaxed);
93 store
94 .query(query, top_k)
95 .into_iter()
96 .map(|r| HnswSearchResult {
97 id: r.id,
98 score: r.score,
99 metadata: r.metadata,
100 })
101 .collect()
102 }
103
104 pub fn search_batch(
109 &self,
110 queries: &[&[f32]],
111 top_k: usize,
112 ) -> Vec<Vec<HnswSearchResult>> {
113 let mut store = self.store.lock().expect("HnswStore lock poisoned");
114 self.search_count
115 .fetch_add(queries.len() as u64, Ordering::Relaxed);
116 queries
117 .iter()
118 .map(|query| {
119 store
120 .query(query, top_k)
121 .into_iter()
122 .map(|r| HnswSearchResult {
123 id: r.id,
124 score: r.score,
125 metadata: r.metadata,
126 })
127 .collect()
128 })
129 .collect()
130 }
131
132 pub fn len(&self) -> usize {
134 let store = self.store.lock().expect("HnswStore lock poisoned");
135 store.len()
136 }
137
138 pub fn is_empty(&self) -> bool {
140 let store = self.store.lock().expect("HnswStore lock poisoned");
141 store.is_empty()
142 }
143
144 pub fn insert_count(&self) -> u64 {
146 self.insert_count.load(Ordering::Relaxed)
147 }
148
149 pub fn search_count(&self) -> u64 {
151 self.search_count.load(Ordering::Relaxed)
152 }
153
154 pub fn clear(&self) {
158 let mut store = self.store.lock().expect("HnswStore lock poisoned");
159 *store = HnswStore::with_params(self.config.ef_search, self.config.ef_construction);
160 }
161
162 pub fn config(&self) -> &HnswServiceConfig {
164 &self.config
165 }
166
167 pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
172 let store = self.store.lock().expect("HnswStore lock poisoned");
173 store.save(path)
174 }
175
176 pub fn load_from_file(path: &std::path::Path) -> Result<Self, std::io::Error> {
181 let store = HnswStore::load(path)?;
182 let config = HnswServiceConfig {
183 ef_search: 100, ef_construction: 200,
185 default_dimensions: 384,
186 };
187 Ok(Self {
188 store: Mutex::new(store),
189 config,
190 insert_count: AtomicU64::new(0),
191 search_count: AtomicU64::new(0),
192 })
193 }
194}
195
196#[async_trait]
199impl SystemService for HnswService {
200 fn name(&self) -> &str {
201 "ecc.hnsw"
202 }
203
204 fn service_type(&self) -> ServiceType {
205 ServiceType::Core
206 }
207
208 async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
209 Ok(())
210 }
211
212 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
213 Ok(())
214 }
215
216 async fn health_check(&self) -> HealthStatus {
217 HealthStatus::Healthy
218 }
219}
220
221#[cfg(test)]
224mod tests {
225 use super::*;
226
227 fn make_service() -> HnswService {
228 HnswService::new(HnswServiceConfig::default())
229 }
230
231 #[test]
232 fn new_service_empty() {
233 let svc = make_service();
234 assert!(svc.is_empty());
235 assert_eq!(svc.len(), 0);
236 }
237
238 #[test]
239 fn insert_and_len() {
240 let svc = make_service();
241 svc.insert("a".into(), vec![1.0, 0.0], serde_json::json!({}));
242 svc.insert("b".into(), vec![0.0, 1.0], serde_json::json!({}));
243 assert_eq!(svc.len(), 2);
244 assert!(!svc.is_empty());
245 }
246
247 #[test]
248 fn insert_upsert() {
249 let svc = make_service();
250 svc.insert("a".into(), vec![1.0, 0.0], serde_json::json!({"v": 1}));
251 svc.insert("a".into(), vec![0.0, 1.0], serde_json::json!({"v": 2}));
252 assert_eq!(svc.len(), 1);
253 }
254
255 #[test]
256 fn search_empty_returns_empty() {
257 let svc = make_service();
258 let results = svc.search(&[1.0, 0.0], 5);
259 assert!(results.is_empty());
260 }
261
262 #[test]
263 fn search_returns_results() {
264 let svc = make_service();
265 svc.insert("a".into(), vec![1.0, 0.0, 0.0], serde_json::json!({}));
266 svc.insert("b".into(), vec![0.0, 1.0, 0.0], serde_json::json!({}));
267 svc.insert("c".into(), vec![0.0, 0.0, 1.0], serde_json::json!({}));
268
269 let results = svc.search(&[1.0, 0.0, 0.0], 2);
270 assert_eq!(results.len(), 2);
271 assert_eq!(results[0].id, "a");
272 assert!((results[0].score - 1.0).abs() < 0.01);
273 }
274
275 #[test]
276 fn search_count_incremented() {
277 let svc = make_service();
278 assert_eq!(svc.search_count(), 0);
279 svc.search(&[1.0], 1);
280 svc.search(&[1.0], 1);
281 assert_eq!(svc.search_count(), 2);
282 }
283
284 #[test]
285 fn insert_count_incremented() {
286 let svc = make_service();
287 assert_eq!(svc.insert_count(), 0);
288 svc.insert("a".into(), vec![1.0], serde_json::json!({}));
289 svc.insert("b".into(), vec![0.0], serde_json::json!({}));
290 assert_eq!(svc.insert_count(), 2);
291 }
292
293 #[test]
294 fn clear_resets() {
295 let svc = make_service();
296 svc.insert("a".into(), vec![1.0], serde_json::json!({}));
297 svc.insert("b".into(), vec![0.0], serde_json::json!({}));
298 assert_eq!(svc.len(), 2);
299
300 svc.clear();
301 assert!(svc.is_empty());
302 assert_eq!(svc.len(), 0);
303 assert_eq!(svc.insert_count(), 2);
305 }
306
307 #[test]
308 fn config_default() {
309 let cfg = HnswServiceConfig::default();
310 assert_eq!(cfg.ef_search, 100);
311 assert_eq!(cfg.ef_construction, 200);
312 assert_eq!(cfg.default_dimensions, 384);
313
314 let svc = HnswService::new(cfg);
315 let c = svc.config();
316 assert_eq!(c.ef_search, 100);
317 assert_eq!(c.ef_construction, 200);
318 assert_eq!(c.default_dimensions, 384);
319 }
320
321 #[test]
322 fn service_name_is_ecc_hnsw() {
323 let svc = make_service();
324 assert_eq!(svc.name(), "ecc.hnsw");
325 assert_eq!(svc.service_type(), ServiceType::Core);
326 }
327
328 #[tokio::test]
329 async fn service_lifecycle() {
330 let svc = make_service();
331 svc.start().await.unwrap();
332 let health = svc.health_check().await;
333 assert_eq!(health, HealthStatus::Healthy);
334 svc.stop().await.unwrap();
335 }
336
337 fn tmp_path(name: &str) -> std::path::PathBuf {
340 std::env::temp_dir().join(format!(
341 "hnsw_test_{name}_{}",
342 std::time::SystemTime::now()
343 .duration_since(std::time::UNIX_EPOCH)
344 .unwrap()
345 .as_nanos()
346 ))
347 }
348
349 #[test]
350 fn persist_empty_index() {
351 let svc = make_service();
352 let path = tmp_path("empty");
353 svc.save_to_file(&path).unwrap();
354 let loaded = HnswService::load_from_file(&path).unwrap();
355 assert!(loaded.is_empty());
356 assert_eq!(loaded.len(), 0);
357 let _ = std::fs::remove_file(&path);
358 }
359
360 #[test]
361 fn persist_index_with_vectors_search_matches() {
362 let svc = make_service();
363 svc.insert("a".into(), vec![1.0, 0.0, 0.0], serde_json::json!({"tag": "first"}));
364 svc.insert("b".into(), vec![0.0, 1.0, 0.0], serde_json::json!({"tag": "second"}));
365 svc.insert("c".into(), vec![0.0, 0.0, 1.0], serde_json::json!({"tag": "third"}));
366
367 let path = tmp_path("vectors");
368 svc.save_to_file(&path).unwrap();
369 let loaded = HnswService::load_from_file(&path).unwrap();
370
371 assert_eq!(loaded.len(), 3);
372
373 let results = loaded.search(&[1.0, 0.0, 0.0], 1);
375 assert_eq!(results.len(), 1);
376 assert_eq!(results[0].id, "a");
377 assert!((results[0].score - 1.0).abs() < 0.01);
378
379 let _ = std::fs::remove_file(&path);
380 }
381}