use async_trait::async_trait;
use serde::Deserialize;
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, QueryComplexity};
use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities};
use crate::document::{DocumentTree, NodeId, TocView};
use crate::llm::LlmClient;
#[derive(Debug, Clone, Deserialize)]
struct NodeScore {
index: usize,
relevance: u8,
action: String,
#[serde(default)]
reasoning: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct BatchResponse {
#[serde(default)]
reasoning: String,
nodes: Vec<NodeScore>,
}
#[derive(Debug, Clone, Deserialize)]
struct NavigationResponse {
relevance: u8,
action: String,
#[serde(default)]
reasoning: Option<String>,
}
#[derive(Clone)]
pub struct LlmStrategy {
client: LlmClient,
system_prompt: String,
batch_system_prompt: String,
toc_view: TocView,
include_toc: bool,
}
impl LlmStrategy {
pub fn new(client: LlmClient) -> Self {
Self {
client,
system_prompt: Self::default_system_prompt(),
batch_system_prompt: Self::default_batch_system_prompt(),
toc_view: TocView::new(),
include_toc: true,
}
}
pub fn with_defaults() -> Self {
Self::new(LlmClient::with_defaults())
}
pub fn with_system_prompt(mut self, prompt: String) -> Self {
self.system_prompt = prompt;
self
}
pub fn with_toc_context(mut self, include: bool) -> Self {
self.include_toc = include;
self
}
fn default_system_prompt() -> String {
r#"You are a document navigation assistant. Your task is to help find the most relevant sections in a document tree.
Given a query and document context (Table of Contents + current node), determine:
1. The relevance of this node (0-100)
2. The best action: "answer" (this node contains the answer), "explore" (check children), or "skip" (not relevant)
Respond in JSON format:
{"relevance": <0-100>, "action": "<answer|explore|skip>", "reasoning": "<brief explanation>"}
Be concise and focused on finding the most relevant information."#.to_string()
}
fn default_batch_system_prompt() -> String {
r#"You are a document navigation assistant. Score the relevance of multiple document sections against a user query.
CRITICAL: Respond with ONLY valid JSON (no markdown code blocks).
Response format:
{
"reasoning": "Brief analysis of the query",
"nodes": [
{"index": 1, "relevance": 85, "action": "answer", "reason": "Why relevant"},
{"index": 2, "relevance": 30, "action": "skip", "reason": "Why not relevant"}
]
}
Rules:
- index: MUST be the number from [N] brackets in the input
- relevance: 0-100 (how relevant this section is to the query)
- action: one of "answer", "explore", "skip"
- Score ALL provided nodes, not just the top ones
- Be concise in reasons"#.to_string()
}
fn build_prompt(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> String {
let node = tree.get(node_id);
let children = tree.children(node_id);
let node_info = if let Some(n) = node {
let summary = if n.summary.is_empty() {
&n.content[..200.min(n.content.len())]
} else {
&n.summary
};
format!(
"Title: {}\nSummary: {}\nDepth: {}\nChildren: {}",
n.title,
summary,
n.depth,
children.len()
)
} else {
"Node not found".to_string()
};
let toc_context = if self.include_toc {
let toc = self.toc_view.generate_from(tree, node_id);
let toc_markdown = self.toc_view.format_markdown(&toc);
let toc_preview: String = toc_markdown.chars().take(1000).collect();
format!(
"\n\nDocument ToC (from this node):\n```\n{}\n```\n",
toc_preview
)
} else {
String::new()
};
format!(
"Query: {}\n{}Current Node:\n{}\n\nWhat is the relevance and action?",
context.query, toc_context, node_info
)
}
fn build_batch_prompt(
&self,
tree: &DocumentTree,
node_ids: &[NodeId],
context: &RetrievalContext,
) -> String {
let node_descriptions: Vec<String> = node_ids
.iter()
.enumerate()
.filter_map(|(i, &node_id)| {
let node = tree.get(node_id)?;
let children = tree.children(node_id);
let summary = if node.summary.is_empty() {
let end = 200.min(node.content.len());
&node.content[..end]
} else {
&node.summary
};
Some(format!(
"[{}] Title: \"{}\"\n Summary: \"{}\"\n Depth: {}, Children: {}",
i + 1,
node.title,
summary,
node.depth,
children.len()
))
})
.collect();
let nodes_str = node_descriptions.join("\n\n");
let toc_context = if self.include_toc && !node_ids.is_empty() {
let toc = self.toc_view.generate_from(tree, node_ids[0]);
let toc_markdown = self.toc_view.format_markdown(&toc);
let toc_preview: String = toc_markdown.chars().take(800).collect();
format!("\n\nDocument ToC:\n{}\n", toc_preview)
} else {
String::new()
};
format!(
"USER QUERY: {}\n{}SECTIONS TO SCORE ({} entries):\n{}\n\nScore ALL sections. Respond with ONLY the JSON object:",
context.query,
toc_context,
node_ids.len(),
nodes_str
)
}
fn parse_response(
&self,
response: &str,
tree: &DocumentTree,
node_id: NodeId,
) -> NodeEvaluation {
if let Ok(parsed) = serde_json::from_str::<NavigationResponse>(response) {
let score = (parsed.relevance as f32 / 100.0).clamp(0.0, 1.0);
let decision = match parsed.action.to_lowercase().as_str() {
"answer" => NavigationDecision::ThisIsTheAnswer,
"explore" => {
if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
}
}
_ => NavigationDecision::Skip,
};
return NodeEvaluation {
score,
decision,
reasoning: parsed.reasoning,
};
}
let score = response
.lines()
.find_map(|line| {
let lower = line.to_lowercase();
if lower.contains("relevance") || lower.contains("score") {
lower
.split(|c: char| !c.is_numeric() && c != '.')
.filter_map(|s| s.parse::<f32>().ok())
.filter(|&s| (0.0..=100.0).contains(&s))
.map(|v| v / 100.0)
.next()
} else {
None
}
})
.unwrap_or(0.5);
NodeEvaluation {
score,
decision: if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
},
reasoning: Some(format!(
"Parsed from response: {}...",
&response[..100.min(response.len())]
)),
}
}
fn parse_batch_response(
&self,
response: &str,
tree: &DocumentTree,
node_ids: &[NodeId],
) -> Vec<NodeEvaluation> {
if let Ok(batch) = serde_json::from_str::<BatchResponse>(response) {
let mut evaluations = vec![
NodeEvaluation {
score: 0.3,
decision: NavigationDecision::ExploreMore,
reasoning: Some("Not scored by LLM (batch fallback)".to_string()),
};
node_ids.len()
];
for node_score in batch.nodes {
let idx = node_score.index.saturating_sub(1);
if idx < node_ids.len() {
let node_id = node_ids[idx];
let score = (node_score.relevance as f32 / 100.0).clamp(0.0, 1.0);
let decision = match node_score.action.to_lowercase().as_str() {
"answer" => NavigationDecision::ThisIsTheAnswer,
"explore" => {
if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
}
}
_ => NavigationDecision::Skip,
};
evaluations[idx] = NodeEvaluation {
score,
decision,
reasoning: node_score.reasoning,
};
}
}
return evaluations;
}
tracing::warn!(
"Failed to parse batch LLM response, using defaults for {} nodes",
node_ids.len()
);
node_ids
.iter()
.map(|&node_id| NodeEvaluation {
score: 0.5,
decision: if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
},
reasoning: Some("Batch parse fallback".to_string()),
})
.collect()
}
}
#[async_trait]
impl RetrievalStrategy for LlmStrategy {
async fn evaluate_node(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> NodeEvaluation {
let prompt = self.build_prompt(tree, node_id, context);
match self.client.complete(&self.system_prompt, &prompt).await {
Ok(response) => self.parse_response(&response, tree, node_id),
Err(e) => {
tracing::warn!("LLM evaluation failed: {}", e);
NodeEvaluation {
score: 0.5,
decision: if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
},
reasoning: Some(format!("LLM error: {}", e)),
}
}
}
}
async fn evaluate_nodes(
&self,
tree: &DocumentTree,
node_ids: &[NodeId],
context: &RetrievalContext,
) -> Vec<NodeEvaluation> {
if node_ids.is_empty() {
return Vec::new();
}
if node_ids.len() == 1 {
return vec![self.evaluate_node(tree, node_ids[0], context).await];
}
let prompt = self.build_batch_prompt(tree, node_ids, context);
match self
.client
.complete(&self.batch_system_prompt, &prompt)
.await
{
Ok(response) => self.parse_batch_response(&response, tree, node_ids),
Err(e) => {
tracing::warn!(
"Batch LLM evaluation failed ({}), falling back to single evaluation: {}",
node_ids.len(),
e
);
let mut results = Vec::with_capacity(node_ids.len());
for &node_id in node_ids {
results.push(self.evaluate_node(tree, node_id, context).await);
}
results
}
}
}
fn name(&self) -> &'static str {
"llm"
}
fn capabilities(&self) -> StrategyCapabilities {
StrategyCapabilities {
uses_llm: true,
uses_embeddings: false,
supports_sufficiency: true,
typical_latency_ms: 500,
}
}
fn suitable_for_complexity(&self, complexity: QueryComplexity) -> bool {
matches!(
complexity,
QueryComplexity::Medium | QueryComplexity::Complex
)
}
}