use async_trait::async_trait;
use serde::Deserialize;
use super::super::RetrievalContext;
use super::super::types::{NavigationDecision, QueryComplexity};
use super::r#trait::{NodeEvaluation, RetrievalStrategy, StrategyCapabilities};
use crate::document::{DocumentTree, NodeId, TocView};
use crate::llm::LlmClient;
#[derive(Debug, Clone, Deserialize)]
struct NavigationResponse {
relevance: u8,
action: String,
#[serde(default)]
reasoning: Option<String>,
}
#[derive(Clone)]
pub struct LlmStrategy {
client: LlmClient,
system_prompt: String,
toc_view: TocView,
include_toc: bool,
}
impl LlmStrategy {
pub fn new(client: LlmClient) -> Self {
Self {
client,
system_prompt: Self::default_system_prompt(),
toc_view: TocView::new(),
include_toc: true,
}
}
pub fn with_defaults() -> Self {
Self::new(LlmClient::with_defaults())
}
pub fn with_system_prompt(mut self, prompt: String) -> Self {
self.system_prompt = prompt;
self
}
pub fn with_toc_context(mut self, include: bool) -> Self {
self.include_toc = include;
self
}
fn default_system_prompt() -> String {
r#"You are a document navigation assistant. Your task is to help find the most relevant sections in a document tree.
Given a query and document context (Table of Contents + current node), determine:
1. The relevance of this node (0-100)
2. The best action: "answer" (this node contains the answer), "explore" (check children), or "skip" (not relevant)
Respond in JSON format:
{"relevance": <0-100>, "action": "<answer|explore|skip>", "reasoning": "<brief explanation>"}
Be concise and focused on finding the most relevant information."#.to_string()
}
fn build_prompt(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> String {
let node = tree.get(node_id);
let children = tree.children(node_id);
let node_info = if let Some(n) = node {
let summary = if n.summary.is_empty() {
&n.content[..200.min(n.content.len())]
} else {
&n.summary
};
format!(
"Title: {}\nSummary: {}\nDepth: {}\nChildren: {}",
n.title,
summary,
n.depth,
children.len()
)
} else {
"Node not found".to_string()
};
let toc_context = if self.include_toc {
let toc = self.toc_view.generate_from(tree, node_id);
let toc_markdown = self.toc_view.format_markdown(&toc);
let toc_preview: String = toc_markdown.chars().take(1000).collect();
format!(
"\n\nDocument ToC (from this node):\n```\n{}\n```\n",
toc_preview
)
} else {
String::new()
};
format!(
"Query: {}\n{}Current Node:\n{}\n\nWhat is the relevance and action?",
context.query, toc_context, node_info
)
}
fn parse_response(
&self,
response: &str,
tree: &DocumentTree,
node_id: NodeId,
) -> NodeEvaluation {
if let Ok(parsed) = serde_json::from_str::<NavigationResponse>(response) {
let score = (parsed.relevance as f32 / 100.0).clamp(0.0, 1.0);
let decision = match parsed.action.to_lowercase().as_str() {
"answer" => NavigationDecision::ThisIsTheAnswer,
"explore" => {
if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
}
}
_ => NavigationDecision::Skip,
};
return NodeEvaluation {
score,
decision,
reasoning: parsed.reasoning,
};
}
let score = response
.lines()
.find_map(|line| {
let lower = line.to_lowercase();
if lower.contains("relevance") || lower.contains("score") {
lower
.split(|c: char| !c.is_numeric() && c != '.')
.filter_map(|s| s.parse::<f32>().ok())
.filter(|&s| (0.0..=100.0).contains(&s))
.map(|v| v / 100.0)
.next()
} else {
None
}
})
.unwrap_or(0.5);
NodeEvaluation {
score,
decision: if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
},
reasoning: Some(format!(
"Parsed from response: {}...",
&response[..100.min(response.len())]
)),
}
}
}
#[async_trait]
impl RetrievalStrategy for LlmStrategy {
async fn evaluate_node(
&self,
tree: &DocumentTree,
node_id: NodeId,
context: &RetrievalContext,
) -> NodeEvaluation {
let prompt = self.build_prompt(tree, node_id, context);
match self.client.complete(&self.system_prompt, &prompt).await {
Ok(response) => self.parse_response(&response, tree, node_id),
Err(e) => {
tracing::warn!("LLM evaluation failed: {}", e);
NodeEvaluation {
score: 0.5,
decision: if tree.is_leaf(node_id) {
NavigationDecision::ThisIsTheAnswer
} else {
NavigationDecision::ExploreMore
},
reasoning: Some(format!("LLM error: {}", e)),
}
}
}
}
async fn evaluate_nodes(
&self,
tree: &DocumentTree,
node_ids: &[NodeId],
context: &RetrievalContext,
) -> Vec<NodeEvaluation> {
let mut results = Vec::with_capacity(node_ids.len());
for node_id in node_ids {
results.push(self.evaluate_node(tree, *node_id, context).await);
}
results
}
fn name(&self) -> &'static str {
"llm"
}
fn capabilities(&self) -> StrategyCapabilities {
StrategyCapabilities {
uses_llm: true,
uses_embeddings: false,
supports_sufficiency: true,
typical_latency_ms: 500,
}
}
fn suitable_for_complexity(&self, complexity: QueryComplexity) -> bool {
matches!(
complexity,
QueryComplexity::Medium | QueryComplexity::Complex
)
}
}