use crate::core::ChunkId;
use crate::entity::BidirectionalIndex;
use crate::lightrag::concept_graph::ConceptGraph;
use crate::lightrag::query_refinement::{QueryRefinementConfig, QueryRefiner, RefinedQuery};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub max_depth: usize,
pub min_chunks: usize,
pub max_chunks: usize,
pub concepts_per_depth: usize,
pub use_adaptive_depth: bool,
pub adaptive_quality_threshold: f64,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
max_depth: 3,
min_chunks: 5,
max_chunks: 50,
concepts_per_depth: 5,
use_adaptive_depth: true,
adaptive_quality_threshold: 0.7,
}
}
}
pub struct IterativeDeepeningSearch {
config: SearchConfig,
query_refiner: QueryRefiner,
}
impl IterativeDeepeningSearch {
pub fn new(config: SearchConfig) -> Self {
let refinement_config = QueryRefinementConfig {
max_iterations: config.max_depth,
concepts_per_iteration: config.concepts_per_depth,
min_cooccurrence: 1,
max_total_concepts: config.concepts_per_depth * config.max_depth,
use_relevance_feedback: true,
};
Self {
config,
query_refiner: QueryRefiner::new(refinement_config),
}
}
pub fn default() -> Self {
Self::new(SearchConfig::default())
}
pub fn search(
&self,
query: &str,
concept_graph: &ConceptGraph,
bidirectional_index: &BidirectionalIndex,
) -> SearchResults {
let mut results = SearchResults::new(query.to_string());
let mut current_concepts: HashSet<String> = HashSet::new();
let mut visited_chunks: HashSet<ChunkId> = HashSet::new();
let refined_query =
self.query_refiner
.refine_query(query, concept_graph, bidirectional_index);
if refined_query.initial_concepts.is_empty() {
return results;
}
current_concepts.extend(refined_query.initial_concepts.iter().cloned());
for depth in 0..self.config.max_depth {
let depth_results = self.search_at_depth(
depth,
¤t_concepts,
concept_graph,
bidirectional_index,
&mut visited_chunks,
);
results.add_depth_results(depth, depth_results.clone());
if visited_chunks.len() >= self.config.max_chunks {
results.depth_reached = depth;
results.stop_reason = StopReason::MaxChunksReached;
break;
}
if visited_chunks.len() >= self.config.min_chunks
&& self.config.use_adaptive_depth
&& self.should_stop_early(&results, depth)
{
results.depth_reached = depth;
results.stop_reason = StopReason::QualityThresholdMet;
break;
}
let expanded_concepts = self.expand_concepts(
¤t_concepts,
concept_graph,
self.config.concepts_per_depth,
);
if expanded_concepts.is_empty() {
results.depth_reached = depth;
results.stop_reason = StopReason::NoMoreConcepts;
break;
}
current_concepts.extend(expanded_concepts);
results.depth_reached = depth + 1;
}
results.total_chunks = visited_chunks.len();
results.total_concepts_explored = current_concepts.len();
results.chunk_ids = visited_chunks.into_iter().collect();
results
}
fn search_at_depth(
&self,
depth: usize,
concepts: &HashSet<String>,
concept_graph: &ConceptGraph,
bidirectional_index: &BidirectionalIndex,
visited_chunks: &mut HashSet<ChunkId>,
) -> DepthResults {
let mut depth_results = DepthResults {
depth,
concepts_explored: concepts.len(),
new_chunks_found: 0,
chunk_ids: Vec::new(),
};
for concept in concepts {
let entity_id = crate::core::EntityId::new(self.normalize_concept(concept));
let chunks = bidirectional_index.get_chunks_for_entity(&entity_id);
for chunk_id in chunks {
if visited_chunks.insert(chunk_id.clone()) {
depth_results.new_chunks_found += 1;
depth_results.chunk_ids.push(chunk_id);
}
}
}
depth_results
}
fn expand_concepts(
&self,
current_concepts: &HashSet<String>,
concept_graph: &ConceptGraph,
max_expand: usize,
) -> Vec<String> {
let mut related_concepts: HashMap<String, f64> = HashMap::new();
for concept in current_concepts {
let related = concept_graph.get_related_concepts(concept, max_expand);
for related_concept in related {
if !current_concepts.contains(&related_concept) {
let score =
self.score_concept(&related_concept, current_concepts, concept_graph);
*related_concepts.entry(related_concept).or_insert(0.0) += score;
}
}
}
let mut sorted_concepts: Vec<_> = related_concepts.into_iter().collect();
sorted_concepts.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sorted_concepts
.into_iter()
.take(max_expand)
.map(|(c, _)| c)
.collect()
}
fn score_concept(
&self,
concept: &str,
current_concepts: &HashSet<String>,
concept_graph: &ConceptGraph,
) -> f64 {
if let Some(concept_data) = concept_graph.concepts.get(concept) {
let mut score = (concept_data.frequency as f64).ln() + 1.0;
let mut connections = 0;
for current in current_concepts {
let has_relation = concept_graph.relations.iter().any(|rel| {
(rel.source == *concept && rel.target == *current)
|| (rel.source == *current && rel.target == *concept)
});
if has_relation {
connections += 1;
}
}
score *= 1.0 + (connections as f64 * 0.5);
score
} else {
0.0
}
}
fn should_stop_early(&self, results: &SearchResults, current_depth: usize) -> bool {
if current_depth == 0 {
return false; }
let quality = if results.total_concepts_explored > 0 {
results.total_chunks as f64 / results.total_concepts_explored as f64
} else {
0.0
};
quality >= self.config.adaptive_quality_threshold
}
fn normalize_concept(&self, concept: &str) -> String {
concept
.to_lowercase()
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect::<String>()
.replace(' ', "_")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResults {
pub query: String,
pub depth_reached: usize,
pub total_chunks: usize,
pub total_concepts_explored: usize,
pub depth_results: Vec<DepthResults>,
pub chunk_ids: Vec<ChunkId>,
pub stop_reason: StopReason,
}
impl SearchResults {
fn new(query: String) -> Self {
Self {
query,
depth_reached: 0,
total_chunks: 0,
total_concepts_explored: 0,
depth_results: Vec::new(),
chunk_ids: Vec::new(),
stop_reason: StopReason::MaxDepthReached,
}
}
fn add_depth_results(&mut self, depth: usize, results: DepthResults) {
self.depth_results.push(results);
}
pub fn chunk_count(&self) -> usize {
self.total_chunks
}
pub fn concept_count(&self) -> usize {
self.total_concepts_explored
}
pub fn get_depth_results(&self, depth: usize) -> Option<&DepthResults> {
self.depth_results.iter().find(|r| r.depth == depth)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DepthResults {
pub depth: usize,
pub concepts_explored: usize,
pub new_chunks_found: usize,
pub chunk_ids: Vec<ChunkId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
MaxDepthReached,
MaxChunksReached,
QualityThresholdMet,
NoMoreConcepts,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lightrag::concept_graph::{ConceptExtractor, ConceptGraphBuilder};
#[test]
fn test_iterative_deepening_basic() {
let config = SearchConfig {
max_depth: 2,
min_chunks: 1,
max_chunks: 10,
concepts_per_depth: 3,
use_adaptive_depth: false,
adaptive_quality_threshold: 0.7,
};
let search = IterativeDeepeningSearch::new(config);
let mut builder = ConceptGraphBuilder::new();
builder.add_document_concepts("doc1", vec!["machine".to_string(), "learning".to_string()]);
builder.add_chunk_concepts("chunk1", vec!["machine".to_string()]);
let concept_graph = builder.build();
let bidirectional_index = BidirectionalIndex::new();
let results = search.search("machine", &concept_graph, &bidirectional_index);
assert_eq!(results.query, "machine");
assert!(results.depth_reached <= 2);
}
#[test]
fn test_search_config_default() {
let config = SearchConfig::default();
assert_eq!(config.max_depth, 3);
assert_eq!(config.min_chunks, 5);
assert!(config.use_adaptive_depth);
}
#[test]
fn test_stop_reasons() {
assert_ne!(StopReason::MaxDepthReached, StopReason::QualityThresholdMet);
assert_eq!(StopReason::NoMoreConcepts, StopReason::NoMoreConcepts);
}
}