use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::document::{DocumentTree, NodeId, ReasoningIndex, RetrievalIndex};
use crate::graph::DocumentGraph;
use crate::retrieval::cache::{HotNodeTracker, ReasoningCache};
use crate::retrieval::pilot::Pilot;
use crate::retrieval::pipeline::budget::RetrievalBudgetController;
use crate::retrieval::types::{
NavigationDecision, QueryComplexity, ReasoningChain, ReasoningStep, RetrieveOptions,
RetrieveResponse, SearchPath, StageName, StrategyPreference, SufficiencyLevel,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchAlgorithm {
Greedy,
Beam,
Mcts,
}
impl Default for SearchAlgorithm {
fn default() -> Self {
Self::Beam
}
}
impl SearchAlgorithm {
pub fn name(&self) -> &'static str {
match self {
Self::Greedy => "greedy",
Self::Beam => "beam",
Self::Mcts => "mcts",
}
}
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub beam_width: usize,
pub max_depth: usize,
pub min_score: f32,
pub max_iterations: usize,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
beam_width: 3,
max_depth: 10,
min_score: 0.1,
max_iterations: 5,
}
}
}
#[derive(Debug, Clone)]
pub struct CandidateNode {
pub node_id: NodeId,
pub score: f32,
pub depth: usize,
pub is_leaf: bool,
}
impl CandidateNode {
pub fn new(node_id: NodeId, score: f32, depth: usize, is_leaf: bool) -> Self {
Self {
node_id,
score,
depth,
is_leaf,
}
}
}
#[derive(Debug, Clone)]
pub struct StageResult {
pub stage: String,
pub success: bool,
pub duration_ms: u64,
pub message: Option<String>,
}
impl StageResult {
pub fn success(stage: impl Into<String>) -> Self {
Self {
stage: stage.into(),
success: true,
duration_ms: 0,
message: None,
}
}
pub fn failure(stage: impl Into<String>, message: impl Into<String>) -> Self {
Self {
stage: stage.into(),
success: false,
duration_ms: 0,
message: Some(message.into()),
}
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = ms;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct RetrievalMetrics {
pub analyze_time_ms: u64,
pub plan_time_ms: u64,
pub search_time_ms: u64,
pub evaluate_time_ms: u64,
pub total_time_ms: u64,
pub nodes_visited: usize,
pub llm_calls: usize,
pub tokens_used: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub search_iterations: usize,
pub backtracks: usize,
}
impl RetrievalMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn merge(&mut self, other: &RetrievalMetrics) {
self.analyze_time_ms += other.analyze_time_ms;
self.plan_time_ms += other.plan_time_ms;
self.search_time_ms += other.search_time_ms;
self.evaluate_time_ms += other.evaluate_time_ms;
self.nodes_visited += other.nodes_visited;
self.llm_calls += other.llm_calls;
self.tokens_used = other.tokens_used; self.cache_hits += other.cache_hits;
self.cache_misses += other.cache_misses;
self.search_iterations = other.search_iterations; self.backtracks += other.backtracks;
}
}
pub struct PipelineContext {
pub query: String,
pub tree: Arc<DocumentTree>,
pub retrieval_index: Option<RetrievalIndex>,
pub options: RetrieveOptions,
pub pilot: Option<Arc<dyn Pilot>>,
pub budget_controller: RetrievalBudgetController,
pub reasoning_cache: Arc<ReasoningCache>,
pub reasoning_index: Option<Arc<ReasoningIndex>>,
pub hot_tracker: Option<Arc<HotNodeTracker>>,
pub document_graph: Option<Arc<DocumentGraph>>,
pub complexity: Option<QueryComplexity>,
pub keywords: Vec<String>,
pub target_sections: Vec<String>,
pub resolved_path_hints: Vec<(String, NodeId)>,
pub decomposition: Option<crate::retrieval::decompose::DecompositionResult>,
pub selected_strategy: Option<StrategyPreference>,
pub selected_algorithm: Option<SearchAlgorithm>,
pub search_config: Option<SearchConfig>,
pub candidates: Vec<CandidateNode>,
pub search_paths: Vec<SearchPath>,
pub reasoning_chain: ReasoningChain,
pub search_iterations: usize,
pub sufficiency: SufficiencyLevel,
pub accumulated_content: String,
pub token_count: usize,
pub prev_candidate_fingerprint: Option<u64>,
pub result: Option<RetrieveResponse>,
pub stage_results: HashMap<String, StageResult>,
pub metrics: RetrievalMetrics,
pub stage_start: Option<Instant>,
}
impl PipelineContext {
pub fn new(
tree: Arc<DocumentTree>,
query: impl Into<String>,
options: RetrieveOptions,
) -> Self {
let retrieval_index = Some(tree.build_retrieval_index());
let budget_controller = RetrievalBudgetController::new(options.max_tokens);
Self {
query: query.into(),
tree,
retrieval_index,
options,
pilot: None,
budget_controller,
reasoning_cache: Arc::new(ReasoningCache::new()),
reasoning_index: None,
hot_tracker: None,
document_graph: None,
complexity: None,
keywords: Vec::new(),
target_sections: Vec::new(),
resolved_path_hints: Vec::new(),
decomposition: None,
selected_strategy: None,
selected_algorithm: None,
search_config: None,
candidates: Vec::new(),
search_paths: Vec::new(),
reasoning_chain: ReasoningChain::new(),
search_iterations: 0,
sufficiency: SufficiencyLevel::default(),
accumulated_content: String::new(),
token_count: 0,
prev_candidate_fingerprint: None,
result: None,
stage_results: HashMap::new(),
metrics: RetrievalMetrics::default(),
stage_start: None,
}
}
pub fn with_pilot(
tree: Arc<DocumentTree>,
query: impl Into<String>,
options: RetrieveOptions,
pilot: Option<Arc<dyn Pilot>>,
) -> Self {
let mut ctx = Self::new(tree, query, options);
ctx.pilot = pilot;
ctx
}
pub fn set_pilot(&mut self, pilot: Option<Arc<dyn Pilot>>) {
self.pilot = pilot;
}
pub fn with_reasoning_index(mut self, index: ReasoningIndex) -> Self {
self.reasoning_index = Some(Arc::new(index));
self
}
pub fn with_hot_tracker(mut self, tracker: HotNodeTracker) -> Self {
self.hot_tracker = Some(Arc::new(tracker));
self
}
pub fn with_document_graph(mut self, graph: Arc<DocumentGraph>) -> Self {
self.document_graph = Some(graph);
self
}
pub fn pilot(&self) -> Option<&dyn Pilot> {
self.pilot.as_deref()
}
pub fn start_stage(&mut self) {
self.stage_start = Some(Instant::now());
}
pub fn end_stage(&mut self, stage_name: &str, success: bool, message: Option<String>) {
let duration_ms = self
.stage_start
.map(|s| s.elapsed().as_millis() as u64)
.unwrap_or(0);
let result = StageResult {
stage: stage_name.to_string(),
success,
duration_ms,
message,
};
match stage_name {
"analyze" => self.metrics.analyze_time_ms += duration_ms,
"plan" => self.metrics.plan_time_ms += duration_ms,
"search" => self.metrics.search_time_ms += duration_ms,
"evaluate" => self.metrics.evaluate_time_ms += duration_ms,
_ => {}
}
self.stage_results.insert(stage_name.to_string(), result);
self.stage_start = None;
}
pub fn can_search_more(&self) -> bool {
self.search_iterations < self.options.max_iterations
}
pub fn increment_search_iteration(&mut self) {
self.search_iterations += 1;
self.metrics.search_iterations = self.search_iterations;
}
pub fn increment_backtrack(&mut self) {
self.metrics.backtracks += 1;
}
fn candidate_fingerprint(&self) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for c in &self.candidates {
format!("{:?}", c.node_id).hash(&mut hasher);
}
hasher.finish()
}
pub fn check_candidates_stagnant(&mut self) -> bool {
let fp = self.candidate_fingerprint();
let stagnant = self.prev_candidate_fingerprint == Some(fp);
self.prev_candidate_fingerprint = Some(fp);
stagnant
}
pub fn is_token_limit_reached(&self) -> bool {
self.token_count >= self.options.max_tokens
}
pub fn token_utilization(&self) -> f32 {
if self.options.max_tokens == 0 {
0.0
} else {
(self.token_count as f32 / self.options.max_tokens as f32).min(1.0)
}
}
pub fn push_reasoning_step(&mut self, step: ReasoningStep) {
self.reasoning_chain.push(step);
}
pub fn record_reasoning(
&mut self,
stage: StageName,
reasoning: impl Into<String>,
decision: NavigationDecision,
) {
self.push_reasoning_step(ReasoningStep {
stage,
node_id: None,
title: None,
score: 0.0,
decision,
depth: 0,
reasoning: reasoning.into(),
candidates: Vec::new(),
strategy_used: None,
llm_call: None,
references_followed: Vec::new(),
});
}
pub fn finalize(self) -> RetrieveResponse {
self.result.unwrap_or_else(|| RetrieveResponse {
results: Vec::new(),
content: self.accumulated_content,
confidence: 0.0,
is_sufficient: self.sufficiency == SufficiencyLevel::Sufficient,
strategy_used: self
.selected_strategy
.map(|s| format!("{:?}", s))
.unwrap_or_else(|| "unknown".to_string()),
complexity: self.complexity.unwrap_or_default(),
reasoning_chain: self.reasoning_chain,
tokens_used: self.token_count,
})
}
}