use async_trait::async_trait;
use crate::{Document, GroundingStrategy, KeywordRetriever, Retriever};
pub struct HybridRetriever<E: Retriever> {
keyword: KeywordRetriever,
embedding: E,
keyword_weight: f64,
embedding_weight: f64,
}
impl<E: Retriever> HybridRetriever<E> {
#[must_use]
pub fn new(keyword: KeywordRetriever, embedding: E) -> Self {
Self {
keyword,
embedding,
keyword_weight: 0.3,
embedding_weight: 0.7,
}
}
#[must_use]
pub fn with_weights(mut self, keyword_weight: f64, embedding_weight: f64) -> Self {
assert!(keyword_weight >= 0.0, "keyword_weight must be non-negative");
assert!(
embedding_weight >= 0.0,
"embedding_weight must be non-negative"
);
self.keyword_weight = keyword_weight;
self.embedding_weight = embedding_weight;
self
}
}
#[async_trait]
impl<E: Retriever> Retriever for HybridRetriever<E> {
async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
let candidate = (limit * 5).max(10);
let (kw_docs, emb_docs) = tokio::join!(
self.keyword.retrieve(query, candidate),
self.embedding.retrieve(query, candidate),
);
let kw_docs = kw_docs.unwrap_or_default();
let emb_docs = emb_docs.unwrap_or_default();
let kw_max = kw_docs
.iter()
.map(|d| d.score)
.fold(f64::NEG_INFINITY, f64::max);
let emb_max = emb_docs
.iter()
.map(|d| d.score)
.fold(f64::NEG_INFINITY, f64::max);
let mut scores: std::collections::HashMap<String, (f64, &Document)> =
std::collections::HashMap::new();
for doc in &kw_docs {
let norm = if kw_max > 0.0 {
doc.score / kw_max
} else {
0.0
};
scores
.entry(doc.id.clone())
.and_modify(|(s, _)| *s += self.keyword_weight * norm)
.or_insert((self.keyword_weight * norm, doc));
}
for doc in &emb_docs {
let norm = if emb_max > 0.0 {
doc.score / emb_max
} else {
0.0
};
scores
.entry(doc.id.clone())
.and_modify(|(s, _)| *s += self.embedding_weight * norm)
.or_insert((self.embedding_weight * norm, doc));
}
let mut combined: Vec<Document> = scores
.into_values()
.map(|(combined_score, doc)| {
let mut d = doc.clone();
d.score = combined_score;
d
})
.collect();
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(limit);
Ok(combined)
}
}
pub struct CitationStrategy;
impl GroundingStrategy for CitationStrategy {
fn ground(&self, documents: &[Document]) -> String {
if documents.is_empty() {
return String::new();
}
let mut ctx = String::from("Context:\n\n");
for (i, doc) in documents.iter().enumerate() {
use std::fmt::Write;
let _ = writeln!(ctx, "[{}] {} (Source: {})", i + 1, doc.content, doc.id);
}
ctx
}
}
pub struct ContextWindowStrategy {
max_tokens: usize,
}
impl ContextWindowStrategy {
#[must_use]
pub fn new(max_tokens: usize) -> Self {
assert!(max_tokens > 0, "max_tokens must be > 0");
Self { max_tokens }
}
}
impl GroundingStrategy for ContextWindowStrategy {
fn ground(&self, documents: &[Document]) -> String {
if documents.is_empty() {
return String::new();
}
let char_budget = self.max_tokens * 4;
let mut ctx = String::from("Context:\n\n");
let mut used = ctx.len();
for (i, doc) in documents.iter().enumerate() {
use std::fmt::Write;
let entry = format!("[{}] {}\n\n", i + 1, doc.content);
if used + entry.len() > char_budget {
let available = char_budget.saturating_sub(used + 10); if available > 20 {
let truncated: String = doc.content.chars().take(available).collect();
let _ = write!(ctx, "[{}] {}…\n\n", i + 1, truncated);
}
break;
}
ctx.push_str(&entry);
used += entry.len();
}
ctx
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Document;
fn kw_retriever_with(docs: Vec<(&str, &str)>) -> KeywordRetriever {
let mut r = KeywordRetriever::new();
for (id, content) in docs {
r.add(Document::new(id, content));
}
r
}
#[tokio::test]
async fn test_hybrid_returns_from_keyword_source() {
let kw = kw_retriever_with(vec![("k1", "Rust programming"), ("k2", "Python code")]);
let emb = KeywordRetriever::new();
let hybrid = HybridRetriever::new(kw, emb);
let results = hybrid.retrieve("Rust", 5).await.unwrap();
assert!(!results.is_empty(), "expected keyword results");
assert!(results.iter().any(|d| d.id == "k1"));
}
#[tokio::test]
async fn test_hybrid_merges_both_sources() {
let kw = kw_retriever_with(vec![("k1", "Rust keyword hit")]);
let emb = kw_retriever_with(vec![("e1", "Rust embedding hit")]);
let hybrid = HybridRetriever::new(kw, emb);
let results = hybrid.retrieve("Rust hit", 10).await.unwrap();
let ids: Vec<_> = results.iter().map(|d| d.id.as_str()).collect();
assert!(
ids.contains(&"k1") || ids.contains(&"e1"),
"should contain results from both: {ids:?}"
);
}
#[tokio::test]
async fn test_hybrid_respects_limit() {
let kw = kw_retriever_with(vec![("k1", "Rust a"), ("k2", "Rust b"), ("k3", "Rust c")]);
let emb = kw_retriever_with(vec![("e1", "Rust d"), ("e2", "Rust e")]);
let hybrid = HybridRetriever::new(kw, emb);
let results = hybrid.retrieve("Rust", 2).await.unwrap();
assert!(results.len() <= 2);
}
#[tokio::test]
async fn test_hybrid_combined_score_sorted_desc() {
let kw = kw_retriever_with(vec![("k1", "Rust programming")]);
let emb = kw_retriever_with(vec![("e1", "Rust embedding search")]);
let hybrid = HybridRetriever::new(kw, emb);
let results = hybrid.retrieve("Rust", 10).await.unwrap();
for window in results.windows(2) {
assert!(window[0].score >= window[1].score);
}
}
#[test]
fn test_citation_strategy_format() {
let docs = vec![
Document::new("paper-42", "Important finding."),
Document::new("blog-7", "Another insight."),
];
let ctx = CitationStrategy.ground(&docs);
assert!(ctx.contains("[1]"));
assert!(ctx.contains("Source: paper-42"));
assert!(ctx.contains("[2]"));
assert!(ctx.contains("Source: blog-7"));
}
#[test]
fn test_citation_strategy_empty() {
assert!(CitationStrategy.ground(&[]).is_empty());
}
#[test]
fn test_context_window_small_docs_fit() {
let docs = vec![
Document::new("d1", "Short."),
Document::new("d2", "Also short."),
];
let ctx = ContextWindowStrategy::new(1000).ground(&docs);
assert!(ctx.contains("[1]"));
assert!(ctx.contains("[2]"));
}
#[test]
fn test_context_window_truncates_large_doc() {
let large = "word ".repeat(500); let docs = vec![Document::new("big", &large)];
let strategy = ContextWindowStrategy::new(50); let ctx = strategy.ground(&docs);
assert!(
ctx.chars().count() < 500,
"expected truncation, got {} chars",
ctx.chars().count()
);
}
#[test]
fn test_context_window_empty_docs() {
assert!(ContextWindowStrategy::new(100).ground(&[]).is_empty());
}
}