use async_trait::async_trait;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::document::{DocumentTree, NodeId};
use crate::llm::{LlmClient, LlmExecutor};
use crate::memo::{MemoKey, MemoStore, MemoValue};
use crate::utils::fingerprint::Fingerprint;
use super::budget::BudgetController;
use super::builder::ContextBuilder;
use super::config::PilotConfig;
use super::decision::{InterventionPoint, PilotDecision};
use super::feedback::{FeedbackRecord, FeedbackStore, PilotLearner};
use super::parser::ResponseParser;
use super::prompts::PromptBuilder;
use super::r#trait::{Pilot, SearchState};
pub struct LlmPilot {
client: LlmClient,
executor: Option<Arc<LlmExecutor>>,
config: PilotConfig,
budget: BudgetController,
pipeline_budget:
parking_lot::RwLock<Option<Arc<crate::retrieval::pipeline::RetrievalBudgetController>>>,
context_builder: ContextBuilder,
prompt_builder: PromptBuilder,
response_parser: ResponseParser,
learner: Option<Arc<PilotLearner>>,
memo_store: Option<MemoStore>,
}
impl std::fmt::Debug for LlmPilot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmPilot")
.field("config", &self.config)
.field("budget", &self.budget.usage())
.finish()
}
}
impl LlmPilot {
pub fn new(client: LlmClient, config: PilotConfig) -> Self {
let budget = BudgetController::new(config.budget.clone());
let token_budget = config.budget.max_tokens_per_call;
Self {
client,
executor: None,
config,
budget,
pipeline_budget: parking_lot::RwLock::new(None),
context_builder: ContextBuilder::new(token_budget),
prompt_builder: PromptBuilder::new(),
response_parser: ResponseParser::new(),
learner: None,
memo_store: None,
}
}
pub fn with_executor(executor: LlmExecutor, config: PilotConfig) -> Self {
let budget = BudgetController::new(config.budget.clone());
let token_budget = config.budget.max_tokens_per_call;
let client = LlmClient::for_model(&executor.config().model);
Self {
client,
executor: Some(Arc::new(executor)),
config,
budget,
pipeline_budget: parking_lot::RwLock::new(None),
context_builder: ContextBuilder::new(token_budget),
prompt_builder: PromptBuilder::new(),
response_parser: ResponseParser::new(),
learner: None,
memo_store: None,
}
}
pub fn with_shared_executor(executor: Arc<LlmExecutor>, config: PilotConfig) -> Self {
let budget = BudgetController::new(config.budget.clone());
let token_budget = config.budget.max_tokens_per_call;
let client = LlmClient::for_model(&executor.config().model);
Self {
client,
executor: Some(executor),
config,
budget,
pipeline_budget: parking_lot::RwLock::new(None),
context_builder: ContextBuilder::new(token_budget),
prompt_builder: PromptBuilder::new(),
response_parser: ResponseParser::new(),
learner: None,
memo_store: None,
}
}
pub fn with_builders(
client: LlmClient,
config: PilotConfig,
context_builder: ContextBuilder,
prompt_builder: PromptBuilder,
) -> Self {
let budget = BudgetController::new(config.budget.clone());
Self {
client,
executor: None,
config,
budget,
pipeline_budget: parking_lot::RwLock::new(None),
context_builder,
prompt_builder,
response_parser: ResponseParser::new(),
learner: None,
memo_store: None,
}
}
pub fn with_executor_mut(mut self, executor: LlmExecutor) -> Self {
self.executor = Some(Arc::new(executor));
self
}
pub fn with_learner(mut self, learner: Arc<PilotLearner>) -> Self {
self.learner = Some(learner);
self
}
pub fn with_feedback_store(mut self, store: Arc<FeedbackStore>) -> Self {
self.learner = Some(Arc::new(PilotLearner::new(store)));
self
}
pub fn with_memo_store(mut self, store: MemoStore) -> Self {
self.memo_store = Some(store);
self
}
pub fn set_pipeline_budget(
&self,
budget: Arc<crate::retrieval::pipeline::RetrievalBudgetController>,
) {
*self.pipeline_budget.write() = Some(budget);
}
pub fn has_executor(&self) -> bool {
self.executor.is_some()
}
pub fn has_learner(&self) -> bool {
self.learner.is_some()
}
pub fn has_memo_store(&self) -> bool {
self.memo_store.is_some()
}
pub fn learner(&self) -> Option<&PilotLearner> {
self.learner.as_deref()
}
pub fn memo_store(&self) -> Option<&MemoStore> {
self.memo_store.as_ref()
}
pub fn record_feedback(&self, record: FeedbackRecord) {
if let Some(ref learner) = self.learner {
let decision_id = record.decision_id;
learner.store().record(record);
debug!("Recorded feedback for decision {:?}", decision_id);
}
}
fn compute_cache_key(
&self,
context: &super::builder::PilotContext,
_point: InterventionPoint,
) -> Option<MemoKey> {
let _store = self.memo_store.as_ref()?;
let context_str = context.to_string();
let context_fp = Fingerprint::from_str(&context_str);
let query_fp = Fingerprint::from_str(&context.query_section);
Some(MemoKey::pilot_decision(&context_fp, &query_fp))
}
fn has_budget(&self) -> bool {
if let Some(ref pb) = *self.pipeline_budget.read() {
if pb.status().should_stop() {
return false;
}
}
self.budget.can_call()
}
fn scores_are_close(&self, state: &SearchState<'_>) -> bool {
state.candidates.len() >= 2
&& state.best_score < self.config.intervention.score_gap_threshold
}
fn get_intervention_point(&self, state: &SearchState<'_>) -> InterventionPoint {
if state.is_at_root() || state.iteration == 0 {
InterventionPoint::Start
} else if state.is_backtracking {
InterventionPoint::Backtrack
} else if state.is_fork_point() {
InterventionPoint::Fork
} else {
InterventionPoint::Evaluate
}
}
async fn call_llm(
&self,
point: InterventionPoint,
context: &super::builder::PilotContext,
candidates: &[super::parser::CandidateInfo],
) -> PilotDecision {
if let Some(ref store) = self.memo_store {
if let Some(cache_key) = self.compute_cache_key(context, point) {
if let Some(cached) = store.get(&cache_key) {
if let MemoValue::PilotDecision(decision_value) = cached {
debug!("Memo cache hit for pilot decision at {:?}", point);
let decision =
self.cached_value_to_decision(decision_value, candidates, point);
return decision;
}
}
}
}
let prompt = self.prompt_builder.build(point, context);
if !self.budget.can_afford(prompt.estimated_tokens) {
warn!(
"Budget cannot afford LLM call (estimated: {} tokens)",
prompt.estimated_tokens
);
return self.default_decision(candidates, point);
}
let adjustment = if let Some(ref learner) = self.learner {
let query_hash = context.query_hash();
let path_hash = context.path_hash();
Some(learner.get_adjustment(point, query_hash, path_hash))
} else {
None
};
if let Some(ref adj) = adjustment {
if adj.skip_intervention {
debug!("Learner suggests skipping intervention (low historical accuracy)");
return self.default_decision(candidates, point);
}
}
debug!(
"Calling LLM for {:?} point (estimated: {} tokens)",
point, prompt.estimated_tokens
);
let result = if let Some(ref executor) = self.executor {
executor.complete(&prompt.system, &prompt.user).await
} else {
self.client.complete(&prompt.system, &prompt.user).await
};
match result {
Ok(response) => {
let output_tokens = self.estimate_tokens(&response);
let total_tokens = prompt.estimated_tokens + output_tokens;
self.budget
.record_usage(prompt.estimated_tokens, output_tokens, 0);
if let Some(ref pb) = *self.pipeline_budget.read() {
pb.record_tokens(total_tokens);
}
let mut decision = self.response_parser.parse(&response, candidates, point);
if let Some(ref adj) = adjustment {
decision.confidence =
(decision.confidence + adj.confidence_delta as f32).clamp(0.0, 1.0);
debug!(
"Applied learner adjustment: confidence_delta={:.2}, algorithm_weight={:.2}",
adj.confidence_delta, adj.algorithm_weight
);
}
info!(
"LLM decision: direction={:?}, confidence={:.2}, candidates={}",
std::mem::discriminant(&decision.direction),
decision.confidence,
decision.ranked_candidates.len()
);
if let Some(ref store) = self.memo_store {
if let Some(cache_key) = self.compute_cache_key(context, point) {
let decision_value = self.decision_to_cached_value(&decision);
let tokens_saved = prompt.estimated_tokens as u64 + output_tokens as u64;
store.put_with_tokens(
cache_key,
MemoValue::PilotDecision(decision_value),
tokens_saved,
);
debug!("Memo cache stored for pilot decision at {:?}", point);
}
}
decision
}
Err(e) => {
warn!("LLM call failed: {}", e);
self.default_decision(candidates, point)
}
}
}
fn decision_to_cached_value(
&self,
decision: &PilotDecision,
) -> crate::memo::PilotDecisionValue {
crate::memo::PilotDecisionValue {
selected_idx: decision
.ranked_candidates
.first()
.map(|c| c.node_id.0.into())
.unwrap_or(0),
confidence: decision.confidence,
reasoning: decision.reasoning.clone(),
}
}
fn cached_value_to_decision(
&self,
value: crate::memo::PilotDecisionValue,
candidates: &[super::parser::CandidateInfo],
point: InterventionPoint,
) -> PilotDecision {
let ranked = candidates
.iter()
.enumerate()
.map(|(i, c)| super::decision::RankedCandidate {
node_id: c.node_id,
score: if i == value.selected_idx {
1.0
} else {
0.5 / (i + 1) as f32
},
reason: None,
})
.collect();
PilotDecision {
ranked_candidates: ranked,
direction: super::decision::SearchDirection::GoDeeper {
reason: "Cached decision".to_string(),
},
confidence: value.confidence,
reasoning: value.reasoning,
intervention_point: point,
}
}
fn default_decision(
&self,
candidates: &[super::parser::CandidateInfo],
point: InterventionPoint,
) -> PilotDecision {
let ranked = candidates
.iter()
.enumerate()
.map(|(i, c)| super::decision::RankedCandidate {
node_id: c.node_id,
score: 1.0 / (i + 1) as f32,
reason: None,
})
.collect();
PilotDecision {
ranked_candidates: ranked,
direction: super::decision::SearchDirection::GoDeeper {
reason: "Default decision (LLM unavailable)".to_string(),
},
confidence: 0.0,
reasoning: "LLM call failed or budget exhausted".to_string(),
intervention_point: point,
}
}
fn estimate_tokens(&self, text: &str) -> usize {
let char_count = text.chars().count();
let chinese_count = text
.chars()
.filter(|c| ('\u{4E00}'..='\u{9FFF}').contains(c))
.count();
let english_count = char_count - chinese_count;
(chinese_count as f32 / 1.5 + english_count as f32 / 4.0).ceil() as usize
}
}
#[async_trait]
impl Pilot for LlmPilot {
fn name(&self) -> &str {
"llm_pilot"
}
fn should_intervene(&self, state: &SearchState<'_>) -> bool {
if !self.config.mode.uses_llm() {
return false;
}
if !self.has_budget() {
debug!("Budget exhausted, skipping intervention");
return false;
}
let intervention = &self.config.intervention;
if state.candidates.len() > intervention.fork_threshold {
debug!(
"Intervening: fork point with {} candidates",
state.candidates.len()
);
return true;
}
if self.scores_are_close(state) {
debug!("Intervening: scores are close");
return true;
}
if intervention.is_low_confidence(state.best_score) {
debug!(
"Intervening: low confidence (best_score={:.2})",
state.best_score
);
return true;
}
if state.is_backtracking && self.config.guide_at_backtrack {
debug!("Intervening: backtracking");
return true;
}
false
}
async fn decide(&self, state: &SearchState<'_>) -> PilotDecision {
let point = self.get_intervention_point(state);
let context = self.context_builder.build(state);
let candidate_info: Vec<super::parser::CandidateInfo> = state
.candidates
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
state
.tree
.get(node_id)
.map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
})
.collect();
self.call_llm(point, &context, &candidate_info).await
}
async fn guide_start(
&self,
tree: &DocumentTree,
query: &str,
start_node: NodeId,
) -> Option<PilotDecision> {
if !self.config.guide_at_start {
return None;
}
if !self.has_budget() {
debug!("Budget exhausted, cannot guide start");
return None;
}
let context = self.context_builder.build_start_context(tree, query);
let node_ids = tree.children(start_node);
if node_ids.is_empty() {
debug!("Start node has no children, no guidance needed");
return None;
}
let candidates: Vec<super::parser::CandidateInfo> = node_ids
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
tree.get(node_id).map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
})
.collect();
let decision = self
.call_llm(InterventionPoint::Start, &context, &candidates)
.await;
info!(
"Pilot start guidance: confidence={:.2}, candidates={}",
decision.confidence,
decision.ranked_candidates.len()
);
Some(decision)
}
async fn guide_backtrack(&self, state: &SearchState<'_>) -> Option<PilotDecision> {
if !self.config.guide_at_backtrack {
return None;
}
if !self.has_budget() {
return None;
}
let context = self
.context_builder
.build_backtrack_context(state, state.path);
let candidates: Vec<super::parser::CandidateInfo> = state
.candidates
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
state
.tree
.get(node_id)
.map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
})
.collect();
Some(
self.call_llm(InterventionPoint::Backtrack, &context, &candidates)
.await,
)
}
async fn binary_prune(&self, state: &SearchState<'_>) -> Option<Vec<NodeId>> {
if !self.has_budget() {
debug!("Budget exhausted, cannot binary prune");
return None;
}
let context = self.context_builder.build(state);
let candidate_info: Vec<super::parser::CandidateInfo> = state
.candidates
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
state
.tree
.get(node_id)
.map(|node| super::parser::CandidateInfo {
node_id,
title: node.title.clone(),
index: i,
})
})
.collect();
let decision = self
.call_llm(InterventionPoint::Prune, &context, &candidate_info)
.await;
let relevant: Vec<NodeId> = decision
.ranked_candidates
.iter()
.filter(|c| c.score > 0.5)
.map(|c| c.node_id)
.collect();
if relevant.is_empty() {
debug!("Binary prune: LLM marked no candidates as relevant");
return None;
}
debug!(
"Binary prune: {} of {} candidates marked relevant",
relevant.len(),
state.candidates.len()
);
Some(relevant)
}
fn config(&self) -> &PilotConfig {
&self.config
}
fn is_active(&self) -> bool {
self.config.mode.uses_llm() && self.has_budget()
}
fn reset(&self) {
self.budget.reset();
*self.pipeline_budget.write() = None;
debug!("LlmPilot reset for new query");
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::document::NodeId;
use indextree::Arena;
fn create_test_node_ids(count: usize) -> Vec<NodeId> {
let mut arena = Arena::new();
let mut ids = Vec::new();
for i in 0..count {
let node = crate::document::TreeNode {
title: format!("Node {}", i),
structure: String::new(),
content: String::new(),
summary: String::new(),
depth: 0,
start_index: 1,
end_index: 1,
start_page: None,
end_page: None,
node_id: None,
physical_index: None,
token_count: None,
references: Vec::new(),
};
ids.push(NodeId(arena.new_node(node)));
}
ids
}
#[test]
fn test_llm_pilot_creation() {
let client = LlmClient::for_model("gpt-4o-mini");
let config = PilotConfig::default();
let pilot = LlmPilot::new(client, config);
assert_eq!(pilot.name(), "llm_pilot");
assert!(pilot.is_active());
}
#[test]
fn test_llm_pilot_algorithm_only_mode() {
let client = LlmClient::for_model("gpt-4o-mini");
let config = PilotConfig::algorithm_only();
let pilot = LlmPilot::new(client, config);
assert!(!pilot.config().mode.uses_llm());
}
#[test]
fn test_llm_pilot_budget_exhausted() {
let client = LlmClient::for_model("gpt-4o-mini");
let config = PilotConfig::default();
let pilot = LlmPilot::new(client, config);
pilot.budget.record_usage(3000, 500, 0);
assert!(!pilot.has_budget());
}
#[test]
fn test_reset() {
let client = LlmClient::for_model("gpt-4o-mini");
let config = PilotConfig::default();
let pilot = LlmPilot::new(client, config);
pilot.budget.record_usage(100, 50, 0);
assert!(pilot.budget.total_tokens() > 0);
pilot.reset();
assert_eq!(pilot.budget.total_tokens(), 0);
}
}