use regex::Regex;
use serde::{Deserialize, Serialize};
use tracing::warn;
use super::decision::{InterventionPoint, PilotDecision, RankedCandidate, SearchDirection};
use crate::document::NodeId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
#[serde(default)]
pub ranked_candidates: Vec<CandidateScore>,
#[serde(default)]
pub entry_points: Vec<String>,
#[serde(default)]
pub best_entry_points: Vec<EntryPoint>,
#[serde(default)]
pub selected_nodes: Vec<String>,
#[serde(default)]
pub selected_node: Option<String>,
#[serde(default)]
pub recommended_node: Option<String>,
#[serde(default)]
pub analysis: Option<AnalysisWrapper>,
#[serde(default)]
pub direction: DirectionResponse,
#[serde(default = "default_confidence", deserialize_with = "deserialize_confidence")]
pub confidence: f32,
#[serde(default)]
pub reasoning: String,
}
fn deserialize_confidence<'de, D>(deserializer: D) -> Result<f32, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::Number(n) => {
Ok(n.as_f64().unwrap_or(0.5) as f32)
}
serde_json::Value::String(s) => {
let lower = s.to_lowercase();
let confidence = match lower.as_str() {
"high" | "very high" | "strong" => 0.9,
"medium" | "moderate" => 0.6,
"low" | "weak" => 0.3,
_ => 0.5, };
Ok(confidence)
}
_ => Ok(0.5), }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnalysisWrapper {
#[serde(default)]
pub query: Option<String>,
#[serde(default)]
pub intent: Option<String>,
#[serde(default)]
pub selected_node: Option<String>,
#[serde(default)]
pub selected_nodes: Vec<String>,
#[serde(default)]
pub reasoning: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandidateScore {
pub index: usize,
pub score: f32,
#[serde(default)]
pub reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CandidateInfo {
pub node_id: NodeId,
pub title: String,
pub index: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntryPoint {
#[serde(default)]
pub node_id: Option<usize>,
#[serde(default)]
pub index: Option<usize>,
#[serde(default)]
pub title: Option<String>,
#[serde(default)]
pub relevance_score: Option<f32>,
#[serde(default)]
pub score: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Top3Candidate {
pub node_id: usize,
pub relevance_score: f32,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DirectionResponse {
#[default]
GoDeeper,
ExploreSiblings,
Backtrack,
FoundAnswer,
}
fn default_confidence() -> f32 {
0.5
}
pub struct ResponseParser {
json_block_regex: Regex,
confidence_regex: Regex,
direction_regex: Regex,
}
impl Default for ResponseParser {
fn default() -> Self {
Self::new()
}
}
impl ResponseParser {
pub fn new() -> Self {
Self {
json_block_regex: Regex::new(r"```(?:json)?\s*([\s\S]*?)```").unwrap(),
confidence_regex: Regex::new(r"(?i)confidence[:\s]+([0-9.]+)").unwrap(),
direction_regex: Regex::new(
r"(?i)(go.?deeper|explore.?siblings|backtrack|found.?answer)",
)
.unwrap(),
}
}
pub fn parse(
&self,
response: &str,
candidates: &[CandidateInfo],
point: InterventionPoint,
) -> PilotDecision {
println!("[DEBUG] ResponseParser::parse() - candidates.len()={}", candidates.len());
if let Some(decision) = self.try_json_parse(response, candidates, point) {
println!("[DEBUG] ResponseParser::parse() - JSON parse succeeded, ranked={}", decision.ranked_candidates.len());
return decision;
}
println!("[DEBUG] ResponseParser::parse() - JSON parse failed, trying regex...");
if let Some(decision) = self.try_regex_parse(response, candidates, point) {
println!("[DEBUG] ResponseParser::parse() - Regex parse succeeded, ranked={}", decision.ranked_candidates.len());
return decision;
}
println!("[DEBUG] ResponseParser::parse() - Regex parse failed, using default decision");
self.default_decision(candidates, point)
}
fn try_json_parse(
&self,
response: &str,
candidates: &[CandidateInfo],
point: InterventionPoint,
) -> Option<PilotDecision> {
let json_str = if let Some(caps) = self.json_block_regex.captures(response) {
let extracted = caps.get(1)?.as_str().trim().to_string();
println!("[DEBUG] ResponseParser::try_json_parse() - Found JSON in code block");
extracted
} else {
let start = response.find('{')?;
let end = response.rfind('}')? + 1;
let extracted = response[start..end].to_string();
println!("[DEBUG] ResponseParser::try_json_parse() - Found raw JSON (no code block)");
extracted
};
println!("[DEBUG] ResponseParser::try_json_parse() - Extracted JSON:\n{}", json_str);
let llm_response: LlmResponse = match serde_json::from_str::<LlmResponse>(&json_str) {
Ok(r) => {
println!("[DEBUG] ResponseParser::try_json_parse() - JSON parsed successfully");
println!("[DEBUG] ResponseParser::try_json_parse() - ranked_candidates count: {}", r.ranked_candidates.len());
r
},
Err(e) => {
println!("[DEBUG] ResponseParser::try_json_parse() - JSON parse FAILED: {}", e);
warn!("Failed to parse LLM response as JSON: {}", e);
return None;
}
};
Some(self.llm_response_to_decision(llm_response, candidates, point))
}
fn try_regex_parse(
&self,
response: &str,
candidates: &[CandidateInfo],
point: InterventionPoint,
) -> Option<PilotDecision> {
let confidence = self
.confidence_regex
.captures(response)
.and_then(|caps| caps.get(1)?.as_str().parse::<f32>().ok())
.unwrap_or(0.5)
.clamp(0.0, 1.0);
let direction = self
.direction_regex
.captures(response)
.map(|caps| {
let dir = caps.get(1)?.as_str().to_lowercase();
match dir.as_str() {
d if d.contains("deeper") => Some(SearchDirection::GoDeeper {
reason: String::new(),
}),
d if d.contains("sibling") => Some(SearchDirection::ExploreSiblings {
recommended: vec![],
}),
d if d.contains("backtrack") => Some(SearchDirection::Backtrack {
reason: String::new(),
alternative_branches: vec![],
}),
d if d.contains("found") || d.contains("answer") => {
Some(SearchDirection::FoundAnswer { confidence })
}
_ => None,
}
})
.flatten()
.unwrap_or_else(|| SearchDirection::GoDeeper {
reason: String::new(),
});
let ranked = self.extract_ranked_candidates(response, candidates);
if ranked.is_empty() && candidates.len() > 1 {
return None; }
Some(PilotDecision {
ranked_candidates: ranked,
direction,
confidence,
reasoning: "Extracted via regex".to_string(),
intervention_point: point,
})
}
fn extract_ranked_candidates(
&self,
response: &str,
candidates: &[CandidateInfo],
) -> Vec<RankedCandidate> {
let mut ranked = Vec::new();
let ranking_pattern =
Regex::new(r"(\d+)[.\)]\s*(?:Candidate\s*)?(\d+)[\s:]+(?:score[:\s]*)?([0-9.]+)?")
.unwrap();
for caps in ranking_pattern.captures_iter(response) {
if let Some(index_match) = caps.get(2) {
if let Ok(index) = index_match.as_str().parse::<usize>() {
let score: f32 = caps
.get(3)
.and_then(|m| m.as_str().parse().ok())
.unwrap_or(0.5);
if index < candidates.len() {
ranked.push(RankedCandidate {
node_id: candidates[index].node_id,
score: score.clamp(0.0, 1.0),
reason: None,
});
}
}
}
}
if !ranked.is_empty() {
return ranked;
}
let number_pattern = Regex::new(r"\b(\d+)\b").unwrap();
let mut seen = std::collections::HashSet::new();
for caps in number_pattern.captures_iter(response) {
if let Some(match_1) = caps.get(1) {
if let Ok(idx) = match_1.as_str().parse::<usize>() {
if idx < candidates.len() && seen.insert(idx) {
ranked.push(RankedCandidate {
node_id: candidates[idx].node_id,
score: 1.0 - (ranked.len() as f32 * 0.1), reason: None,
});
}
}
}
if ranked.len() >= candidates.len() {
break;
}
}
ranked
}
fn llm_response_to_decision(
&self,
mut llm_response: LlmResponse,
candidates: &[CandidateInfo],
point: InterventionPoint,
) -> PilotDecision {
println!("[DEBUG] ResponseParser::llm_response_to_decision() - point={:?}", point);
println!("[DEBUG] ResponseParser::llm_response_to_decision() - ranked_candidates.len()={}", llm_response.ranked_candidates.len());
println!("[DEBUG] ResponseParser::llm_response_to_decision() - best_entry_points.len()={}", llm_response.best_entry_points.len());
println!("[DEBUG] ResponseParser::llm_response_to_decision() - entry_points.len()={}", llm_response.entry_points.len());
println!("[DEBUG] ResponseParser::llm_response_to_decision() - selected_nodes.len()={}", llm_response.selected_nodes.len());
println!("[DEBUG] ResponseParser::llm_response_to_decision() - selected_node={:?}", llm_response.selected_node);
println!("[DEBUG] ResponseParser::llm_response_to_decision() - analysis={:?}", llm_response.analysis.as_ref().map(|a| (&a.selected_node, &a.selected_nodes)));
let mut ranked_candidates: Vec<RankedCandidate> = llm_response
.ranked_candidates
.iter()
.filter_map(|cs| {
if cs.index < candidates.len() {
Some(RankedCandidate {
node_id: candidates[cs.index].node_id,
score: cs.score.clamp(0.0, 1.0),
reason: cs.reason.clone(),
})
} else {
None
}
})
.collect();
if ranked_candidates.is_empty() {
for entry in &llm_response.best_entry_points {
let idx = if let Some(nid) = entry.node_id {
if nid > 0 { nid - 1 } else { nid }
} else if let Some(idx) = entry.index {
idx
} else {
continue; };
if idx < candidates.len() {
let score = entry.relevance_score
.or(entry.score)
.unwrap_or(0.5)
/ 5.0; ranked_candidates.push(RankedCandidate {
node_id: candidates[idx].node_id,
score: score.clamp(0.0, 1.0),
reason: entry.title.clone(),
});
println!("[DEBUG] ResponseParser - converted best_entry_point[{}] to ranked_candidate (idx={}, score={:.2})",
idx, idx, score);
}
}
for selected_title in &llm_response.selected_nodes {
for candidate in candidates {
if Self::titles_match(selected_title, &candidate.title) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.9, reason: Some(format!("Title match: {}", selected_title)),
});
println!("[DEBUG] ResponseParser - matched selected_node '{}' to candidate '{}' (index={})",
selected_title, candidate.title, candidate.index);
break; }
}
}
if let Some(ref single_node) = llm_response.selected_node {
for candidate in candidates {
if Self::titles_match(single_node, &candidate.title) {
if !ranked_candidates.iter().any(|rc| rc.node_id == candidate.node_id) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.9,
reason: Some(format!("Title match (singular): {}", single_node)),
});
println!("[DEBUG] ResponseParser - matched selected_node (singular) '{}' to candidate '{}' (index={})",
single_node, candidate.title, candidate.index);
}
break;
}
}
}
if let Some(ref recommended) = llm_response.recommended_node {
for candidate in candidates {
if Self::titles_match(recommended, &candidate.title) {
if !ranked_candidates.iter().any(|rc| rc.node_id == candidate.node_id) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.85,
reason: Some(format!("Recommended node: {}", recommended)),
});
println!("[DEBUG] ResponseParser - matched recommended_node '{}' to candidate '{}' (index={})",
recommended, candidate.title, candidate.index);
}
break;
}
}
}
if let Some(ref analysis) = llm_response.analysis {
for selected_title in &analysis.selected_nodes {
for candidate in candidates {
if Self::titles_match(selected_title, &candidate.title) {
if !ranked_candidates.iter().any(|rc| rc.node_id == candidate.node_id) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.85,
reason: Some(format!("Analysis selected_nodes: {}", selected_title)),
});
println!("[DEBUG] ResponseParser - matched analysis.selected_nodes '{}' to candidate '{}' (index={})",
selected_title, candidate.title, candidate.index);
}
break;
}
}
}
if let Some(ref single_node) = analysis.selected_node {
for candidate in candidates {
if Self::titles_match(single_node, &candidate.title) {
if !ranked_candidates.iter().any(|rc| rc.node_id == candidate.node_id) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.85,
reason: Some(format!("Analysis selected_node: {}", single_node)),
});
println!("[DEBUG] ResponseParser - matched analysis.selected_node (singular) '{}' to candidate '{}' (index={})",
single_node, candidate.title, candidate.index);
}
break;
}
}
}
if llm_response.reasoning.is_empty() {
if let Some(ref r) = analysis.reasoning {
llm_response.reasoning = r.clone();
}
}
}
for entry_title in &llm_response.entry_points {
for candidate in candidates {
if Self::titles_match(entry_title, &candidate.title) {
if !ranked_candidates.iter().any(|rc| rc.node_id == candidate.node_id) {
ranked_candidates.push(RankedCandidate {
node_id: candidate.node_id,
score: 0.8, reason: Some(format!("Entry point: {}", entry_title)),
});
println!("[DEBUG] ResponseParser - matched entry_point '{}' to candidate '{}' (index={})",
entry_title, candidate.title, candidate.index);
}
break;
}
}
}
}
let direction = match llm_response.direction {
DirectionResponse::GoDeeper => SearchDirection::GoDeeper {
reason: llm_response.reasoning.clone(),
},
DirectionResponse::ExploreSiblings => SearchDirection::ExploreSiblings {
recommended: ranked_candidates
.iter()
.take(3)
.map(|c| c.node_id)
.collect(),
},
DirectionResponse::Backtrack => SearchDirection::Backtrack {
reason: llm_response.reasoning.clone(),
alternative_branches: ranked_candidates
.iter()
.take(3)
.map(|c| c.node_id)
.collect(),
},
DirectionResponse::FoundAnswer => SearchDirection::FoundAnswer {
confidence: llm_response.confidence,
},
};
println!("[DEBUG] ResponseParser::llm_response_to_decision() - final ranked_candidates.len()={}", ranked_candidates.len());
PilotDecision {
ranked_candidates,
direction,
confidence: llm_response.confidence.clamp(0.0, 1.0),
reasoning: llm_response.reasoning,
intervention_point: point,
}
}
fn titles_match(llm_title: &str, candidate_title: &str) -> bool {
let llm_lower = llm_title.to_lowercase().trim().to_string();
let candidate_lower = candidate_title.to_lowercase().trim().to_string();
if llm_lower == candidate_lower {
return true;
}
if llm_lower.contains(&candidate_lower) || candidate_lower.contains(&llm_lower) {
return true;
}
let llm_words: std::collections::HashSet<&str> = llm_lower.split_whitespace().collect();
let candidate_words: std::collections::HashSet<&str> = candidate_lower.split_whitespace().collect();
let overlap = llm_words.intersection(&candidate_words).count();
let min_words = llm_words.len().min(candidate_words.len());
if min_words > 0 && overlap as f32 / min_words as f32 >= 0.5 {
return true;
}
false
}
fn default_decision(&self, candidates: &[CandidateInfo], point: InterventionPoint) -> PilotDecision {
let ranked: Vec<RankedCandidate> = candidates
.iter()
.enumerate()
.map(|(i, c)| RankedCandidate {
node_id: c.node_id,
score: 1.0 / (i + 1) as f32, reason: None,
})
.collect();
PilotDecision {
ranked_candidates: ranked,
direction: SearchDirection::GoDeeper {
reason: String::new(),
},
confidence: 0.0,
reasoning: "Default decision (parsing failed)".to_string(),
intervention_point: point,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use indextree::Arena;
fn create_test_node_ids(count: usize) -> Vec<NodeId> {
let mut arena = Arena::new();
let mut ids = Vec::new();
for i in 0..count {
let node = crate::document::TreeNode {
title: format!("Node {}", i),
structure: String::new(),
content: String::new(),
summary: String::new(),
depth: 0,
start_index: 1,
end_index: 1,
start_page: None,
end_page: None,
node_id: None,
physical_index: None,
token_count: None,
references: Vec::new(),
};
ids.push(NodeId(arena.new_node(node)));
}
ids
}
}