use async_trait::async_trait;
use std::sync::Arc;
use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities};
use crate::document::{DocumentTree, NodeId};
use crate::graph::DocumentGraph;
use crate::retrieval::RetrievalContext;
use crate::retrieval::types::QueryComplexity;
pub type DocumentId = String;
pub struct DocumentEntry {
pub id: DocumentId,
pub title: String,
pub tree: DocumentTree,
}
impl DocumentEntry {
pub fn new(id: impl Into<String>, title: impl Into<String>, tree: DocumentTree) -> Self {
Self {
id: id.into(),
title: title.into(),
tree,
}
}
}
#[derive(Debug, Clone)]
pub struct DocumentResult {
pub doc_id: DocumentId,
pub doc_title: String,
pub evaluations: Vec<(NodeId, NodeEvaluation)>,
pub best_score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MergeStrategy {
#[default]
TopK,
BestPerDocument,
WeightedByRelevance,
GraphBoosted,
}
#[derive(Debug, Clone)]
pub struct CrossDocumentConfig {
pub max_documents: usize,
pub max_results_per_doc: usize,
pub max_total_results: usize,
pub min_score: f32,
pub merge_strategy: MergeStrategy,
pub parallel_search: bool,
}
impl Default for CrossDocumentConfig {
fn default() -> Self {
Self {
max_documents: 10,
max_results_per_doc: 3,
max_total_results: 10,
min_score: 0.3,
merge_strategy: MergeStrategy::TopK,
parallel_search: true,
}
}
}
pub struct CrossDocumentStrategy {
inner: Box<dyn RetrievalStrategy>,
config: CrossDocumentConfig,
documents: Vec<DocumentEntry>,
graph: Option<Arc<DocumentGraph>>,
}
impl CrossDocumentStrategy {
pub fn new(inner: Box<dyn RetrievalStrategy>) -> Self {
Self {
inner,
config: CrossDocumentConfig::default(),
documents: Vec::new(),
graph: None,
}
}
pub fn with_config(mut self, config: CrossDocumentConfig) -> Self {
self.config = config;
self
}
pub fn add_document(&mut self, doc: DocumentEntry) {
if self.documents.len() < self.config.max_documents {
self.documents.push(doc);
}
}
pub fn with_documents(mut self, documents: Vec<DocumentEntry>) -> Self {
self.documents = documents
.into_iter()
.take(self.config.max_documents)
.collect();
self
}
pub fn document_count(&self) -> usize {
self.documents.len()
}
pub fn with_graph(mut self, graph: Arc<DocumentGraph>) -> Self {
self.graph = Some(graph);
self
}
fn apply_graph_boost(
&self,
results: &mut Vec<(DocumentId, NodeId, NodeEvaluation)>,
boost_factor: f32,
) {
let graph = match self.graph {
Some(ref g) => g,
None => return,
};
let high_score_docs: Vec<(String, f32)> = results
.iter()
.filter(|(_, _, eval)| eval.score > 0.5)
.map(|(doc_id, _, eval)| (doc_id.clone(), eval.score))
.collect();
if high_score_docs.is_empty() {
return;
}
for (doc_id, base_score) in &high_score_docs {
let neighbors = graph.get_neighbors(doc_id);
for edge in neighbors {
for result in results.iter_mut() {
if result.0 == edge.target_doc_id {
let boost = boost_factor * edge.weight * base_score;
result.2.score += boost;
}
}
}
}
results.sort_by(|a, b| {
b.2.score
.partial_cmp(&a.2.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
async fn search_document(
&self,
doc: &DocumentEntry,
context: &RetrievalContext,
) -> DocumentResult {
let root_id = doc.tree.root();
let children = doc.tree.children(root_id);
let top_evaluations = self
.inner
.evaluate_nodes(&doc.tree, &children, context)
.await;
let mut scored_nodes: Vec<(NodeId, NodeEvaluation)> = children
.into_iter()
.zip(top_evaluations.into_iter())
.filter(|(_, eval)| eval.score >= self.config.min_score)
.collect();
let high_score_nodes: Vec<NodeId> = scored_nodes
.iter()
.filter(|(_, eval)| eval.score >= self.config.min_score * 1.5)
.map(|(id, _)| *id)
.collect();
for node_id in high_score_nodes {
let depth_results = self.search_subtree(&doc.tree, node_id, context, 0, 2).await;
scored_nodes.extend(depth_results);
}
scored_nodes.sort_by(|a, b| {
b.1.score
.partial_cmp(&a.1.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored_nodes.dedup_by(|a, b| a.0 == b.0);
scored_nodes.truncate(self.config.max_results_per_doc);
let best_score = scored_nodes.first().map(|(_, e)| e.score).unwrap_or(0.0);
DocumentResult {
doc_id: doc.id.clone(),
doc_title: doc.title.clone(),
evaluations: scored_nodes,
best_score,
}
}
fn search_subtree<'a>(
&'a self,
tree: &'a DocumentTree,
parent_id: NodeId,
context: &'a RetrievalContext,
current_depth: usize,
max_depth: usize,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Vec<(NodeId, NodeEvaluation)>> + Send + 'a>,
> {
Box::pin(async move {
if current_depth >= max_depth {
return Vec::new();
}
let children = tree.children(parent_id);
if children.is_empty() {
return Vec::new();
}
let evaluations = self.inner.evaluate_nodes(tree, &children, context).await;
let mut results = Vec::new();
let mut explore_further = Vec::new();
for (node_id, eval) in children.into_iter().zip(evaluations.into_iter()) {
if eval.score >= self.config.min_score {
results.push((node_id, eval.clone()));
}
if eval.score >= self.config.min_score * 1.5 {
explore_further.push(node_id);
}
}
for child_id in explore_further {
let deeper = self
.search_subtree(tree, child_id, context, current_depth + 1, max_depth)
.await;
results.extend(deeper);
}
results
})
}
fn merge_results(
&self,
doc_results: Vec<DocumentResult>,
) -> Vec<(DocumentId, NodeId, NodeEvaluation)> {
match self.config.merge_strategy {
MergeStrategy::TopK => {
let mut all_results: Vec<_> = doc_results
.into_iter()
.flat_map(|doc| {
doc.evaluations
.into_iter()
.map(move |(node_id, eval)| (doc.doc_id.clone(), node_id, eval))
})
.collect();
all_results.sort_by(|a, b| {
b.2.score
.partial_cmp(&a.2.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(self.config.max_total_results);
all_results
}
MergeStrategy::BestPerDocument => {
doc_results
.into_iter()
.filter_map(|doc| {
doc.evaluations
.into_iter()
.next()
.map(|(node_id, eval)| (doc.doc_id, node_id, eval))
})
.take(self.config.max_total_results)
.collect()
}
MergeStrategy::WeightedByRelevance => {
let max_doc_score = doc_results
.iter()
.map(|d| d.best_score)
.fold(0.0_f32, f32::max);
let mut all_results: Vec<_> = doc_results
.into_iter()
.flat_map(|doc| {
let weight = if max_doc_score > 0.0 {
doc.best_score / max_doc_score
} else {
1.0
};
doc.evaluations.into_iter().map(move |(node_id, mut eval)| {
eval.score *= weight;
(doc.doc_id.clone(), node_id, eval)
})
})
.collect();
all_results.sort_by(|a, b| {
b.2.score
.partial_cmp(&a.2.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(self.config.max_total_results);
all_results
}
MergeStrategy::GraphBoosted => {
let mut all_results: Vec<_> = doc_results
.into_iter()
.flat_map(|doc| {
doc.evaluations
.into_iter()
.map(move |(node_id, eval)| (doc.doc_id.clone(), node_id, eval))
})
.collect();
all_results.sort_by(|a, b| {
b.2.score
.partial_cmp(&a.2.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.apply_graph_boost(&mut all_results, 0.15);
all_results.truncate(self.config.max_total_results);
all_results
}
}
}
}
#[async_trait]
impl RetrievalStrategy for CrossDocumentStrategy {
async fn evaluate_node(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> NodeEvaluation {
self.inner.evaluate_node(tree, node_id, context).await
}
async fn evaluate_nodes(
&self,
tree: &DocumentTree,
node_ids: &[NodeId],
context: &RetrievalContext,
) -> Vec<NodeEvaluation> {
self.inner.evaluate_nodes(tree, node_ids, context).await
}
fn name(&self) -> &'static str {
"cross_document"
}
fn capabilities(&self) -> StrategyCapabilities {
let inner_caps = self.inner.capabilities();
StrategyCapabilities {
uses_llm: inner_caps.uses_llm,
uses_embeddings: inner_caps.uses_embeddings,
supports_sufficiency: true,
typical_latency_ms: inner_caps.typical_latency_ms * self.documents.len().min(5) as u64,
}
}
fn suitable_for_complexity(&self, complexity: QueryComplexity) -> bool {
matches!(
complexity,
QueryComplexity::Simple | QueryComplexity::Medium | QueryComplexity::Complex
)
}
fn estimate_cost(&self, node_count: usize) -> super::r#trait::StrategyCost {
let inner_cost = self.inner.estimate_cost(node_count);
super::r#trait::StrategyCost {
llm_calls: inner_cost.llm_calls * self.documents.len().min(self.config.max_documents),
tokens: inner_cost.tokens * self.documents.len().min(self.config.max_documents),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = CrossDocumentConfig::default();
assert_eq!(config.max_documents, 10);
assert_eq!(config.max_results_per_doc, 3);
assert_eq!(config.max_total_results, 10);
assert_eq!(config.merge_strategy, MergeStrategy::TopK);
}
#[test]
fn test_merge_strategy_default() {
let strategy = MergeStrategy::default();
assert!(matches!(strategy, MergeStrategy::TopK));
}
}