hashtree-collection 0.2.36

Immutable collections, schema hooks, and federated search for hashtree
Documentation
use std::cmp::max;
use std::collections::BTreeMap;

use futures::future::join_all;
use hashtree_core::{Cid, Store};

use crate::{CollectionError, CollectionSource, SearchOptions};

#[derive(Debug, Clone, Default)]
pub struct FederatedSearchOptions {
    pub limit: Option<usize>,
    pub full_match: bool,
    pub per_source_limit: Option<usize>,
}

#[derive(Clone)]
pub struct FederatedCollectionSource<'a, S: Store> {
    pub source_id: String,
    pub boost: usize,
    pub source: &'a CollectionSource<S>,
}

impl<'a, S: Store> FederatedCollectionSource<'a, S> {
    pub fn new(source_id: impl Into<String>, source: &'a CollectionSource<S>) -> Self {
        Self {
            source_id: source_id.into(),
            boost: 1,
            source,
        }
    }

    pub fn with_boost(mut self, boost: usize) -> Self {
        self.boost = boost;
        self
    }
}

#[derive(Debug, Clone, PartialEq)]
pub struct FederatedSearchSourceHit {
    pub source_id: String,
    pub cid: Cid,
    pub score: usize,
    pub boost: usize,
}

#[derive(Debug, Clone, PartialEq)]
pub struct FederatedSearchHit {
    pub id: String,
    pub cid: Cid,
    pub score: usize,
    pub best_source_id: String,
    pub source_ids: Vec<String>,
    pub hits: Vec<FederatedSearchSourceHit>,
}

pub async fn federated_search<'a, S: Store + 'a>(
    sources: impl IntoIterator<Item = FederatedCollectionSource<'a, S>>,
    index_name: &str,
    query: &str,
    options: FederatedSearchOptions,
) -> Result<Vec<FederatedSearchHit>, CollectionError> {
    let source_list = sources.into_iter().collect::<Vec<_>>();
    let limit = options.limit.unwrap_or(20);
    if limit == 0 {
        return Ok(Vec::new());
    }

    let per_source_limit = options
        .per_source_limit
        .unwrap_or(max(limit.saturating_mul(2), 20));
    let search_options = SearchOptions {
        limit: Some(per_source_limit),
        full_match: options.full_match,
    };

    let local_results = join_all(source_list.iter().map(|source_input| {
        let source_id = source_input.source_id.clone();
        let boost = source_input.boost;
        let search_options = search_options.clone();
        async move {
            let results = source_input
                .source
                .search(index_name, query, search_options)
                .await?;
            Ok::<_, CollectionError>(
                results
                    .into_iter()
                    .map(|result| WeightedFederatedSearchHit {
                        source_id: source_id.clone(),
                        cid: result.cid,
                        id: result.id,
                        score: result.score,
                        boost,
                        weighted_score: result.score.saturating_mul(boost),
                    })
                    .collect::<Vec<_>>(),
            )
        }
    }))
    .await;

    #[derive(Debug)]
    struct AggregateHit {
        cid: Cid,
        score: usize,
        best_source_id: String,
        source_ids: Vec<String>,
        hits: Vec<FederatedSearchSourceHit>,
        best_weighted_score: usize,
    }

    let mut merged = BTreeMap::<String, AggregateHit>::new();
    for result_set in local_results {
        for result in result_set? {
            let hit = FederatedSearchSourceHit {
                source_id: result.source_id.clone(),
                cid: result.cid.clone(),
                score: result.score,
                boost: result.boost,
            };

            match merged.get_mut(&result.id) {
                None => {
                    merged.insert(
                        result.id,
                        AggregateHit {
                            cid: result.cid,
                            score: result.weighted_score,
                            best_source_id: result.source_id.clone(),
                            source_ids: vec![result.source_id],
                            hits: vec![hit],
                            best_weighted_score: result.weighted_score,
                        },
                    );
                }
                Some(existing) => {
                    existing.score = existing.score.saturating_add(result.weighted_score);
                    if !existing
                        .source_ids
                        .iter()
                        .any(|source_id| source_id == &result.source_id)
                    {
                        existing.source_ids.push(result.source_id.clone());
                    }
                    existing.hits.push(hit);

                    if result.weighted_score > existing.best_weighted_score {
                        existing.best_weighted_score = result.weighted_score;
                        existing.best_source_id = result.source_id;
                        existing.cid = result.cid;
                    }
                }
            }
        }
    }

    let mut results = merged
        .into_iter()
        .map(|(id, aggregate)| FederatedSearchHit {
            id,
            cid: aggregate.cid,
            score: aggregate.score,
            best_source_id: aggregate.best_source_id,
            source_ids: aggregate.source_ids,
            hits: aggregate.hits,
        })
        .collect::<Vec<_>>();

    results.sort_by(|left, right| {
        right
            .score
            .cmp(&left.score)
            .then(right.source_ids.len().cmp(&left.source_ids.len()))
            .then(left.id.cmp(&right.id))
    });
    results.truncate(limit);
    Ok(results)
}

#[derive(Debug)]
struct WeightedFederatedSearchHit {
    source_id: String,
    cid: Cid,
    id: String,
    score: usize,
    boost: usize,
    weighted_score: usize,
}