Skip to main content

hashtree_collection/
federated.rs

1use std::cmp::max;
2use std::collections::BTreeMap;
3
4use futures::future::join_all;
5use hashtree_core::{Cid, Store};
6
7use crate::{CollectionError, CollectionSource, SearchOptions};
8
9#[derive(Debug, Clone, Default)]
10pub struct FederatedSearchOptions {
11    pub limit: Option<usize>,
12    pub full_match: bool,
13    pub per_source_limit: Option<usize>,
14}
15
16#[derive(Clone)]
17pub struct FederatedCollectionSource<'a, S: Store> {
18    pub source_id: String,
19    pub boost: usize,
20    pub source: &'a CollectionSource<S>,
21}
22
23impl<'a, S: Store> FederatedCollectionSource<'a, S> {
24    pub fn new(source_id: impl Into<String>, source: &'a CollectionSource<S>) -> Self {
25        Self {
26            source_id: source_id.into(),
27            boost: 1,
28            source,
29        }
30    }
31
32    pub fn with_boost(mut self, boost: usize) -> Self {
33        self.boost = boost;
34        self
35    }
36}
37
38#[derive(Debug, Clone, PartialEq)]
39pub struct FederatedSearchSourceHit {
40    pub source_id: String,
41    pub cid: Cid,
42    pub score: usize,
43    pub boost: usize,
44}
45
46#[derive(Debug, Clone, PartialEq)]
47pub struct FederatedSearchHit {
48    pub id: String,
49    pub cid: Cid,
50    pub score: usize,
51    pub best_source_id: String,
52    pub source_ids: Vec<String>,
53    pub hits: Vec<FederatedSearchSourceHit>,
54}
55
56pub async fn federated_search<'a, S: Store + 'a>(
57    sources: impl IntoIterator<Item = FederatedCollectionSource<'a, S>>,
58    index_name: &str,
59    query: &str,
60    options: FederatedSearchOptions,
61) -> Result<Vec<FederatedSearchHit>, CollectionError> {
62    let source_list = sources.into_iter().collect::<Vec<_>>();
63    let limit = options.limit.unwrap_or(20);
64    if limit == 0 {
65        return Ok(Vec::new());
66    }
67
68    let per_source_limit = options
69        .per_source_limit
70        .unwrap_or(max(limit.saturating_mul(2), 20));
71    let search_options = SearchOptions {
72        limit: Some(per_source_limit),
73        full_match: options.full_match,
74    };
75
76    let local_results = join_all(source_list.iter().map(|source_input| {
77        let source_id = source_input.source_id.clone();
78        let boost = source_input.boost;
79        let search_options = search_options.clone();
80        async move {
81            let results = source_input
82                .source
83                .search(index_name, query, search_options)
84                .await?;
85            Ok::<_, CollectionError>(
86                results
87                    .into_iter()
88                    .map(|result| WeightedFederatedSearchHit {
89                        source_id: source_id.clone(),
90                        cid: result.cid,
91                        id: result.id,
92                        score: result.score,
93                        boost,
94                        weighted_score: result.score.saturating_mul(boost),
95                    })
96                    .collect::<Vec<_>>(),
97            )
98        }
99    }))
100    .await;
101
102    #[derive(Debug)]
103    struct AggregateHit {
104        cid: Cid,
105        score: usize,
106        best_source_id: String,
107        source_ids: Vec<String>,
108        hits: Vec<FederatedSearchSourceHit>,
109        best_weighted_score: usize,
110    }
111
112    let mut merged = BTreeMap::<String, AggregateHit>::new();
113    for result_set in local_results {
114        for result in result_set? {
115            let hit = FederatedSearchSourceHit {
116                source_id: result.source_id.clone(),
117                cid: result.cid.clone(),
118                score: result.score,
119                boost: result.boost,
120            };
121
122            match merged.get_mut(&result.id) {
123                None => {
124                    merged.insert(
125                        result.id,
126                        AggregateHit {
127                            cid: result.cid,
128                            score: result.weighted_score,
129                            best_source_id: result.source_id.clone(),
130                            source_ids: vec![result.source_id],
131                            hits: vec![hit],
132                            best_weighted_score: result.weighted_score,
133                        },
134                    );
135                }
136                Some(existing) => {
137                    existing.score = existing.score.saturating_add(result.weighted_score);
138                    if !existing
139                        .source_ids
140                        .iter()
141                        .any(|source_id| source_id == &result.source_id)
142                    {
143                        existing.source_ids.push(result.source_id.clone());
144                    }
145                    existing.hits.push(hit);
146
147                    if result.weighted_score > existing.best_weighted_score {
148                        existing.best_weighted_score = result.weighted_score;
149                        existing.best_source_id = result.source_id;
150                        existing.cid = result.cid;
151                    }
152                }
153            }
154        }
155    }
156
157    let mut results = merged
158        .into_iter()
159        .map(|(id, aggregate)| FederatedSearchHit {
160            id,
161            cid: aggregate.cid,
162            score: aggregate.score,
163            best_source_id: aggregate.best_source_id,
164            source_ids: aggregate.source_ids,
165            hits: aggregate.hits,
166        })
167        .collect::<Vec<_>>();
168
169    results.sort_by(|left, right| {
170        right
171            .score
172            .cmp(&left.score)
173            .then(right.source_ids.len().cmp(&left.source_ids.len()))
174            .then(left.id.cmp(&right.id))
175    });
176    results.truncate(limit);
177    Ok(results)
178}
179
180#[derive(Debug)]
181struct WeightedFederatedSearchHit {
182    source_id: String,
183    cid: Cid,
184    id: String,
185    score: usize,
186    boost: usize,
187    weighted_score: usize,
188}