1use crate::hnsw::{DistanceMetric, SearchResult};
13use ipfrs_core::{Cid, Error, Result};
14use parking_lot::RwLock;
15use std::collections::HashMap;
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
20pub struct FederatedConfig {
21 pub max_concurrent_queries: usize,
23 pub query_timeout_ms: u64,
25 pub privacy_preserving: bool,
27 pub privacy_noise_level: f32,
29 pub aggregation_strategy: AggregationStrategy,
31 pub normalize_scores: bool,
33}
34
35impl Default for FederatedConfig {
36 fn default() -> Self {
37 Self {
38 max_concurrent_queries: 10,
39 query_timeout_ms: 5000,
40 privacy_preserving: false,
41 privacy_noise_level: 0.0,
42 aggregation_strategy: AggregationStrategy::RankFusion,
43 normalize_scores: true,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum AggregationStrategy {
51 Simple,
53 RankFusion,
55 ScoreNormalization,
57 BordaCount,
59}
60
61#[async_trait::async_trait]
63pub trait QueryableIndex: Send + Sync {
64 async fn query(&self, embedding: &[f32], k: usize) -> Result<Vec<SearchResult>>;
66
67 fn distance_metric(&self) -> DistanceMetric;
69
70 fn index_id(&self) -> String;
72
73 fn size(&self) -> usize;
75}
76
77pub struct LocalIndexAdapter {
79 index: Arc<RwLock<crate::hnsw::VectorIndex>>,
80 index_id: String,
81}
82
83impl LocalIndexAdapter {
84 pub fn new(index: Arc<RwLock<crate::hnsw::VectorIndex>>, index_id: String) -> Self {
86 Self { index, index_id }
87 }
88}
89
90#[async_trait::async_trait]
91impl QueryableIndex for LocalIndexAdapter {
92 async fn query(&self, embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
93 let index = self.index.read();
94 let ef_search = k * 10; index.search(embedding, k, ef_search)
96 }
97
98 fn distance_metric(&self) -> DistanceMetric {
99 let index = self.index.read();
100 index.metric()
101 }
102
103 fn index_id(&self) -> String {
104 self.index_id.clone()
105 }
106
107 fn size(&self) -> usize {
108 let index = self.index.read();
109 index.len()
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct FederatedSearchResult {
116 pub cid: Cid,
118 pub score: f32,
120 pub source_index_id: String,
122 pub source_rank: usize,
124 pub source_metric: DistanceMetric,
126}
127
128pub struct FederatedQueryExecutor {
130 config: FederatedConfig,
132 indices: Arc<RwLock<HashMap<String, Arc<dyn QueryableIndex>>>>,
134 stats: Arc<RwLock<FederatedQueryStats>>,
136}
137
138#[derive(Debug, Clone, Default)]
140pub struct FederatedQueryStats {
141 pub total_queries: u64,
143 pub total_indices_queried: u64,
145 pub avg_latency_ms: f64,
147 pub avg_results_per_query: f64,
149 pub timeouts: u64,
151}
152
153impl FederatedQueryExecutor {
154 pub fn new(config: FederatedConfig) -> Self {
156 Self {
157 config,
158 indices: Arc::new(RwLock::new(HashMap::new())),
159 stats: Arc::new(RwLock::new(FederatedQueryStats::default())),
160 }
161 }
162
163 pub fn register_index(&self, index: Arc<dyn QueryableIndex>) -> Result<()> {
165 let index_id = index.index_id();
166 let mut indices = self.indices.write();
167
168 if indices.contains_key(&index_id) {
169 return Err(Error::InvalidInput(format!(
170 "Index '{}' is already registered",
171 index_id
172 )));
173 }
174
175 indices.insert(index_id.clone(), index);
176 tracing::info!("Registered index '{}' for federated queries", index_id);
177 Ok(())
178 }
179
180 pub fn unregister_index(&self, index_id: &str) -> Result<()> {
182 let mut indices = self.indices.write();
183 if indices.remove(index_id).is_some() {
184 tracing::info!("Unregistered index '{}'", index_id);
185 Ok(())
186 } else {
187 Err(Error::NotFound(format!("Index '{}' not found", index_id)))
188 }
189 }
190
191 pub async fn query(&self, embedding: &[f32], k: usize) -> Result<Vec<FederatedSearchResult>> {
193 let start = std::time::Instant::now();
194
195 let indices = {
197 let indices_lock = self.indices.read();
198 indices_lock
199 .iter()
200 .map(|(id, idx)| (id.clone(), Arc::clone(idx)))
201 .collect::<Vec<_>>()
202 };
203
204 if indices.is_empty() {
205 return Err(Error::InvalidInput(
206 "No indices registered for federated query".to_string(),
207 ));
208 }
209
210 let query_embedding = if self.config.privacy_preserving {
212 self.apply_privacy_noise(embedding)
213 } else {
214 embedding.to_vec()
215 };
216
217 let mut tasks = Vec::new();
219 for (index_id, index) in indices {
220 let query_emb = query_embedding.clone();
221 let task = tokio::spawn(async move {
222 let result = index.query(&query_emb, k).await;
223 (index_id, index.distance_metric(), result)
224 });
225 tasks.push(task);
226 }
227
228 let mut all_results = Vec::new();
230 let mut indices_queried = 0;
231 let mut timeouts = 0;
232
233 for task in tasks {
234 match tokio::time::timeout(
235 std::time::Duration::from_millis(self.config.query_timeout_ms),
236 task,
237 )
238 .await
239 {
240 Ok(Ok((index_id, metric, Ok(results)))) => {
241 indices_queried += 1;
242 for (rank, result) in results.into_iter().enumerate() {
243 all_results.push((index_id.clone(), metric, rank, result));
244 }
245 }
246 Ok(Ok((index_id, _, Err(e)))) => {
247 tracing::warn!("Query failed for index '{}': {:?}", index_id, e);
248 }
249 Ok(Err(e)) => {
250 tracing::warn!("Task panicked: {:?}", e);
251 }
252 Err(_) => {
253 timeouts += 1;
254 tracing::warn!("Query timeout for an index");
255 }
256 }
257 }
258
259 let aggregated = self.aggregate_results(all_results, k)?;
261
262 let latency = start.elapsed().as_millis() as f64;
264 self.update_stats(indices_queried, aggregated.len(), latency, timeouts);
265
266 Ok(aggregated)
267 }
268
269 pub async fn query_indices(
271 &self,
272 embedding: &[f32],
273 k: usize,
274 index_ids: &[String],
275 ) -> Result<Vec<FederatedSearchResult>> {
276 let start = std::time::Instant::now();
277
278 let indices = {
280 let indices_lock = self.indices.read();
281 index_ids
282 .iter()
283 .filter_map(|id| {
284 indices_lock
285 .get(id)
286 .map(|idx| (id.clone(), Arc::clone(idx)))
287 })
288 .collect::<Vec<_>>()
289 };
290
291 if indices.is_empty() {
292 return Err(Error::InvalidInput(
293 "None of the requested indices are registered".to_string(),
294 ));
295 }
296
297 let query_embedding = if self.config.privacy_preserving {
299 self.apply_privacy_noise(embedding)
300 } else {
301 embedding.to_vec()
302 };
303
304 let mut tasks = Vec::new();
306 for (index_id, index) in indices {
307 let query_emb = query_embedding.clone();
308 let task = tokio::spawn(async move {
309 let result = index.query(&query_emb, k).await;
310 (index_id, index.distance_metric(), result)
311 });
312 tasks.push(task);
313 }
314
315 let mut all_results = Vec::new();
317 let mut indices_queried = 0;
318 let mut timeouts = 0;
319
320 for task in tasks {
321 match tokio::time::timeout(
322 std::time::Duration::from_millis(self.config.query_timeout_ms),
323 task,
324 )
325 .await
326 {
327 Ok(Ok((index_id, metric, Ok(results)))) => {
328 indices_queried += 1;
329 for (rank, result) in results.into_iter().enumerate() {
330 all_results.push((index_id.clone(), metric, rank, result));
331 }
332 }
333 Ok(Ok((index_id, _, Err(e)))) => {
334 tracing::warn!("Query failed for index '{}': {:?}", index_id, e);
335 }
336 Ok(Err(e)) => {
337 tracing::warn!("Task panicked: {:?}", e);
338 }
339 Err(_) => {
340 timeouts += 1;
341 tracing::warn!("Query timeout for an index");
342 }
343 }
344 }
345
346 let aggregated = self.aggregate_results(all_results, k)?;
347
348 let latency = start.elapsed().as_millis() as f64;
349 self.update_stats(indices_queried, aggregated.len(), latency, timeouts);
350
351 Ok(aggregated)
352 }
353
354 fn apply_privacy_noise(&self, embedding: &[f32]) -> Vec<f32> {
356 use rand::Rng;
357 let mut rng = rand::rng();
358
359 embedding
360 .iter()
361 .map(|&x| {
362 let noise = rng.random_range(
363 -self.config.privacy_noise_level..self.config.privacy_noise_level,
364 );
365 x + noise
366 })
367 .collect()
368 }
369
370 fn aggregate_results(
372 &self,
373 results: Vec<(String, DistanceMetric, usize, SearchResult)>,
374 k: usize,
375 ) -> Result<Vec<FederatedSearchResult>> {
376 match self.config.aggregation_strategy {
377 AggregationStrategy::Simple => self.aggregate_simple(results, k),
378 AggregationStrategy::RankFusion => self.aggregate_rank_fusion(results, k),
379 AggregationStrategy::ScoreNormalization => {
380 self.aggregate_score_normalization(results, k)
381 }
382 AggregationStrategy::BordaCount => self.aggregate_borda_count(results, k),
383 }
384 }
385
386 fn aggregate_simple(
388 &self,
389 results: Vec<(String, DistanceMetric, usize, SearchResult)>,
390 k: usize,
391 ) -> Result<Vec<FederatedSearchResult>> {
392 let mut federated: Vec<_> = results
393 .into_iter()
394 .map(|(index_id, metric, rank, result)| FederatedSearchResult {
395 cid: result.cid,
396 score: result.score,
397 source_index_id: index_id,
398 source_rank: rank,
399 source_metric: metric,
400 })
401 .collect();
402
403 federated.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
405 federated.truncate(k);
406
407 Ok(federated)
408 }
409
410 fn aggregate_rank_fusion(
412 &self,
413 results: Vec<(String, DistanceMetric, usize, SearchResult)>,
414 k: usize,
415 ) -> Result<Vec<FederatedSearchResult>> {
416 let mut scores: HashMap<Cid, (f32, String, usize, DistanceMetric)> = HashMap::new();
417 const RRF_K: f32 = 60.0;
418
419 for (index_id, metric, rank, result) in results {
420 let rrf_score = 1.0 / (RRF_K + rank as f32);
421
422 scores
423 .entry(result.cid)
424 .and_modify(|(score, _, _, _)| *score += rrf_score)
425 .or_insert((rrf_score, index_id.clone(), rank, metric));
426 }
427
428 let mut federated: Vec<_> = scores
429 .into_iter()
430 .map(
431 |(cid, (score, index_id, rank, metric))| FederatedSearchResult {
432 cid,
433 score,
434 source_index_id: index_id,
435 source_rank: rank,
436 source_metric: metric,
437 },
438 )
439 .collect();
440
441 federated.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
443 federated.truncate(k);
444
445 Ok(federated)
446 }
447
448 fn aggregate_score_normalization(
450 &self,
451 results: Vec<(String, DistanceMetric, usize, SearchResult)>,
452 k: usize,
453 ) -> Result<Vec<FederatedSearchResult>> {
454 let mut by_index: HashMap<String, Vec<(DistanceMetric, usize, SearchResult)>> =
456 HashMap::new();
457
458 for (index_id, metric, rank, result) in results {
459 by_index
460 .entry(index_id)
461 .or_default()
462 .push((metric, rank, result));
463 }
464
465 let mut normalized = Vec::new();
467 for (index_id, index_results) in by_index {
468 if index_results.is_empty() {
469 continue;
470 }
471
472 let scores: Vec<f32> = index_results.iter().map(|(_, _, r)| r.score).collect();
474 let min_score = scores.iter().copied().fold(f32::INFINITY, f32::min);
475 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
476 let range = max_score - min_score;
477
478 for (metric, rank, result) in index_results {
479 let normalized_score = if range > 1e-6 {
480 (result.score - min_score) / range
481 } else {
482 0.5 };
484
485 normalized.push(FederatedSearchResult {
486 cid: result.cid,
487 score: normalized_score,
488 source_index_id: index_id.clone(),
489 source_rank: rank,
490 source_metric: metric,
491 });
492 }
493 }
494
495 normalized.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
497 normalized.truncate(k);
498
499 Ok(normalized)
500 }
501
502 fn aggregate_borda_count(
504 &self,
505 results: Vec<(String, DistanceMetric, usize, SearchResult)>,
506 k: usize,
507 ) -> Result<Vec<FederatedSearchResult>> {
508 let mut borda_scores: HashMap<Cid, (usize, String, usize, DistanceMetric)> = HashMap::new();
509
510 let max_rank = results
512 .iter()
513 .map(|(_, _, rank, _)| *rank)
514 .max()
515 .unwrap_or(0);
516
517 for (index_id, metric, rank, result) in results {
518 let borda_points = max_rank.saturating_sub(rank);
519
520 borda_scores
521 .entry(result.cid)
522 .and_modify(|(points, _, _, _)| *points += borda_points)
523 .or_insert((borda_points, index_id.clone(), rank, metric));
524 }
525
526 let mut federated: Vec<_> = borda_scores
527 .into_iter()
528 .map(
529 |(cid, (points, index_id, rank, metric))| FederatedSearchResult {
530 cid,
531 score: points as f32,
532 source_index_id: index_id,
533 source_rank: rank,
534 source_metric: metric,
535 },
536 )
537 .collect();
538
539 federated.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
541 federated.truncate(k);
542
543 Ok(federated)
544 }
545
546 fn update_stats(&self, indices_queried: u64, num_results: usize, latency: f64, timeouts: u64) {
548 let mut stats = self.stats.write();
549 stats.total_queries += 1;
550 stats.total_indices_queried += indices_queried;
551 stats.timeouts += timeouts;
552
553 let alpha = 0.1;
555 stats.avg_latency_ms = alpha * latency + (1.0 - alpha) * stats.avg_latency_ms;
556 stats.avg_results_per_query =
557 alpha * num_results as f64 + (1.0 - alpha) * stats.avg_results_per_query;
558 }
559
560 pub fn stats(&self) -> FederatedQueryStats {
562 self.stats.read().clone()
563 }
564
565 pub fn registered_indices(&self) -> Vec<String> {
567 let indices = self.indices.read();
568 indices.keys().cloned().collect()
569 }
570
571 pub fn total_size(&self) -> usize {
573 let indices = self.indices.read();
574 indices.values().map(|idx| idx.size()).sum()
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::hnsw::VectorIndex;
582 use multihash_codetable::{Code, MultihashDigest};
583
584 #[tokio::test]
585 async fn test_federated_executor_creation() {
586 let config = FederatedConfig::default();
587 let executor = FederatedQueryExecutor::new(config);
588 assert_eq!(executor.registered_indices().len(), 0);
589 }
590
591 #[tokio::test]
592 async fn test_register_and_unregister_index() {
593 let executor = FederatedQueryExecutor::new(FederatedConfig::default());
594
595 let index = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
596 let adapter =
597 LocalIndexAdapter::new(Arc::new(RwLock::new(index)), "test_index".to_string());
598
599 executor.register_index(Arc::new(adapter)).unwrap();
600 assert_eq!(executor.registered_indices().len(), 1);
601
602 executor.unregister_index("test_index").unwrap();
603 assert_eq!(executor.registered_indices().len(), 0);
604 }
605
606 #[tokio::test]
607 async fn test_federated_query_single_index() {
608 let executor = FederatedQueryExecutor::new(FederatedConfig::default());
609
610 let index = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
612 let index_lock = Arc::new(RwLock::new(index));
613
614 for i in 0..100 {
616 let data = format!("vector_{}", i);
617 let hash = Code::Sha2_256.digest(data.as_bytes());
618 let cid = Cid::new_v1(0x55, hash);
619 let embedding: Vec<f32> = (0..128).map(|j| (i + j) as f32 * 0.01).collect();
620 index_lock.write().insert(&cid, &embedding).unwrap();
621 }
622
623 let adapter = LocalIndexAdapter::new(Arc::clone(&index_lock), "index1".to_string());
624 executor.register_index(Arc::new(adapter)).unwrap();
625
626 let query_emb: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
628 let results = executor.query(&query_emb, 10).await.unwrap();
629
630 assert!(!results.is_empty());
631 assert!(results.len() <= 10);
632 }
633
634 #[tokio::test]
635 async fn test_federated_query_multiple_indices() {
636 let config = FederatedConfig {
637 aggregation_strategy: AggregationStrategy::RankFusion,
638 ..Default::default()
639 };
640 let executor = FederatedQueryExecutor::new(config);
641
642 let index1 = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
644 let index2 = VectorIndex::new(128, DistanceMetric::L2, 16, 200).unwrap();
645
646 let lock1 = Arc::new(RwLock::new(index1));
647 let lock2 = Arc::new(RwLock::new(index2));
648
649 for i in 0..50 {
651 let data = format!("vector_a_{}", i);
652 let hash = Code::Sha2_256.digest(data.as_bytes());
653 let cid = Cid::new_v1(0x55, hash);
654 let embedding: Vec<f32> = (0..128).map(|j| (i + j) as f32 * 0.01).collect();
655 lock1.write().insert(&cid, &embedding).unwrap();
656 }
657
658 for i in 25..75 {
659 let data = format!("vector_b_{}", i);
661 let hash = Code::Sha2_256.digest(data.as_bytes());
662 let cid = Cid::new_v1(0x55, hash);
663 let embedding: Vec<f32> = (0..128).map(|j| (i + j) as f32 * 0.01).collect();
664 lock2.write().insert(&cid, &embedding).unwrap();
665 }
666
667 executor
668 .register_index(Arc::new(LocalIndexAdapter::new(
669 Arc::clone(&lock1),
670 "index1".to_string(),
671 )))
672 .unwrap();
673 executor
674 .register_index(Arc::new(LocalIndexAdapter::new(
675 Arc::clone(&lock2),
676 "index2".to_string(),
677 )))
678 .unwrap();
679
680 let query_emb: Vec<f32> = (0..128).map(|i| i as f32 * 0.02).collect();
682 let results = executor.query(&query_emb, 10).await.unwrap();
683
684 assert!(!results.is_empty());
685 assert!(results.len() <= 10);
686
687 let stats = executor.stats();
689 assert_eq!(stats.total_queries, 1);
690 assert!(stats.total_indices_queried >= 1);
691 }
692
693 #[tokio::test]
694 async fn test_different_aggregation_strategies() {
695 for strategy in &[
696 AggregationStrategy::Simple,
697 AggregationStrategy::RankFusion,
698 AggregationStrategy::ScoreNormalization,
699 AggregationStrategy::BordaCount,
700 ] {
701 let config = FederatedConfig {
702 aggregation_strategy: *strategy,
703 ..Default::default()
704 };
705 let executor = FederatedQueryExecutor::new(config);
706
707 let index = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
708 let lock = Arc::new(RwLock::new(index));
709
710 for i in 0..20 {
712 let data = format!("vec_{}", i);
713 let hash = Code::Sha2_256.digest(data.as_bytes());
714 let cid = Cid::new_v1(0x55, hash);
715 let embedding: Vec<f32> = (0..128).map(|j| (i + j) as f32 * 0.01).collect();
716 lock.write().insert(&cid, &embedding).unwrap();
717 }
718
719 executor
720 .register_index(Arc::new(LocalIndexAdapter::new(
721 lock,
722 format!("index_{:?}", strategy),
723 )))
724 .unwrap();
725
726 let query_emb: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
727 let results = executor.query(&query_emb, 5).await.unwrap();
728
729 assert!(!results.is_empty(), "Strategy {:?} failed", strategy);
730 }
731 }
732
733 #[tokio::test]
734 async fn test_privacy_preserving_mode() {
735 let config = FederatedConfig {
736 privacy_preserving: true,
737 privacy_noise_level: 0.1,
738 ..Default::default()
739 };
740
741 let executor = FederatedQueryExecutor::new(config);
742
743 let index = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
744 let lock = Arc::new(RwLock::new(index));
745
746 for i in 0..30 {
747 let data = format!("private_vec_{}", i);
748 let hash = Code::Sha2_256.digest(data.as_bytes());
749 let cid = Cid::new_v1(0x55, hash);
750 let embedding: Vec<f32> = (0..128).map(|j| (i + j) as f32 * 0.01).collect();
751 lock.write().insert(&cid, &embedding).unwrap();
752 }
753
754 executor
755 .register_index(Arc::new(LocalIndexAdapter::new(
756 lock,
757 "private_index".to_string(),
758 )))
759 .unwrap();
760
761 let query_emb: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
762 let results = executor.query(&query_emb, 5).await.unwrap();
763
764 assert!(!results.is_empty());
766 }
767
768 #[tokio::test]
769 async fn test_query_specific_indices() {
770 let executor = FederatedQueryExecutor::new(FederatedConfig::default());
771
772 for idx_num in 0..3 {
774 let index = VectorIndex::new(128, DistanceMetric::Cosine, 16, 200).unwrap();
775 let lock = Arc::new(RwLock::new(index));
776
777 for i in 0..20 {
778 let data = format!("vec_{}_{}", idx_num, i);
779 let hash = Code::Sha2_256.digest(data.as_bytes());
780 let cid = Cid::new_v1(0x55, hash);
781 let embedding: Vec<f32> =
782 (0..128).map(|j| (i + j + idx_num) as f32 * 0.01).collect();
783 lock.write().insert(&cid, &embedding).unwrap();
784 }
785
786 executor
787 .register_index(Arc::new(LocalIndexAdapter::new(
788 lock,
789 format!("index_{}", idx_num),
790 )))
791 .unwrap();
792 }
793
794 let query_emb: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
796 let results = executor
797 .query_indices(
798 &query_emb,
799 10,
800 &["index_0".to_string(), "index_2".to_string()],
801 )
802 .await
803 .unwrap();
804
805 assert!(!results.is_empty());
806
807 for result in results {
809 assert!(result.source_index_id == "index_0" || result.source_index_id == "index_2");
810 }
811 }
812}