1use std::cmp::max;
2use std::collections::BTreeMap;
3
4use futures::future::join_all;
5use hashtree_core::{Cid, Store};
6
7use crate::{CollectionError, CollectionSource, SearchOptions};
8
9#[derive(Debug, Clone, Default)]
10pub struct FederatedSearchOptions {
11 pub limit: Option<usize>,
12 pub full_match: bool,
13 pub per_source_limit: Option<usize>,
14}
15
16#[derive(Clone)]
17pub struct FederatedCollectionSource<'a, S: Store> {
18 pub source_id: String,
19 pub boost: usize,
20 pub source: &'a CollectionSource<S>,
21}
22
23impl<'a, S: Store> FederatedCollectionSource<'a, S> {
24 pub fn new(source_id: impl Into<String>, source: &'a CollectionSource<S>) -> Self {
25 Self {
26 source_id: source_id.into(),
27 boost: 1,
28 source,
29 }
30 }
31
32 pub fn with_boost(mut self, boost: usize) -> Self {
33 self.boost = boost;
34 self
35 }
36}
37
38#[derive(Debug, Clone, PartialEq)]
39pub struct FederatedSearchSourceHit {
40 pub source_id: String,
41 pub cid: Cid,
42 pub score: usize,
43 pub boost: usize,
44}
45
46#[derive(Debug, Clone, PartialEq)]
47pub struct FederatedSearchHit {
48 pub id: String,
49 pub cid: Cid,
50 pub score: usize,
51 pub best_source_id: String,
52 pub source_ids: Vec<String>,
53 pub hits: Vec<FederatedSearchSourceHit>,
54}
55
56pub async fn federated_search<'a, S: Store + 'a>(
57 sources: impl IntoIterator<Item = FederatedCollectionSource<'a, S>>,
58 index_name: &str,
59 query: &str,
60 options: FederatedSearchOptions,
61) -> Result<Vec<FederatedSearchHit>, CollectionError> {
62 let source_list = sources.into_iter().collect::<Vec<_>>();
63 let limit = options.limit.unwrap_or(20);
64 if limit == 0 {
65 return Ok(Vec::new());
66 }
67
68 let per_source_limit = options
69 .per_source_limit
70 .unwrap_or(max(limit.saturating_mul(2), 20));
71 let search_options = SearchOptions {
72 limit: Some(per_source_limit),
73 full_match: options.full_match,
74 };
75
76 let local_results = join_all(source_list.iter().map(|source_input| {
77 let source_id = source_input.source_id.clone();
78 let boost = source_input.boost;
79 let search_options = search_options.clone();
80 async move {
81 let results = source_input
82 .source
83 .search(index_name, query, search_options)
84 .await?;
85 Ok::<_, CollectionError>(
86 results
87 .into_iter()
88 .map(|result| WeightedFederatedSearchHit {
89 source_id: source_id.clone(),
90 cid: result.cid,
91 id: result.id,
92 score: result.score,
93 boost,
94 weighted_score: result.score.saturating_mul(boost),
95 })
96 .collect::<Vec<_>>(),
97 )
98 }
99 }))
100 .await;
101
102 #[derive(Debug)]
103 struct AggregateHit {
104 cid: Cid,
105 score: usize,
106 best_source_id: String,
107 source_ids: Vec<String>,
108 hits: Vec<FederatedSearchSourceHit>,
109 best_weighted_score: usize,
110 }
111
112 let mut merged = BTreeMap::<String, AggregateHit>::new();
113 for result_set in local_results {
114 for result in result_set? {
115 let hit = FederatedSearchSourceHit {
116 source_id: result.source_id.clone(),
117 cid: result.cid.clone(),
118 score: result.score,
119 boost: result.boost,
120 };
121
122 match merged.get_mut(&result.id) {
123 None => {
124 merged.insert(
125 result.id,
126 AggregateHit {
127 cid: result.cid,
128 score: result.weighted_score,
129 best_source_id: result.source_id.clone(),
130 source_ids: vec![result.source_id],
131 hits: vec![hit],
132 best_weighted_score: result.weighted_score,
133 },
134 );
135 }
136 Some(existing) => {
137 existing.score = existing.score.saturating_add(result.weighted_score);
138 if !existing
139 .source_ids
140 .iter()
141 .any(|source_id| source_id == &result.source_id)
142 {
143 existing.source_ids.push(result.source_id.clone());
144 }
145 existing.hits.push(hit);
146
147 if result.weighted_score > existing.best_weighted_score {
148 existing.best_weighted_score = result.weighted_score;
149 existing.best_source_id = result.source_id;
150 existing.cid = result.cid;
151 }
152 }
153 }
154 }
155 }
156
157 let mut results = merged
158 .into_iter()
159 .map(|(id, aggregate)| FederatedSearchHit {
160 id,
161 cid: aggregate.cid,
162 score: aggregate.score,
163 best_source_id: aggregate.best_source_id,
164 source_ids: aggregate.source_ids,
165 hits: aggregate.hits,
166 })
167 .collect::<Vec<_>>();
168
169 results.sort_by(|left, right| {
170 right
171 .score
172 .cmp(&left.score)
173 .then(right.source_ids.len().cmp(&left.source_ids.len()))
174 .then(left.id.cmp(&right.id))
175 });
176 results.truncate(limit);
177 Ok(results)
178}
179
180#[derive(Debug)]
181struct WeightedFederatedSearchHit {
182 source_id: String,
183 cid: Cid,
184 id: String,
185 score: usize,
186 boost: usize,
187 weighted_score: usize,
188}