use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::join_all;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::retrievers::BaseRetriever;
#[async_trait]
pub trait QueryGenerator: Send + Sync {
async fn generate_queries(&self, query: &str) -> Result<Vec<String>>;
}
pub struct SimpleQueryGenerator {
pub num_queries: usize,
}
impl Default for SimpleQueryGenerator {
fn default() -> Self {
Self { num_queries: 3 }
}
}
impl SimpleQueryGenerator {
pub fn new(num_queries: usize) -> Self {
Self { num_queries }
}
fn perspective_variations(query: &str) -> Vec<String> {
let lower = query.to_lowercase();
let trimmed = query.trim().trim_end_matches('?');
let mut variations = Vec::new();
if lower.starts_with("what is ") || lower.starts_with("what are ") {
let subject = if lower.starts_with("what is ") {
&trimmed["what is ".len()..]
} else {
&trimmed["what are ".len()..]
};
variations.push(format!("Define {}", subject));
variations.push(format!("Explain {}", subject));
variations.push(format!("Describe {}", subject));
}
if lower.starts_with("how does ") || lower.starts_with("how do ") {
let subject = if lower.starts_with("how does ") {
&trimmed["how does ".len()..]
} else {
&trimmed["how do ".len()..]
};
variations.push(format!("Explain how {} works", subject));
variations.push(format!("Mechanism of {}", subject));
}
if lower.starts_with("why ") {
let rest = &trimmed["why ".len()..];
variations.push(format!("Reason for {}", rest));
variations.push(format!("Explain why {}", rest));
}
variations.push(format!("Tell me about {}", trimmed));
variations.push(format!("{} overview", trimmed));
variations.push(format!("Detailed information on {}", trimmed));
variations
}
fn specificity_variations(query: &str) -> Vec<String> {
let trimmed = query.trim().trim_end_matches('?');
vec![
format!("{} in detail", trimmed),
format!("comprehensive guide to {}", trimmed),
format!("{} summary", trimmed),
format!("brief explanation of {}", trimmed),
]
}
}
#[async_trait]
impl QueryGenerator for SimpleQueryGenerator {
async fn generate_queries(&self, query: &str) -> Result<Vec<String>> {
let mut all_variations = Vec::new();
all_variations.extend(Self::perspective_variations(query));
all_variations.extend(Self::specificity_variations(query));
let original_lower = query.to_lowercase();
let mut seen = HashSet::new();
let mut unique = Vec::new();
for v in all_variations {
let key = v.to_lowercase();
if key != original_lower && seen.insert(key) {
unique.push(v);
}
}
unique.truncate(self.num_queries);
Ok(unique)
}
}
pub struct TemplateQueryGenerator {
templates: Vec<String>,
}
impl TemplateQueryGenerator {
pub fn new(templates: Vec<String>) -> Self {
Self { templates }
}
}
#[async_trait]
impl QueryGenerator for TemplateQueryGenerator {
async fn generate_queries(&self, query: &str) -> Result<Vec<String>> {
let variations = self
.templates
.iter()
.map(|t| t.replace("{query}", query))
.collect();
Ok(variations)
}
}
pub struct MultiQueryRetriever {
inner: Arc<dyn BaseRetriever>,
query_generator: Arc<dyn QueryGenerator>,
k: usize,
deduplicate: bool,
include_original: bool,
}
impl MultiQueryRetriever {
pub fn builder(inner: Arc<dyn BaseRetriever>) -> MultiQueryRetrieverBuilder {
MultiQueryRetrieverBuilder {
inner,
query_generator: Arc::new(SimpleQueryGenerator::default()),
k: 4,
deduplicate: true,
include_original: true,
}
}
fn deduplicate_docs(docs: Vec<Document>) -> Vec<Document> {
let mut seen_ids: HashSet<String> = HashSet::new();
let mut seen_content: HashSet<String> = HashSet::new();
let mut result = Vec::new();
for doc in docs {
if let Some(ref id) = doc.id {
if !seen_ids.insert(id.clone()) {
continue;
}
}
if seen_content.insert(doc.page_content.clone()) {
result.push(doc);
}
}
result
}
}
pub struct MultiQueryRetrieverBuilder {
inner: Arc<dyn BaseRetriever>,
query_generator: Arc<dyn QueryGenerator>,
k: usize,
deduplicate: bool,
include_original: bool,
}
impl MultiQueryRetrieverBuilder {
pub fn query_generator(mut self, gen: Arc<dyn QueryGenerator>) -> Self {
self.query_generator = gen;
self
}
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn deduplicate(mut self, deduplicate: bool) -> Self {
self.deduplicate = deduplicate;
self
}
pub fn include_original(mut self, include: bool) -> Self {
self.include_original = include;
self
}
pub fn build(self) -> MultiQueryRetriever {
MultiQueryRetriever {
inner: self.inner,
query_generator: self.query_generator,
k: self.k,
deduplicate: self.deduplicate,
include_original: self.include_original,
}
}
}
#[async_trait]
impl BaseRetriever for MultiQueryRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let mut queries = self.query_generator.generate_queries(query).await?;
if self.include_original {
queries.insert(0, query.to_string());
}
let futures: Vec<_> = queries
.iter()
.map(|q| self.inner.get_relevant_documents(q))
.collect();
let all_results = join_all(futures).await;
let mut all_docs = Vec::new();
for result in all_results {
all_docs.extend(result?);
}
if self.deduplicate {
all_docs = MultiQueryRetriever::deduplicate_docs(all_docs);
}
all_docs.truncate(self.k);
Ok(all_docs)
}
}
pub struct FusionRetriever {
inner: Arc<dyn BaseRetriever>,
query_generator: Arc<dyn QueryGenerator>,
k: usize,
rrf_k: usize,
include_original: bool,
}
impl FusionRetriever {
pub fn builder(inner: Arc<dyn BaseRetriever>) -> FusionRetrieverBuilder {
FusionRetrieverBuilder {
inner,
query_generator: Arc::new(SimpleQueryGenerator::default()),
k: 4,
rrf_k: 60,
include_original: true,
}
}
fn reciprocal_rank_fusion(result_sets: &[Vec<Document>], rrf_k: usize) -> Vec<(Document, f64)> {
let mut score_map: HashMap<String, (Document, f64)> = HashMap::new();
for docs in result_sets {
for (rank, doc) in docs.iter().enumerate() {
let score = 1.0 / (rrf_k as f64 + (rank + 1) as f64);
let key = if let Some(ref id) = doc.id {
id.clone()
} else {
doc.page_content.clone()
};
let entry = score_map.entry(key).or_insert_with(|| (doc.clone(), 0.0));
entry.1 += score;
}
}
let mut scored: Vec<(Document, f64)> = score_map.into_values().collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
}
}
pub struct FusionRetrieverBuilder {
inner: Arc<dyn BaseRetriever>,
query_generator: Arc<dyn QueryGenerator>,
k: usize,
rrf_k: usize,
include_original: bool,
}
impl FusionRetrieverBuilder {
pub fn query_generator(mut self, gen: Arc<dyn QueryGenerator>) -> Self {
self.query_generator = gen;
self
}
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn rrf_k(mut self, rrf_k: usize) -> Self {
self.rrf_k = rrf_k;
self
}
pub fn include_original(mut self, include: bool) -> Self {
self.include_original = include;
self
}
pub fn build(self) -> FusionRetriever {
FusionRetriever {
inner: self.inner,
query_generator: self.query_generator,
k: self.k,
rrf_k: self.rrf_k,
include_original: self.include_original,
}
}
}
#[async_trait]
impl BaseRetriever for FusionRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let mut queries = self.query_generator.generate_queries(query).await?;
if self.include_original {
queries.insert(0, query.to_string());
}
let futures: Vec<_> = queries
.iter()
.map(|q| self.inner.get_relevant_documents(q))
.collect();
let all_results = join_all(futures).await;
let mut result_sets = Vec::with_capacity(all_results.len());
for result in all_results {
result_sets.push(result?);
}
let scored = Self::reciprocal_rank_fusion(&result_sets, self.rrf_k);
Ok(scored
.into_iter()
.take(self.k)
.map(|(doc, _)| doc)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::error::CognisError;
struct MockRetriever {
docs: Vec<Document>,
}
impl MockRetriever {
fn new(contents: &[&str]) -> Self {
Self {
docs: contents.iter().map(|c| Document::new(*c)).collect(),
}
}
fn with_ids(pairs: &[(&str, &str)]) -> Self {
Self {
docs: pairs
.iter()
.map(|(id, content)| Document::new(*content).with_id(*id))
.collect(),
}
}
}
#[async_trait]
impl BaseRetriever for MockRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Ok(self.docs.clone())
}
}
struct EchoRetriever;
#[async_trait]
impl BaseRetriever for EchoRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
Ok(vec![Document::new(format!("result for: {}", query))])
}
}
struct FailingRetriever;
#[async_trait]
impl BaseRetriever for FailingRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Err(CognisError::Other("retriever failed".into()))
}
}
struct FixedQueryGenerator {
queries: Vec<String>,
}
impl FixedQueryGenerator {
fn new(queries: Vec<&str>) -> Self {
Self {
queries: queries.into_iter().map(String::from).collect(),
}
}
}
#[async_trait]
impl QueryGenerator for FixedQueryGenerator {
async fn generate_queries(&self, _query: &str) -> Result<Vec<String>> {
Ok(self.queries.clone())
}
}
#[tokio::test]
async fn test_simple_generator_produces_correct_count() {
let gen = SimpleQueryGenerator::new(3);
let queries = gen.generate_queries("What is Rust?").await.unwrap();
assert_eq!(queries.len(), 3);
}
#[tokio::test]
async fn test_simple_generator_what_is_variations() {
let gen = SimpleQueryGenerator::new(10);
let queries = gen
.generate_queries("What is machine learning?")
.await
.unwrap();
assert!(
queries.iter().any(|q| q.contains("Define")),
"Expected a 'Define' variation, got: {:?}",
queries
);
assert!(
queries.iter().any(|q| q.contains("Explain")),
"Expected an 'Explain' variation, got: {:?}",
queries
);
}
#[tokio::test]
async fn test_simple_generator_excludes_original() {
let gen = SimpleQueryGenerator::new(10);
let original = "What is Rust?";
let queries = gen.generate_queries(original).await.unwrap();
let original_lower = original.to_lowercase();
for q in &queries {
assert_ne!(q.to_lowercase(), original_lower);
}
}
#[tokio::test]
async fn test_simple_generator_deduplicates() {
let gen = SimpleQueryGenerator::new(20);
let queries = gen.generate_queries("test query").await.unwrap();
let unique: HashSet<String> = queries.iter().map(|q| q.to_lowercase()).collect();
assert_eq!(unique.len(), queries.len(), "Variations should be unique");
}
#[tokio::test]
async fn test_simple_generator_why_variations() {
let gen = SimpleQueryGenerator::new(10);
let queries = gen.generate_queries("Why is the sky blue?").await.unwrap();
assert!(
queries.iter().any(|q| q.contains("Reason for")),
"Expected a 'Reason for' variation, got: {:?}",
queries
);
}
#[tokio::test]
async fn test_template_generator_basic() {
let gen = TemplateQueryGenerator::new(vec![
"Rephrase: {query}".to_string(),
"Summarize: {query}".to_string(),
]);
let queries = gen.generate_queries("What is Rust?").await.unwrap();
assert_eq!(queries.len(), 2);
assert_eq!(queries[0], "Rephrase: What is Rust?");
assert_eq!(queries[1], "Summarize: What is Rust?");
}
#[tokio::test]
async fn test_template_generator_multiple_placeholders() {
let gen = TemplateQueryGenerator::new(vec![
"Given '{query}', find documents about {query}".to_string(),
]);
let queries = gen.generate_queries("neural networks").await.unwrap();
assert_eq!(
queries[0],
"Given 'neural networks', find documents about neural networks"
);
}
#[tokio::test]
async fn test_template_generator_empty_templates() {
let gen = TemplateQueryGenerator::new(vec![]);
let queries = gen.generate_queries("test").await.unwrap();
assert!(queries.is_empty());
}
#[tokio::test]
async fn test_multi_query_retriever_basic() {
let base = Arc::new(MockRetriever::new(&["doc1", "doc2", "doc3"]));
let gen = Arc::new(FixedQueryGenerator::new(vec!["q1", "q2"]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let docs = retriever.get_relevant_documents("original").await.unwrap();
assert_eq!(docs.len(), 3);
}
#[tokio::test]
async fn test_multi_query_retriever_with_echo() {
let base: Arc<dyn BaseRetriever> = Arc::new(EchoRetriever);
let gen = Arc::new(FixedQueryGenerator::new(vec!["alt query"]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let docs = retriever
.get_relevant_documents("main query")
.await
.unwrap();
assert_eq!(docs.len(), 2);
assert!(docs
.iter()
.any(|d| d.page_content == "result for: main query"));
assert!(docs
.iter()
.any(|d| d.page_content == "result for: alt query"));
}
#[tokio::test]
async fn test_multi_query_retriever_k_limits_output() {
let base = Arc::new(MockRetriever::new(&["a", "b", "c", "d", "e"]));
let gen = Arc::new(FixedQueryGenerator::new(vec![]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(2)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_multi_query_retriever_dedup_by_id() {
let base = Arc::new(MockRetriever::with_ids(&[
("id1", "content A"),
("id2", "content B"),
]));
let gen = Arc::new(FixedQueryGenerator::new(vec!["q1"]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_multi_query_retriever_no_dedup() {
let base = Arc::new(MockRetriever::new(&["doc1"]));
let gen = Arc::new(FixedQueryGenerator::new(vec!["q1"]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.deduplicate(false)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_multi_query_retriever_exclude_original() {
let base: Arc<dyn BaseRetriever> = Arc::new(EchoRetriever);
let gen = Arc::new(FixedQueryGenerator::new(vec!["alt"]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.include_original(false)
.build();
let docs = retriever.get_relevant_documents("original").await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "result for: alt");
}
#[tokio::test]
async fn test_multi_query_retriever_propagates_error() {
let base: Arc<dyn BaseRetriever> = Arc::new(FailingRetriever);
let gen = Arc::new(FixedQueryGenerator::new(vec![]));
let retriever = MultiQueryRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let result = retriever.get_relevant_documents("q").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_fusion_retriever_basic() {
let base = Arc::new(MockRetriever::new(&["doc1", "doc2"]));
let gen = Arc::new(FixedQueryGenerator::new(vec!["q1"]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 2);
}
#[tokio::test]
async fn test_fusion_retriever_ranks_by_frequency() {
let base = Arc::new(MockRetriever::new(&["common", "unique_to_each"]));
let gen = Arc::new(FixedQueryGenerator::new(vec!["q1", "q2"]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs[0].page_content, "common");
}
#[tokio::test]
async fn test_fusion_retriever_k_limits_output() {
let base = Arc::new(MockRetriever::new(&["a", "b", "c", "d", "e"]));
let gen = Arc::new(FixedQueryGenerator::new(vec![]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.k(3)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 3);
}
#[tokio::test]
async fn test_fusion_retriever_custom_rrf_k() {
let base = Arc::new(MockRetriever::new(&["doc1"]));
let gen = Arc::new(FixedQueryGenerator::new(vec![]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.rrf_k(10)
.k(10)
.build();
let docs = retriever.get_relevant_documents("q").await.unwrap();
assert_eq!(docs.len(), 1);
}
#[tokio::test]
async fn test_fusion_retriever_propagates_error() {
let base: Arc<dyn BaseRetriever> = Arc::new(FailingRetriever);
let gen = Arc::new(FixedQueryGenerator::new(vec![]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.k(10)
.build();
let result = retriever.get_relevant_documents("q").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_fusion_rrf_scoring_correctness() {
let docs_a = vec![Document::new("X"), Document::new("Y")];
let docs_b = vec![Document::new("Y"), Document::new("Z")];
let result_sets = vec![docs_a, docs_b];
let scored = FusionRetriever::reciprocal_rank_fusion(&result_sets, 60);
assert_eq!(scored[0].0.page_content, "Y");
assert_eq!(scored.len(), 3);
let y_score = scored
.iter()
.find(|(d, _)| d.page_content == "Y")
.unwrap()
.1;
let x_score = scored
.iter()
.find(|(d, _)| d.page_content == "X")
.unwrap()
.1;
assert!(y_score > x_score);
}
#[tokio::test]
async fn test_fusion_retriever_exclude_original() {
let base: Arc<dyn BaseRetriever> = Arc::new(EchoRetriever);
let gen = Arc::new(FixedQueryGenerator::new(vec!["alt"]));
let retriever = FusionRetriever::builder(base)
.query_generator(gen)
.k(10)
.include_original(false)
.build();
let docs = retriever.get_relevant_documents("original").await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "result for: alt");
}
#[tokio::test]
async fn test_dedup_by_content() {
let docs = vec![
Document::new("same content"),
Document::new("same content"),
Document::new("different"),
];
let deduped = MultiQueryRetriever::deduplicate_docs(docs);
assert_eq!(deduped.len(), 2);
}
#[tokio::test]
async fn test_dedup_by_id_takes_priority() {
let docs = vec![
Document::new("content A").with_id("id1"),
Document::new("content B").with_id("id1"), Document::new("content C").with_id("id2"),
];
let deduped = MultiQueryRetriever::deduplicate_docs(docs);
assert_eq!(deduped.len(), 2);
assert_eq!(deduped[0].page_content, "content A");
assert_eq!(deduped[1].page_content, "content C");
}
}