use serde::{Deserialize, Serialize};
use crate::document::NodeId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PilotDecision {
pub ranked_candidates: Vec<RankedCandidate>,
pub direction: SearchDirection,
pub confidence: f32,
pub reasoning: String,
pub intervention_point: InterventionPoint,
}
impl Default for PilotDecision {
fn default() -> Self {
Self {
ranked_candidates: Vec::new(),
direction: SearchDirection::GoDeeper {
reason: "Default decision".to_string(),
},
confidence: 0.0,
reasoning: "No specific guidance available".to_string(),
intervention_point: InterventionPoint::Evaluate,
}
}
}
impl PilotDecision {
pub fn new(
ranked_candidates: Vec<RankedCandidate>,
direction: SearchDirection,
confidence: f32,
reasoning: String,
) -> Self {
Self {
ranked_candidates,
direction,
confidence,
reasoning,
intervention_point: InterventionPoint::Fork,
}
}
pub fn preserve_order(candidates: &[NodeId]) -> Self {
Self {
ranked_candidates: candidates
.iter()
.enumerate()
.map(|(i, &id)| RankedCandidate {
node_id: id,
score: 1.0 - (i as f32 * 0.1),
reason: None,
})
.collect(),
direction: SearchDirection::GoDeeper {
reason: "Preserving original order".to_string(),
},
confidence: 0.0,
reasoning: "No intervention performed".to_string(),
intervention_point: InterventionPoint::Fork,
}
}
pub fn has_candidates(&self) -> bool {
!self.ranked_candidates.is_empty()
}
pub fn top_candidate(&self) -> Option<&RankedCandidate> {
self.ranked_candidates.first()
}
pub fn ranked_node_ids(&self) -> Vec<NodeId> {
self.ranked_candidates.iter().map(|c| c.node_id).collect()
}
pub fn found_answer(&self) -> bool {
matches!(self.direction, SearchDirection::FoundAnswer { .. })
}
pub fn needs_backtrack(&self) -> bool {
matches!(self.direction, SearchDirection::Backtrack { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankedCandidate {
pub node_id: NodeId,
pub score: f32,
pub reason: Option<String>,
}
impl RankedCandidate {
pub fn new(node_id: NodeId, score: f32) -> Self {
Self {
node_id,
score,
reason: None,
}
}
pub fn with_reason(node_id: NodeId, score: f32, reason: impl Into<String>) -> Self {
Self {
node_id,
score,
reason: Some(reason.into()),
}
}
pub fn reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchDirection {
GoDeeper {
reason: String,
},
ExploreSiblings {
recommended: Vec<NodeId>,
},
Backtrack {
reason: String,
alternative_branches: Vec<NodeId>,
},
JumpTo {
target: NodeId,
reason: String,
},
FoundAnswer {
confidence: f32,
},
}
impl SearchDirection {
pub fn go_deeper(reason: impl Into<String>) -> Self {
Self::GoDeeper {
reason: reason.into(),
}
}
pub fn backtrack(reason: impl Into<String>, alternatives: Vec<NodeId>) -> Self {
Self::Backtrack {
reason: reason.into(),
alternative_branches: alternatives,
}
}
pub fn jump_to(target: NodeId, reason: impl Into<String>) -> Self {
Self::JumpTo {
target,
reason: reason.into(),
}
}
pub fn found_answer(confidence: f32) -> Self {
Self::FoundAnswer { confidence }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum InterventionPoint {
Start,
#[default]
Fork,
Backtrack,
Evaluate,
Prune,
}
impl InterventionPoint {
pub fn name(&self) -> &'static str {
match self {
Self::Start => "start",
Self::Fork => "fork",
Self::Backtrack => "backtrack",
Self::Evaluate => "evaluate",
Self::Prune => "prune",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use indextree::Arena;
fn create_test_node_ids(count: usize) -> Vec<NodeId> {
let mut arena = Arena::new();
let mut ids = Vec::new();
for i in 0..count {
let node = crate::document::TreeNode {
title: format!("Node {}", i),
structure: String::new(),
content: String::new(),
summary: String::new(),
depth: 0,
start_index: 1,
end_index: 1,
start_page: None,
end_page: None,
node_id: None,
physical_index: None,
token_count: None,
references: Vec::new(),
};
ids.push(NodeId(arena.new_node(node)));
}
ids
}
#[test]
fn test_pilot_decision_default() {
let decision = PilotDecision::default();
assert!(!decision.has_candidates());
assert!(decision.top_candidate().is_none());
assert!(!decision.found_answer());
assert!(!decision.needs_backtrack());
}
#[test]
fn test_pilot_decision_preserve_order() {
let node_ids = create_test_node_ids(2);
let decision = PilotDecision::preserve_order(&node_ids);
assert!(decision.has_candidates());
assert_eq!(decision.ranked_candidates.len(), 2);
assert_eq!(decision.confidence, 0.0);
}
#[test]
fn test_ranked_candidate() {
let node_ids = create_test_node_ids(1);
let candidate = RankedCandidate::new(node_ids[0], 0.8);
assert_eq!(candidate.score, 0.8);
assert!(candidate.reason.is_none());
let candidate_with_reason = RankedCandidate::with_reason(node_ids[0], 0.9, "test reason");
assert_eq!(candidate_with_reason.score, 0.9);
assert_eq!(
candidate_with_reason.reason,
Some("test reason".to_string())
);
}
#[test]
fn test_search_direction_constructors() {
let deeper = SearchDirection::go_deeper("test");
assert!(matches!(deeper, SearchDirection::GoDeeper { .. }));
let found = SearchDirection::found_answer(0.9);
assert!(matches!(
found,
SearchDirection::FoundAnswer { confidence: 0.9 }
));
}
#[test]
fn test_intervention_point() {
assert_eq!(InterventionPoint::Start.name(), "start");
assert_eq!(InterventionPoint::Fork.name(), "fork");
assert_eq!(InterventionPoint::Backtrack.name(), "backtrack");
assert_eq!(InterventionPoint::Evaluate.name(), "evaluate");
}
}