use std::collections::HashSet;
use std::sync::Arc;
use crate::core::error::GraphRAGError;
use crate::lightrag::keyword_extraction::{DualLevelKeywords, KeywordExtractor};
use crate::retrieval::SearchResult;
#[derive(Debug, Clone)]
pub struct DualRetrievalResults {
pub high_level_chunks: Vec<SearchResult>,
pub low_level_chunks: Vec<SearchResult>,
pub merged_chunks: Vec<SearchResult>,
pub keywords: DualLevelKeywords,
}
#[derive(Debug, Clone)]
pub struct DualRetrievalConfig {
pub high_level_weight: f32,
pub low_level_weight: f32,
pub merge_strategy: MergeStrategy,
}
impl Default for DualRetrievalConfig {
fn default() -> Self {
Self {
high_level_weight: 0.6, low_level_weight: 0.4, merge_strategy: MergeStrategy::Interleave,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MergeStrategy {
Interleave,
HighFirst,
LowFirst,
Weighted,
}
#[async_trait::async_trait]
pub trait SemanticSearcher: Send + Sync {
async fn search(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>, GraphRAGError>;
}
pub struct DualLevelRetriever {
keyword_extractor: Arc<KeywordExtractor>,
high_level_store: Arc<dyn SemanticSearcher>, low_level_store: Arc<dyn SemanticSearcher>, config: DualRetrievalConfig,
}
impl DualLevelRetriever {
pub fn new(
keyword_extractor: Arc<KeywordExtractor>,
high_level_store: Arc<dyn SemanticSearcher>,
low_level_store: Arc<dyn SemanticSearcher>,
config: DualRetrievalConfig,
) -> Self {
Self {
keyword_extractor,
high_level_store,
low_level_store,
config,
}
}
pub async fn retrieve(
&self,
query: &str,
top_k: usize,
) -> Result<DualRetrievalResults, GraphRAGError> {
let keywords = self.keyword_extractor.extract_with_fallback(query).await?;
log::debug!(
"Dual-level keywords - High: {:?}, Low: {:?}",
keywords.high_level,
keywords.low_level
);
let (high_results, low_results) = tokio::join!(
self.retrieve_high_level(&keywords.high_level, top_k),
self.retrieve_low_level(&keywords.low_level, top_k)
);
let high_level_chunks = high_results?;
let low_level_chunks = low_results?;
let merged_chunks = self.merge_results(&high_level_chunks, &low_level_chunks, top_k)?;
log::info!(
"Dual retrieval: {} high-level, {} low-level → {} merged",
high_level_chunks.len(),
low_level_chunks.len(),
merged_chunks.len()
);
Ok(DualRetrievalResults {
high_level_chunks,
low_level_chunks,
merged_chunks,
keywords,
})
}
async fn retrieve_high_level(
&self,
keywords: &[String],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
if keywords.is_empty() {
log::debug!("No high-level keywords, skipping high-level retrieval");
return Ok(Vec::new());
}
let combined_query = keywords.join(" ");
log::debug!("High-level query: '{}'", combined_query);
let results = self.high_level_store.search(&combined_query, top_k).await?;
Ok(results)
}
async fn retrieve_low_level(
&self,
keywords: &[String],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
if keywords.is_empty() {
log::debug!("No low-level keywords, skipping low-level retrieval");
return Ok(Vec::new());
}
let combined_query = keywords.join(" ");
log::debug!("Low-level query: '{}'", combined_query);
let results = self.low_level_store.search(&combined_query, top_k).await?;
Ok(results)
}
fn merge_results(
&self,
high: &[SearchResult],
low: &[SearchResult],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
match self.config.merge_strategy {
MergeStrategy::Interleave => self.merge_interleave(high, low, top_k),
MergeStrategy::HighFirst => self.merge_concat(high, low, top_k),
MergeStrategy::LowFirst => self.merge_concat(low, high, top_k),
MergeStrategy::Weighted => self.merge_weighted(high, low, top_k),
}
}
fn merge_interleave(
&self,
high: &[SearchResult],
low: &[SearchResult],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
let mut seen_ids = HashSet::new();
let mut merged = Vec::new();
let mut high_iter = high.iter();
let mut low_iter = low.iter();
let mut use_high = true;
while merged.len() < top_k {
let chunk = if use_high {
high_iter.next()
} else {
low_iter.next()
};
match chunk {
Some(c) => {
if seen_ids.insert(c.id.clone()) {
merged.push(c.clone());
}
},
None => {
if high_iter.len() == 0 && low_iter.len() == 0 {
break;
}
},
}
use_high = !use_high;
}
Ok(merged)
}
fn merge_concat(
&self,
first: &[SearchResult],
second: &[SearchResult],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
let mut seen_ids = HashSet::new();
let mut merged = Vec::new();
for chunk in first {
if merged.len() >= top_k {
break;
}
if seen_ids.insert(chunk.id.clone()) {
merged.push(chunk.clone());
}
}
for chunk in second {
if merged.len() >= top_k {
break;
}
if seen_ids.insert(chunk.id.clone()) {
merged.push(chunk.clone());
}
}
Ok(merged)
}
fn merge_weighted(
&self,
high: &[SearchResult],
low: &[SearchResult],
top_k: usize,
) -> Result<Vec<SearchResult>, GraphRAGError> {
let mut all_chunks: Vec<(SearchResult, f32)> = Vec::new();
for chunk in high {
let weighted_score = chunk.score * self.config.high_level_weight;
all_chunks.push((chunk.clone(), weighted_score));
}
for chunk in low {
let weighted_score = chunk.score * self.config.low_level_weight;
all_chunks.push((chunk.clone(), weighted_score));
}
all_chunks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut seen_ids = HashSet::new();
let merged: Vec<SearchResult> = all_chunks
.into_iter()
.filter_map(|(chunk, _score)| {
if seen_ids.insert(chunk.id.clone()) {
Some(chunk)
} else {
None
}
})
.take(top_k)
.collect();
Ok(merged)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::retrieval::ResultType;
fn create_test_result(id: &str, score: f32) -> SearchResult {
SearchResult {
id: id.to_string(),
content: format!("Content of {}", id),
score,
result_type: ResultType::Chunk,
entities: Vec::new(),
source_chunks: Vec::new(),
}
}
#[test]
fn test_merge_strategies_basic() {
let config = DualRetrievalConfig::default();
assert_eq!(config.merge_strategy, MergeStrategy::Interleave);
assert!(config.high_level_weight > 0.0);
assert!(config.low_level_weight > 0.0);
}
#[test]
fn test_search_result_creation() {
let result = create_test_result("test_1", 0.95);
assert_eq!(result.id, "test_1");
assert_eq!(result.score, 0.95);
assert_eq!(result.result_type, ResultType::Chunk);
}
}