use async_trait::async_trait;
use std::collections::HashSet;
use std::sync::LazyLock;
use crate::document::{DocumentTree, NodeId};
use super::{InterventionPoint, PilotConfig, PilotDecision};
static EMPTY_VISITED: LazyLock<HashSet<NodeId>> = LazyLock::new(HashSet::new);
#[derive(Debug, Clone)]
pub struct SearchState<'a> {
pub tree: &'a DocumentTree,
pub query: &'a str,
pub path: &'a [NodeId],
pub candidates: &'a [NodeId],
pub visited: &'a HashSet<NodeId>,
pub depth: usize,
pub iteration: usize,
pub best_score: f32,
pub is_backtracking: bool,
pub step_reasons: Option<&'a [Option<String>]>,
}
impl<'a> SearchState<'a> {
pub fn new(
tree: &'a DocumentTree,
query: &'a str,
path: &'a [NodeId],
candidates: &'a [NodeId],
visited: &'a HashSet<NodeId>,
) -> Self {
Self {
tree,
query,
path,
candidates,
visited,
depth: path.len(),
iteration: 0,
best_score: 0.0,
is_backtracking: false,
step_reasons: None,
}
}
pub fn for_start(tree: &'a DocumentTree, query: &'a str) -> Self {
Self {
tree,
query,
path: &[],
candidates: &[],
visited: &EMPTY_VISITED,
depth: 0,
iteration: 0,
best_score: 0.0,
is_backtracking: false,
step_reasons: None,
}
}
pub fn is_at_root(&self) -> bool {
self.path.is_empty()
}
pub fn is_fork_point(&self) -> bool {
self.candidates.len() > 1
}
pub fn current_node(&self) -> Option<NodeId> {
self.path.last().copied()
}
}
#[async_trait]
pub trait Pilot: Send + Sync {
fn name(&self) -> &str;
fn should_intervene(&self, state: &SearchState<'_>) -> bool;
async fn decide(&self, state: &SearchState<'_>) -> PilotDecision;
async fn guide_start(
&self,
tree: &DocumentTree,
query: &str,
start_node: NodeId,
) -> Option<PilotDecision>;
async fn guide_backtrack(&self, state: &SearchState<'_>) -> Option<PilotDecision>;
async fn binary_prune(&self, state: &SearchState<'_>) -> Option<Vec<NodeId>>;
fn config(&self) -> &PilotConfig;
fn is_active(&self) -> bool {
true
}
fn reset(&self);
fn as_any(&self) -> &dyn std::any::Any {
&()
}
}
pub trait PilotExt: Pilot {
fn can_intervene(&self, state: &SearchState<'_>) -> bool {
self.is_active() && self.should_intervene(state)
}
fn intervention_point(&self, state: &SearchState<'_>) -> InterventionPoint {
if state.is_at_root() || state.iteration == 0 {
InterventionPoint::Start
} else if state.is_backtracking {
InterventionPoint::Backtrack
} else if state.is_fork_point() {
InterventionPoint::Fork
} else {
InterventionPoint::Evaluate
}
}
}
impl<T: Pilot + ?Sized> PilotExt for T {}