use crate::search::SearchResult;
use crate::search::fusion::{normalize_minmax, weighted_sum_fuse};
use crate::search::rrf_fuse;
#[derive(Debug, Clone)]
pub enum FusionStrategy {
Rrf {
k: u32,
},
WeightedSum,
}
impl Default for FusionStrategy {
fn default() -> Self {
FusionStrategy::Rrf { k: 60 }
}
}
pub fn federate_results(
lists: &[Vec<SearchResult>],
strategy: &FusionStrategy,
) -> Vec<SearchResult> {
match strategy {
FusionStrategy::Rrf { k } => rrf_fuse(lists, *k),
FusionStrategy::WeightedSum => {
let normed: Vec<Vec<SearchResult>> = lists
.iter()
.map(|l| {
let mut c = l.clone();
normalize_minmax(&mut c);
c
})
.collect();
let weights = vec![1.0_f32; normed.len()];
weighted_sum_fuse(&normed, &weights)
}
}
}
#[cfg(feature = "federation")]
mod federated {
use std::sync::Arc;
use std::time::Duration;
use futures_util::future::join_all;
use crate::embedding::{AsyncVectorIndex, SearchHit};
use crate::error::{KernelError, Result};
use crate::search::SearchResult;
use crate::search::fusion::{normalize_minmax, weighted_sum_fuse};
use crate::search::rrf_fuse;
use super::FusionStrategy;
struct Backend {
index: Arc<dyn AsyncVectorIndex>,
weight: f32,
}
fn hits_to_results(hits: Vec<SearchHit>) -> Vec<SearchResult> {
hits.into_iter()
.map(|h| SearchResult {
id: h.id.to_string(),
score: h.score,
text: String::new(),
})
.collect()
}
pub struct FederatedSearch {
backends: Vec<Backend>,
strategy: FusionStrategy,
timeout: Duration,
}
impl Default for FederatedSearch {
fn default() -> Self {
Self {
backends: Vec::new(),
strategy: FusionStrategy::default(),
timeout: Duration::from_secs(5),
}
}
}
impl FederatedSearch {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_backend(mut self, index: Arc<dyn AsyncVectorIndex>, weight: f32) -> Self {
self.backends.push(Backend { index, weight });
self
}
#[must_use]
pub fn strategy(mut self, strategy: FusionStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn search(&self, query: &[f32], k_req: usize) -> Result<Vec<SearchResult>> {
if self.backends.is_empty() {
return Ok(Vec::new());
}
let entries: Vec<(Arc<dyn AsyncVectorIndex>, f32)> = self
.backends
.iter()
.map(|b| (b.index.clone(), b.weight))
.collect();
let timeout = self.timeout;
let fetch_k = k_req.saturating_mul(2);
let futs = entries.into_iter().map(|(index, weight)| {
let q = query.to_vec();
async move {
match tokio::time::timeout(timeout, index.search(&q, fetch_k)).await {
Ok(Ok(hits)) => Some((weight, hits)),
Ok(Err(e)) => {
tracing::warn!("federated backend errored; excluding: {e}");
None
}
Err(_elapsed) => {
tracing::warn!(
"federated backend timed out after {:?}; excluding",
timeout
);
None
}
}
}
});
let collected: Vec<Option<(f32, Vec<SearchHit>)>> = join_all(futs).await;
let ok: Vec<(f32, Vec<SearchHit>)> = collected.into_iter().flatten().collect();
if ok.is_empty() {
return Err(KernelError::Search(
"all federated backends failed or timed out".into(),
));
}
let mut fused = match self.strategy {
FusionStrategy::Rrf { k } => {
let lists: Vec<Vec<SearchResult>> = ok
.into_iter()
.map(|(_w, hits)| hits_to_results(hits))
.collect();
rrf_fuse(&lists, k)
}
FusionStrategy::WeightedSum => {
let mut lists: Vec<Vec<SearchResult>> = Vec::with_capacity(ok.len());
let mut weights: Vec<f32> = Vec::with_capacity(ok.len());
for (w, hits) in ok {
let mut list = hits_to_results(hits);
normalize_minmax(&mut list);
lists.push(list);
weights.push(w);
}
weighted_sum_fuse(&lists, &weights)
}
};
fused.truncate(k_req);
Ok(fused)
}
}
}
#[cfg(feature = "federation")]
pub use federated::FederatedSearch;
#[cfg(test)]
mod tests {
use super::*;
fn hits(ids: &[(&str, f32)]) -> Vec<SearchResult> {
ids.iter()
.map(|(id, score)| SearchResult {
id: (*id).to_string(),
score: *score,
text: String::new(),
})
.collect()
}
#[test]
fn rrf_fuses_heterogeneous_scales_correctly() {
let qdrant = hits(&[("shared", 0.90), ("a", 0.50)]);
let es = hits(&[("shared", 0.97), ("b", 0.70)]);
let turbovec = hits(&[("shared", 0.30), ("c", -0.50)]);
let merged = federate_results(&[qdrant, es, turbovec], &FusionStrategy::default());
assert_eq!(merged[0].id, "shared");
}
#[test]
fn shared_id_is_deduped_and_boosted() {
let qdrant = hits(&[("shared", 1.0), ("only_q", 0.9)]);
let es = hits(&[("shared", 1.0)]);
let turbovec = hits(&[("shared", 1.0)]);
let merged = federate_results(&[qdrant, es, turbovec], &FusionStrategy::default());
let shared_count = merged.iter().filter(|r| r.id == "shared").count();
assert_eq!(shared_count, 1);
assert_eq!(merged.len(), 2); let shared_score = merged.iter().find(|r| r.id == "shared").unwrap().score;
let only_q_score = merged.iter().find(|r| r.id == "only_q").unwrap().score;
assert!(shared_score > only_q_score);
}
#[test]
fn weighted_sum_strategy_runs() {
let a = hits(&[("x", 0.0), ("y", 1.0)]);
let b = hits(&[("y", 1.0), ("z", 0.4)]);
let merged = federate_results(&[a, b], &FusionStrategy::WeightedSum);
assert_eq!(merged.len(), 3);
assert_eq!(merged[0].id, "y");
}
}
#[cfg(all(test, feature = "federation"))]
mod async_tests {
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use crate::embedding::{AsyncVectorIndex, SearchHit};
use crate::search::federation::{FederatedSearch, FusionStrategy};
struct StubIndex {
hits: Vec<SearchHit>,
delay: Option<Duration>,
fail: bool,
dim: usize,
}
#[async_trait]
impl AsyncVectorIndex for StubIndex {
async fn add(&self, _vectors: &[Vec<f32>], _ids: &[u64]) -> Result<()> {
Ok(())
}
async fn remove(&self, _ids: &[u64]) -> Result<()> {
Ok(())
}
async fn search(&self, _query: &[f32], _k: usize) -> Result<Vec<SearchHit>> {
if let Some(d) = self.delay {
tokio::time::sleep(d).await;
}
if self.fail {
return Err(anyhow!("stub backend failure"));
}
Ok(self.hits.clone())
}
async fn search_filtered(
&self,
_query: &[f32],
_k: usize,
_allowlist: &[u64],
) -> Result<Vec<SearchHit>> {
Ok(self.hits.clone())
}
async fn len(&self) -> Result<usize> {
Ok(self.hits.len())
}
fn dim(&self) -> usize {
self.dim
}
}
fn hit(id: u64, score: f32) -> SearchHit {
SearchHit { id, score }
}
#[tokio::test]
async fn slow_backend_is_dropped_not_blocking() {
let fast = Arc::new(StubIndex {
hits: vec![hit(1, 0.9), hit(2, 0.5)],
delay: None,
fail: false,
dim: 4,
});
let slow = Arc::new(StubIndex {
hits: vec![hit(3, 1.0)],
delay: Some(Duration::from_millis(500)),
fail: false,
dim: 4,
});
let fed = FederatedSearch::new()
.with_backend(fast, 1.0)
.with_backend(slow, 1.0)
.timeout(Duration::from_millis(50));
let merged = fed.search(&[1.0, 0.0, 0.0, 0.0], 5).await.unwrap();
assert!(merged.iter().any(|r| r.id == "1"));
assert!(!merged.iter().any(|r| r.id == "3"));
}
#[tokio::test]
async fn failing_backend_is_excluded() {
let good = Arc::new(StubIndex {
hits: vec![hit(7, 0.8)],
delay: None,
fail: false,
dim: 4,
});
let bad = Arc::new(StubIndex {
hits: vec![],
delay: None,
fail: true,
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(good, 1.0)
.with_backend(bad, 1.0)
.search(&[0.0, 0.0, 0.0, 1.0], 3)
.await
.unwrap();
assert!(merged.iter().any(|r| r.id == "7"));
}
#[tokio::test]
async fn all_backends_failing_returns_err() {
let bad = Arc::new(StubIndex {
hits: vec![],
delay: None,
fail: true,
dim: 4,
});
let res = FederatedSearch::new()
.with_backend(bad.clone(), 1.0)
.with_backend(bad, 1.0)
.search(&[0.0; 4], 3)
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn two_backends_merge_via_rrf() {
let a = Arc::new(StubIndex {
hits: vec![hit(1, 0.99), hit(2, 0.4)],
delay: None,
fail: false,
dim: 4,
});
let b = Arc::new(StubIndex {
hits: vec![hit(2, 0.95), hit(3, 0.6)],
delay: None,
fail: false,
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(a, 1.0)
.with_backend(b, 1.0)
.search(&[1.0, 0.0, 0.0, 0.0], 5)
.await
.unwrap();
assert_eq!(merged[0].id, "2");
assert_eq!(merged.len(), 3);
assert!(matches!(
FusionStrategy::default(),
FusionStrategy::Rrf { k: 60 }
));
}
#[tokio::test]
async fn two_backends_merge_via_weighted_sum() {
let a = Arc::new(StubIndex {
hits: vec![hit(1, 1.0), hit(2, 0.2)],
delay: None,
fail: false,
dim: 4,
});
let b = Arc::new(StubIndex {
hits: vec![hit(2, 1.0), hit(3, 0.1)],
delay: None,
fail: false,
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(a, 0.75)
.with_backend(b, 0.25)
.strategy(FusionStrategy::WeightedSum)
.search(&[1.0, 0.0, 0.0, 0.0], 5)
.await
.unwrap();
assert!(!merged.is_empty());
assert_eq!(merged[0].id, "1");
assert!(merged.iter().any(|r| r.id == "2")); assert!(merged.iter().any(|r| r.id == "3")); }
#[tokio::test]
async fn no_backends_returns_empty() {
let merged = FederatedSearch::new().search(&[0.0; 4], 3).await.unwrap();
assert!(merged.is_empty());
}
struct RankAwareStub {
hits: Vec<SearchHit>,
dim: usize,
}
#[async_trait]
impl AsyncVectorIndex for RankAwareStub {
async fn add(&self, _vectors: &[Vec<f32>], _ids: &[u64]) -> Result<()> {
Ok(())
}
async fn remove(&self, _ids: &[u64]) -> Result<()> {
Ok(())
}
async fn search(&self, _query: &[f32], k: usize) -> Result<Vec<SearchHit>> {
Ok(self.hits.iter().take(k).cloned().collect())
}
async fn search_filtered(
&self,
_query: &[f32],
k: usize,
_allowlist: &[u64],
) -> Result<Vec<SearchHit>> {
Ok(self.hits.iter().take(k).cloned().collect())
}
async fn len(&self) -> Result<usize> {
Ok(self.hits.len())
}
fn dim(&self) -> usize {
self.dim
}
}
#[tokio::test]
async fn over_fetch_preserves_rank_credit_across_backends() {
let a = Arc::new(RankAwareStub {
hits: vec![hit(101, 0.99), hit(102, 0.9), hit(7, 0.8), hit(8, 0.7)],
dim: 4,
});
let b = Arc::new(RankAwareStub {
hits: vec![hit(7, 1.0), hit(9, 0.6)],
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(a, 1.0)
.with_backend(b, 1.0)
.search(&[1.0, 0.0, 0.0, 0.0], 2)
.await
.unwrap();
assert_eq!(merged.len(), 2);
assert!(
merged.iter().any(|r| r.id == "7"),
"id 7 should survive via over-fetch rank-credit: {merged:?}"
);
}
#[tokio::test]
async fn fused_output_is_truncated_to_requested_k() {
let a = Arc::new(RankAwareStub {
hits: (1..=20).map(|i| hit(i, 1.0 - i as f32 * 0.01)).collect(),
dim: 4,
});
let b = Arc::new(RankAwareStub {
hits: (21..=40).map(|i| hit(i, 0.5 - i as f32 * 0.01)).collect(),
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(a, 1.0)
.with_backend(b, 1.0)
.search(&[1.0, 0.0, 0.0, 0.0], 5)
.await
.unwrap();
assert_eq!(merged.len(), 5, "fused output must be truncated to k");
}
#[tokio::test]
async fn k_zero_with_backends_returns_empty_not_err() {
let a = Arc::new(RankAwareStub {
hits: vec![hit(1, 0.9)],
dim: 4,
});
let merged = FederatedSearch::new()
.with_backend(a, 1.0)
.search(&[1.0, 0.0, 0.0, 0.0], 0)
.await
.unwrap();
assert!(merged.is_empty());
}
}