use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::learn::offline::LearnedActionOrder;
use crate::types::LoraConfig;
#[derive(Debug, Clone)]
pub struct VotingStrategy {
pub high_threshold: f64,
pub medium_threshold: f64,
}
impl Default for VotingStrategy {
fn default() -> Self {
Self {
high_threshold: 0.8,
medium_threshold: 0.6,
}
}
}
impl VotingStrategy {
pub fn determine(&self, match_rate: f64, has_lora: bool) -> u8 {
if match_rate >= 1.0 {
0 } else if match_rate >= self.high_threshold && has_lora {
1 } else {
3 }
}
}
#[derive(Debug, Clone)]
pub enum SelectResult {
UseLearnedGraph {
graph: Box<DependencyGraph>,
lora: Option<LoraConfig>,
},
UseLlm {
lora: Option<LoraConfig>,
hint: Option<LearnedActionOrder>,
vote_count: u8,
match_rate: f64,
},
}
impl SelectResult {
pub fn needs_llm(&self) -> bool {
matches!(self, SelectResult::UseLlm { .. })
}
pub fn vote_count(&self) -> u8 {
match self {
SelectResult::UseLearnedGraph { .. } => 0,
SelectResult::UseLlm { vote_count, .. } => *vote_count,
}
}
pub fn lora(&self) -> Option<&LoraConfig> {
match self {
SelectResult::UseLearnedGraph { lora, .. } => lora.as_ref(),
SelectResult::UseLlm { lora, .. } => lora.as_ref(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DependencyEdge {
pub from: String,
pub to: String,
pub confidence: f64,
}
impl DependencyEdge {
pub fn new(from: impl Into<String>, to: impl Into<String>, confidence: f64) -> Self {
Self {
from: from.into(),
to: to.into(),
confidence: confidence.clamp(0.0, 1.0),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DependencyGraph {
edges: Vec<DependencyEdge>,
start_nodes: HashSet<String>,
terminal_nodes: HashSet<String>,
task: String,
available_actions: Vec<String>,
#[serde(default)]
param_variants: HashMap<String, (String, Vec<String>)>,
#[serde(default)]
discover_order: Vec<String>,
#[serde(default)]
not_discover_order: Vec<String>,
#[serde(skip)]
learn_record: Option<crate::learn::DependencyGraphRecord>,
}
impl DependencyGraph {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> DependencyGraphBuilder {
DependencyGraphBuilder::new()
}
pub fn valid_next_actions(&self, current_action: &str) -> Vec<String> {
let mut edges: Vec<_> = self
.edges
.iter()
.filter(|e| e.from == current_action)
.collect();
edges.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
edges.iter().map(|e| e.to.clone()).collect()
}
pub fn start_actions(&self) -> Vec<String> {
let mut actions: Vec<_> = self.start_nodes.iter().cloned().collect();
actions.sort();
actions
}
pub fn terminal_actions(&self) -> Vec<String> {
let mut actions: Vec<_> = self.terminal_nodes.iter().cloned().collect();
actions.sort();
actions
}
pub fn is_terminal(&self, action: &str) -> bool {
self.terminal_nodes.contains(action)
}
pub fn is_start(&self, action: &str) -> bool {
self.start_nodes.contains(action)
}
pub fn can_transition(&self, from: &str, to: &str) -> bool {
self.edges.iter().any(|e| e.from == from && e.to == to)
}
pub fn transition_confidence(&self, from: &str, to: &str) -> Option<f64> {
self.edges
.iter()
.find(|e| e.from == from && e.to == to)
.map(|e| e.confidence)
}
pub fn edges(&self) -> &[DependencyEdge] {
&self.edges
}
pub fn task(&self) -> &str {
&self.task
}
pub fn available_actions(&self) -> &[String] {
&self.available_actions
}
pub fn param_variants(&self, action: &str) -> Option<(&str, &[String])> {
self.param_variants
.get(action)
.map(|(key, values)| (key.as_str(), values.as_slice()))
}
pub fn all_param_variants(&self) -> &HashMap<String, (String, Vec<String>)> {
&self.param_variants
}
pub fn discover_order(&self) -> &[String] {
&self.discover_order
}
pub fn not_discover_order(&self) -> &[String] {
&self.not_discover_order
}
pub fn set_action_order(&mut self, discover: Vec<String>, not_discover: Vec<String>) {
self.discover_order = discover;
self.not_discover_order = not_discover;
}
pub fn has_action_order(&self) -> bool {
!self.discover_order.is_empty() || !self.not_discover_order.is_empty()
}
pub fn set_learn_record(&mut self, record: crate::learn::DependencyGraphRecord) {
self.learn_record = Some(record);
}
pub fn learn_record(&self) -> Option<&crate::learn::DependencyGraphRecord> {
self.learn_record.as_ref()
}
pub fn take_learn_record(&mut self) -> Option<crate::learn::DependencyGraphRecord> {
self.learn_record.take()
}
pub fn validate(&self) -> Result<(), DependencyGraphError> {
if self.start_nodes.is_empty() {
return Err(DependencyGraphError::NoStartNodes);
}
if self.terminal_nodes.is_empty() {
return Err(DependencyGraphError::NoTerminalNodes);
}
for node in &self.start_nodes {
if !self.available_actions.contains(node) {
return Err(DependencyGraphError::UnknownAction(node.clone()));
}
}
for node in &self.terminal_nodes {
if !self.available_actions.contains(node) {
return Err(DependencyGraphError::UnknownAction(node.clone()));
}
}
for edge in &self.edges {
if !self.available_actions.contains(&edge.from) {
return Err(DependencyGraphError::UnknownAction(edge.from.clone()));
}
if !self.available_actions.contains(&edge.to) {
return Err(DependencyGraphError::UnknownAction(edge.to.clone()));
}
}
Ok(())
}
pub fn to_mermaid(&self) -> String {
let mut lines = vec!["graph LR".to_string()];
for edge in &self.edges {
let label = format!("{:.0}%", edge.confidence * 100.0);
lines.push(format!(" {} -->|{}| {}", edge.from, label, edge.to));
}
for start in &self.start_nodes {
lines.push(format!(" style {} fill:#9f9", start));
}
for terminal in &self.terminal_nodes {
lines.push(format!(" style {} fill:#f99", terminal));
}
lines.join("\n")
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum DependencyGraphError {
#[error("No start nodes defined")]
NoStartNodes,
#[error("No terminal nodes defined")]
NoTerminalNodes,
#[error("Unknown action: {0}")]
UnknownAction(String),
#[error("Parse error: {0}")]
ParseError(String),
#[error("LLM error: {0}")]
LlmError(String),
}
pub trait DependencyGraphProvider: Send + Sync {
fn provide_graph(&self, task: &str, available_actions: &[String]) -> Option<DependencyGraph>;
fn select(&self, _task: &str, _available_actions: &[String]) -> Option<SelectResult> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct DependencyGraphBuilder {
edges: Vec<DependencyEdge>,
start_nodes: HashSet<String>,
terminal_nodes: HashSet<String>,
task: String,
available_actions: Vec<String>,
param_variants: HashMap<String, (String, Vec<String>)>,
discover_order: Vec<String>,
not_discover_order: Vec<String>,
}
impl DependencyGraphBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn task(mut self, task: impl Into<String>) -> Self {
self.task = task.into();
self
}
pub fn available_actions<I, S>(mut self, actions: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.available_actions = actions.into_iter().map(|s| s.into()).collect();
self
}
pub fn edge(mut self, from: impl Into<String>, to: impl Into<String>, confidence: f64) -> Self {
self.edges.push(DependencyEdge::new(from, to, confidence));
self
}
pub fn start_node(mut self, action: impl Into<String>) -> Self {
self.start_nodes.insert(action.into());
self
}
pub fn start_nodes<I, S>(mut self, actions: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.start_nodes
.extend(actions.into_iter().map(|s| s.into()));
self
}
pub fn terminal_node(mut self, action: impl Into<String>) -> Self {
self.terminal_nodes.insert(action.into());
self
}
pub fn terminal_nodes<I, S>(mut self, actions: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.terminal_nodes
.extend(actions.into_iter().map(|s| s.into()));
self
}
pub fn param_variants<I, S>(
mut self,
action: impl Into<String>,
key: impl Into<String>,
values: I,
) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.param_variants.insert(
action.into(),
(key.into(), values.into_iter().map(|s| s.into()).collect()),
);
self
}
pub fn with_orders(
mut self,
discover_order: Vec<String>,
not_discover_order: Vec<String>,
) -> Self {
self.discover_order = discover_order;
self.not_discover_order = not_discover_order;
self
}
pub fn build(self) -> DependencyGraph {
DependencyGraph {
edges: self.edges,
start_nodes: self.start_nodes,
terminal_nodes: self.terminal_nodes,
task: self.task,
available_actions: self.available_actions,
param_variants: self.param_variants,
discover_order: self.discover_order,
not_discover_order: self.not_discover_order,
learn_record: None,
}
}
pub fn build_validated(self) -> Result<DependencyGraph, DependencyGraphError> {
let graph = self.build();
graph.validate()?;
Ok(graph)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmDependencyResponse {
pub edges: Vec<LlmEdge>,
pub start: Vec<String>,
pub terminal: Vec<String>,
#[serde(default)]
pub reasoning: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmEdge {
pub from: String,
pub to: String,
pub confidence: f64,
}
impl LlmDependencyResponse {
pub fn into_graph(
self,
task: impl Into<String>,
available_actions: Vec<String>,
) -> DependencyGraph {
let mut builder = DependencyGraphBuilder::new()
.task(task)
.available_actions(available_actions)
.start_nodes(self.start)
.terminal_nodes(self.terminal);
for edge in self.edges {
builder = builder.edge(edge.from, edge.to, edge.confidence);
}
builder.build()
}
pub fn parse(text: &str) -> Result<Self, DependencyGraphError> {
if let Some(response) = Self::parse_arrow_format(text) {
return Ok(response);
}
if let Ok(parsed) = serde_json::from_str(text) {
return Ok(parsed);
}
if let Some(json) = Self::extract_json(text) {
serde_json::from_str(&json).map_err(|e| DependencyGraphError::ParseError(e.to_string()))
} else {
Err(DependencyGraphError::ParseError(format!(
"No valid format found in response: {}",
text.chars().take(200).collect::<String>()
)))
}
}
fn parse_arrow_format(text: &str) -> Option<Self> {
if let Some(result) = Self::parse_arrow_only(text) {
return Some(result);
}
if let Some(result) = Self::parse_numbered_list(text) {
return Some(result);
}
None
}
fn parse_arrow_only(text: &str) -> Option<Self> {
let normalized = text.replace('→', "->");
let arrow_line = normalized.lines().find(|line| line.contains("->"))?;
let parts: Vec<&str> = arrow_line.split("->").collect();
if parts.len() < 2 {
return None;
}
let actions_in_order: Vec<String> = parts
.iter()
.filter_map(|part| {
let trimmed = part.trim();
let last_word = trimmed.split_whitespace().last()?;
let action: String = last_word.chars().filter(|c| c.is_alphabetic()).collect();
if action.is_empty() {
None
} else {
Some(action)
}
})
.collect();
if actions_in_order.len() < 2 {
return None;
}
Self::build_response(actions_in_order)
}
fn parse_numbered_list(text: &str) -> Option<Self> {
let mut actions_in_order: Vec<String> = Vec::new();
for i in 1..=10 {
let pattern = format!("{}.", i);
if let Some(pos) = text.find(&pattern) {
let after = &text[pos + pattern.len()..];
if let Some(word) = after.split_whitespace().next() {
let action: String = word.chars().filter(|c| c.is_alphabetic()).collect();
if !action.is_empty() && !actions_in_order.contains(&action) {
actions_in_order.push(action);
}
}
}
}
if actions_in_order.len() < 2 {
return None;
}
Self::build_response(actions_in_order)
}
fn build_response(actions_in_order: Vec<String>) -> Option<Self> {
let mut edges = Vec::new();
for window in actions_in_order.windows(2) {
edges.push(LlmEdge {
from: window[0].clone(),
to: window[1].clone(),
confidence: 0.9,
});
}
Some(Self {
edges,
start: vec![actions_in_order.first()?.clone()],
terminal: vec![actions_in_order.last()?.clone()],
reasoning: Some("Parsed from text format".to_string()),
})
}
fn extract_json(text: &str) -> Option<String> {
let start = text.find('{')?;
let chars: Vec<char> = text[start..].chars().collect();
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, &ch) in chars.iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(chars[..=i].iter().collect());
}
}
_ => {}
}
}
None
}
}
pub trait DependencyPlanner: Send + Sync {
fn plan(
&self,
task: &str,
available_actions: &[String],
) -> Result<DependencyGraph, DependencyGraphError>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone, Default)]
pub struct StaticDependencyPlanner {
patterns: HashMap<String, DependencyGraph>,
default_pattern: Option<String>,
}
impl StaticDependencyPlanner {
pub fn new() -> Self {
Self::default()
}
pub fn with_pattern(mut self, name: impl Into<String>, graph: DependencyGraph) -> Self {
let name = name.into();
if self.default_pattern.is_none() {
self.default_pattern = Some(name.clone());
}
self.patterns.insert(name, graph);
self
}
pub fn with_default_pattern(mut self, name: impl Into<String>) -> Self {
self.default_pattern = Some(name.into());
self
}
pub fn with_file_exploration_pattern(self) -> Self {
let graph = DependencyGraph::builder()
.task("File exploration")
.available_actions(["Grep", "List", "Read"])
.edge("Grep", "Read", 0.95)
.edge("List", "Grep", 0.60)
.edge("List", "Read", 0.40)
.start_nodes(["Grep", "List"])
.terminal_node("Read")
.build();
self.with_pattern("file_exploration", graph)
}
pub fn with_code_search_pattern(self) -> Self {
let graph = DependencyGraph::builder()
.task("Code search")
.available_actions(["Grep", "Read"])
.edge("Grep", "Read", 0.95)
.start_node("Grep")
.terminal_node("Read")
.build();
self.with_pattern("code_search", graph)
}
}
impl DependencyPlanner for StaticDependencyPlanner {
fn plan(
&self,
task: &str,
available_actions: &[String],
) -> Result<DependencyGraph, DependencyGraphError> {
if let Some(pattern_name) = &self.default_pattern {
if let Some(graph) = self.patterns.get(pattern_name) {
let mut graph = graph.clone();
graph.task = task.to_string();
graph.available_actions = available_actions.to_vec();
return Ok(graph);
}
}
if available_actions.is_empty() {
return Err(DependencyGraphError::NoStartNodes);
}
let mut builder = DependencyGraphBuilder::new()
.task(task)
.available_actions(available_actions.to_vec())
.start_node(&available_actions[0]);
if available_actions.len() > 1 {
for window in available_actions.windows(2) {
builder = builder.edge(&window[0], &window[1], 0.80);
}
builder = builder.terminal_node(&available_actions[available_actions.len() - 1]);
} else {
builder = builder.terminal_node(&available_actions[0]);
}
Ok(builder.build())
}
fn name(&self) -> &str {
"StaticDependencyPlanner"
}
}
use crate::actions::ActionDef;
pub struct DependencyPromptGenerator;
impl DependencyPromptGenerator {
pub fn generate_prompt(task: &str, actions: &[ActionDef]) -> String {
let actions_list = actions
.iter()
.map(|a| a.name.as_str())
.collect::<Vec<_>>()
.join(", ");
format!(
r#"{task}
Steps: {actions_list}
The very first step is:"#
)
}
pub fn generate_first_prompt(_task: &str, actions: &[ActionDef]) -> String {
let mut sorted_actions: Vec<&ActionDef> = actions.iter().collect();
sorted_actions.sort_by(|a, b| a.name.cmp(&b.name));
let actions_list = sorted_actions
.iter()
.map(|a| a.name.as_str())
.collect::<Vec<_>>()
.join(", ");
let descriptions: Vec<String> = sorted_actions
.iter()
.map(|a| format!("- {}: {}", a.name, a.description))
.collect();
let descriptions_block = descriptions.join("\n");
let first_verb = sorted_actions
.first()
.map(|a| Self::extract_verb(&a.description))
.unwrap_or_else(|| "CHECK".to_string());
format!(
r#"Steps: {actions_list}
{descriptions_block}
Which step {first_verb}S first?
Answer:"#
)
}
pub fn generate_last_prompt(_task: &str, actions: &[ActionDef]) -> String {
let mut sorted_actions: Vec<&ActionDef> = actions.iter().collect();
sorted_actions.sort_by(|a, b| a.name.cmp(&b.name));
let actions_list = sorted_actions
.iter()
.map(|a| a.name.as_str())
.collect::<Vec<_>>()
.join(", ");
let descriptions: Vec<String> = sorted_actions
.iter()
.map(|a| format!("- {}: {}", a.name, a.description))
.collect();
let descriptions_block = descriptions.join("\n");
format!(
r#"Steps: {actions_list}
{descriptions_block}
Which step should be done last?
Answer:"#
)
}
pub fn generate_pair_prompt(task: &str, action_a: &str, action_b: &str) -> String {
format!(
r#"For {task}, which comes first: {action_a} or {action_b}?
Answer (one word):"#
)
}
fn extract_verb(description: &str) -> String {
description
.split_whitespace()
.next()
.map(|w| {
let word = w.trim_end_matches('s').trim_end_matches('S');
word.to_uppercase()
})
.unwrap_or_else(|| "CHECK".to_string())
}
}
#[derive(Debug, Clone)]
pub struct GraphNavigator {
graph: DependencyGraph,
completed_actions: HashSet<String>,
}
impl GraphNavigator {
pub fn new(graph: DependencyGraph) -> Self {
Self {
graph,
completed_actions: HashSet::new(),
}
}
pub fn mark_completed(&mut self, action: &str) {
self.completed_actions.insert(action.to_string());
}
pub fn suggest_next(&self) -> Vec<String> {
if self.completed_actions.is_empty() {
return self.graph.start_actions();
}
let mut candidates = Vec::new();
for completed in &self.completed_actions {
for next in self.graph.valid_next_actions(completed) {
if !self.completed_actions.contains(&next) && !candidates.contains(&next) {
candidates.push(next);
}
}
}
candidates
}
pub fn is_task_complete(&self) -> bool {
self.graph
.terminal_actions()
.iter()
.any(|t| self.completed_actions.contains(t))
}
pub fn progress(&self) -> f64 {
if self.graph.available_actions.is_empty() {
return 0.0;
}
self.completed_actions.len() as f64 / self.graph.available_actions.len() as f64
}
pub fn graph(&self) -> &DependencyGraph {
&self.graph
}
}
pub fn build_graph_from_action_order(
task: &str,
available_actions: &[String],
discover: &[String],
not_discover: &[String],
) -> Option<DependencyGraph> {
if discover.is_empty() && not_discover.is_empty() {
return None;
}
let mut builder = DependencyGraphBuilder::new()
.task(task)
.available_actions(available_actions.iter().cloned());
if !discover.is_empty() {
builder = builder.start_node(&discover[0]);
} else if !not_discover.is_empty() {
builder = builder.start_node(¬_discover[0]);
}
if let Some(last) = not_discover.last() {
builder = builder.terminal_node(last);
} else if !discover.is_empty() {
builder = builder.terminal_node(discover.last().unwrap());
}
for window in discover.windows(2) {
builder = builder.edge(&window[0], &window[1], 0.9);
}
if !discover.is_empty() && !not_discover.is_empty() {
builder = builder.edge(discover.last().unwrap(), ¬_discover[0], 0.9);
}
for window in not_discover.windows(2) {
builder = builder.edge(&window[0], &window[1], 0.9);
}
builder = builder.with_orders(discover.to_vec(), not_discover.to_vec());
Some(builder.build())
}
#[derive(Debug, Clone, Default)]
pub struct LearnedDependencyProvider {
entries: Vec<LearnedActionOrder>,
strategy: VotingStrategy,
}
impl LearnedDependencyProvider {
pub fn empty() -> Self {
Self::default()
}
pub fn new(action_order: LearnedActionOrder) -> Self {
Self {
entries: vec![action_order],
strategy: VotingStrategy::default(),
}
}
pub fn with_entries(entries: Vec<LearnedActionOrder>) -> Self {
Self {
entries,
strategy: VotingStrategy::default(),
}
}
pub fn with_strategy(mut self, strategy: VotingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn add_entry(&mut self, entry: LearnedActionOrder) {
self.entries.push(entry);
}
pub fn entry_count(&self) -> usize {
self.entries.len()
}
pub fn action_order(&self) -> Option<&LearnedActionOrder> {
self.entries.first()
}
pub fn entries(&self) -> &[LearnedActionOrder] {
&self.entries
}
pub fn select(&self, task: &str, available_actions: &[String]) -> SelectResult {
for entry in &self.entries {
if entry.is_exact_match(available_actions) {
return self.build_learned_result(task, available_actions, entry);
}
}
let mut best_match: Option<(&LearnedActionOrder, f64)> = None;
for entry in &self.entries {
let rate = entry.match_rate(available_actions);
if let Some((_, best_rate)) = best_match {
if rate > best_rate {
best_match = Some((entry, rate));
}
} else if rate > 0.0 {
best_match = Some((entry, rate));
}
}
match best_match {
Some((entry, rate)) if rate >= self.strategy.medium_threshold => {
self.build_llm_result(entry, rate)
}
_ => self.build_fallback_result(),
}
}
fn build_learned_result(
&self,
task: &str,
available_actions: &[String],
entry: &LearnedActionOrder,
) -> SelectResult {
let graph = build_graph_from_action_order(
task,
available_actions,
&entry.discover,
&entry.not_discover,
);
match graph {
Some(g) => {
tracing::info!(
discover = ?entry.discover,
not_discover = ?entry.not_discover,
lora = ?entry.lora.as_ref().map(|l| &l.name),
"Using learned action order (LLM skipped)"
);
SelectResult::UseLearnedGraph {
graph: Box::new(g),
lora: entry.lora.clone(),
}
}
None => {
tracing::warn!("Failed to build graph from exact match, falling back to LLM");
self.build_llm_result(entry, 1.0)
}
}
}
fn build_llm_result(&self, entry: &LearnedActionOrder, match_rate: f64) -> SelectResult {
let vote_count = self.strategy.determine(match_rate, entry.lora.is_some());
tracing::debug!(
match_rate = match_rate,
vote_count = vote_count,
has_lora = entry.lora.is_some(),
"LLM invocation needed (partial match)"
);
SelectResult::UseLlm {
lora: entry.lora.clone(),
hint: Some(entry.clone()),
vote_count,
match_rate,
}
}
fn build_fallback_result(&self) -> SelectResult {
tracing::debug!("No matching entry, using base model with 3 votes");
SelectResult::UseLlm {
lora: None,
hint: None,
vote_count: 3,
match_rate: 0.0,
}
}
}
impl DependencyGraphProvider for LearnedDependencyProvider {
fn provide_graph(&self, task: &str, available_actions: &[String]) -> Option<DependencyGraph> {
match self.select(task, available_actions) {
SelectResult::UseLearnedGraph { graph, .. } => {
graph.validate().ok()?;
Some(*graph)
}
SelectResult::UseLlm { .. } => {
None
}
}
}
fn select(&self, task: &str, available_actions: &[String]) -> Option<SelectResult> {
Some(LearnedDependencyProvider::select(
self,
task,
available_actions,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dependency_graph_builder() {
let graph = DependencyGraph::builder()
.task("Find auth function")
.available_actions(["Grep", "List", "Read"])
.edge("Grep", "Read", 0.95)
.edge("List", "Grep", 0.60)
.start_nodes(["Grep", "List"])
.terminal_node("Read")
.build();
assert_eq!(graph.task(), "Find auth function");
assert!(graph.is_start("Grep"));
assert!(graph.is_start("List"));
assert!(graph.is_terminal("Read"));
assert!(graph.can_transition("Grep", "Read"));
assert!(!graph.can_transition("Read", "Grep"));
}
#[test]
fn test_valid_next_actions() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "List", "Read"])
.edge("Grep", "Read", 0.95)
.edge("List", "Grep", 0.60)
.edge("List", "Read", 0.40)
.start_nodes(["Grep", "List"])
.terminal_node("Read")
.build();
let next = graph.valid_next_actions("Grep");
assert_eq!(next, vec!["Read"]);
let next = graph.valid_next_actions("List");
assert_eq!(next, vec!["Grep", "Read"]);
let next = graph.valid_next_actions("Read");
assert!(next.is_empty());
}
#[test]
fn test_static_planner_file_exploration() {
let planner = StaticDependencyPlanner::new().with_file_exploration_pattern();
let graph = planner
.plan("Find auth.rs", &["Grep".to_string(), "Read".to_string()])
.unwrap();
assert!(graph.is_start("Grep"));
assert!(graph.is_terminal("Read"));
}
#[test]
fn test_graph_navigator() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "Read"])
.edge("Grep", "Read", 0.95)
.start_node("Grep")
.terminal_node("Read")
.build();
let mut nav = GraphNavigator::new(graph);
assert_eq!(nav.suggest_next(), vec!["Grep"]);
assert!(!nav.is_task_complete());
nav.mark_completed("Grep");
assert_eq!(nav.suggest_next(), vec!["Read"]);
assert!(!nav.is_task_complete());
nav.mark_completed("Read");
assert!(nav.is_task_complete());
assert!(nav.suggest_next().is_empty());
}
#[test]
fn test_llm_response_parsing() {
let json = r#"{
"edges": [
{"from": "Grep", "to": "Read", "confidence": 0.95}
],
"start": ["Grep"],
"terminal": ["Read"],
"reasoning": "Search first, then read"
}"#;
let response = LlmDependencyResponse::parse(json).unwrap();
assert_eq!(response.edges.len(), 1);
assert_eq!(response.start, vec!["Grep"]);
assert_eq!(response.terminal, vec!["Read"]);
assert!(response.reasoning.is_some());
let graph = response.into_graph(
"Find function",
vec!["Grep".to_string(), "Read".to_string()],
);
assert!(graph.can_transition("Grep", "Read"));
}
#[test]
fn test_mermaid_output() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "List", "Read"])
.edge("Grep", "Read", 0.95)
.edge("List", "Grep", 0.60)
.start_nodes(["Grep", "List"])
.terminal_node("Read")
.build();
let mermaid = graph.to_mermaid();
assert!(mermaid.contains("graph LR"));
assert!(mermaid.contains("Grep -->|95%| Read"));
assert!(mermaid.contains("style Read fill:#f99"));
}
#[test]
fn test_learned_action_order_hash() {
let actions = vec![
"Grep".to_string(),
"Read".to_string(),
"Restart".to_string(),
];
let order = LearnedActionOrder::new(
vec!["Grep".to_string(), "Read".to_string()],
vec!["Restart".to_string()],
&actions,
);
let actions_reordered = vec![
"Restart".to_string(),
"Grep".to_string(),
"Read".to_string(),
];
assert!(order.matches_actions(&actions_reordered));
let actions_different = vec!["Grep".to_string(), "Read".to_string()];
assert!(!order.matches_actions(&actions_different));
}
#[test]
fn test_learned_dependency_provider_cache_hit() {
let actions = vec![
"Grep".to_string(),
"Read".to_string(),
"Restart".to_string(),
];
let order = LearnedActionOrder::new(
vec!["Grep".to_string(), "Read".to_string()],
vec!["Restart".to_string()],
&actions,
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("troubleshooting", &actions);
assert!(graph.is_some());
let graph = graph.unwrap();
assert!(graph.is_start("Grep"));
assert!(graph.is_terminal("Restart"));
assert!(graph.can_transition("Grep", "Read"));
assert!(graph.can_transition("Read", "Restart"));
}
#[test]
fn test_learned_dependency_provider_cache_miss() {
let original_actions = vec![
"Grep".to_string(),
"Read".to_string(),
"Restart".to_string(),
];
let order = LearnedActionOrder::new(
vec!["Grep".to_string(), "Read".to_string()],
vec!["Restart".to_string()],
&original_actions,
);
let provider = LearnedDependencyProvider::new(order);
let different_actions = vec!["Grep".to_string(), "Read".to_string()];
let graph = provider.provide_graph("troubleshooting", &different_actions);
assert!(graph.is_none());
}
#[test]
fn test_learned_dependency_provider_discover_only() {
let actions = vec!["Grep".to_string(), "Read".to_string()];
let order = LearnedActionOrder::new(
vec!["Grep".to_string(), "Read".to_string()],
vec![], &actions,
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("search task", &actions);
assert!(graph.is_some());
let graph = graph.unwrap();
assert!(graph.is_start("Grep"));
assert!(graph.is_terminal("Read")); assert!(graph.can_transition("Grep", "Read"));
}
#[test]
fn test_learned_dependency_provider_not_discover_only() {
let actions = vec!["Restart".to_string(), "CheckStatus".to_string()];
let order = LearnedActionOrder::new(
vec![], vec!["Restart".to_string(), "CheckStatus".to_string()],
&actions,
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("ops task", &actions);
assert!(graph.is_some());
let graph = graph.unwrap();
assert!(graph.is_start("Restart")); assert!(graph.is_terminal("CheckStatus"));
assert!(graph.can_transition("Restart", "CheckStatus"));
}
#[test]
fn test_learned_dependency_provider_empty_lists() {
let actions = vec!["Grep".to_string(), "Read".to_string()];
let order = LearnedActionOrder::new(
vec![], vec![], &actions,
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("empty task", &actions);
assert!(graph.is_none());
}
#[test]
fn test_extract_json_simple() {
let text = r#"Here is the result: {"edges": [], "start": ["A"], "terminal": ["B"]}"#;
let json = LlmDependencyResponse::extract_json(text);
assert!(json.is_some());
let json = json.unwrap();
assert!(json.starts_with('{'));
assert!(json.ends_with('}'));
}
#[test]
fn test_extract_json_nested() {
let text = r#"Result: {"edges": [{"from": "A", "to": "B", "confidence": 0.9}], "start": ["A"], "terminal": ["B"]}"#;
let json = LlmDependencyResponse::extract_json(text);
assert!(json.is_some());
let parsed: Result<LlmDependencyResponse, _> = serde_json::from_str(&json.unwrap());
assert!(parsed.is_ok());
}
#[test]
fn test_extract_json_with_string_braces() {
let text =
r#"{"edges": [], "start": ["A"], "terminal": ["B"], "reasoning": "Use {pattern}"}"#;
let json = LlmDependencyResponse::extract_json(text);
assert!(json.is_some());
assert_eq!(json.unwrap(), text);
}
#[test]
fn test_extract_json_no_json() {
let text = "This is just plain text without JSON";
let json = LlmDependencyResponse::extract_json(text);
assert!(json.is_none());
}
#[test]
fn test_validate_unknown_start_node() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "Read"])
.start_node("Unknown") .terminal_node("Read")
.build();
let result = graph.validate();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DependencyGraphError::UnknownAction(name) if name == "Unknown"
));
}
#[test]
fn test_validate_unknown_terminal_node() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "Read"])
.start_node("Grep")
.terminal_node("Unknown") .build();
let result = graph.validate();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DependencyGraphError::UnknownAction(name) if name == "Unknown"
));
}
#[test]
fn test_validate_valid_graph() {
let graph = DependencyGraph::builder()
.available_actions(["Grep", "Read"])
.edge("Grep", "Read", 0.9)
.start_node("Grep")
.terminal_node("Read")
.build();
assert!(graph.validate().is_ok());
}
#[test]
fn test_start_actions_sorted() {
let graph = DependencyGraph::builder()
.available_actions(["Zebra", "Apple", "Mango"])
.start_nodes(["Zebra", "Apple", "Mango"])
.terminal_node("Zebra")
.build();
let actions = graph.start_actions();
assert_eq!(actions, vec!["Apple", "Mango", "Zebra"]);
}
#[test]
fn test_terminal_actions_sorted() {
let graph = DependencyGraph::builder()
.available_actions(["Zebra", "Apple", "Mango"])
.start_node("Apple")
.terminal_nodes(["Zebra", "Apple", "Mango"])
.build();
let actions = graph.terminal_actions();
assert_eq!(actions, vec!["Apple", "Mango", "Zebra"]);
}
#[test]
fn test_voting_strategy_exact_match() {
let strategy = VotingStrategy::default();
assert_eq!(strategy.determine(1.0, true), 0);
assert_eq!(strategy.determine(1.0, false), 0);
}
#[test]
fn test_voting_strategy_high_with_lora() {
let strategy = VotingStrategy::default();
assert_eq!(strategy.determine(0.85, true), 1);
assert_eq!(strategy.determine(0.80, true), 1);
}
#[test]
fn test_voting_strategy_high_without_lora() {
let strategy = VotingStrategy::default();
assert_eq!(strategy.determine(0.85, false), 3);
}
#[test]
fn test_voting_strategy_medium() {
let strategy = VotingStrategy::default();
assert_eq!(strategy.determine(0.65, true), 3);
assert_eq!(strategy.determine(0.60, true), 3);
}
#[test]
fn test_voting_strategy_low() {
let strategy = VotingStrategy::default();
assert_eq!(strategy.determine(0.5, true), 3);
assert_eq!(strategy.determine(0.5, false), 3);
}
#[test]
fn test_select_exact_match_returns_learned_graph() {
let actions = vec![
"CheckStatus".to_string(),
"ReadLogs".to_string(),
"Restart".to_string(),
];
let order = LearnedActionOrder::new(
vec!["CheckStatus".to_string(), "ReadLogs".to_string()],
vec!["Restart".to_string()],
&actions,
);
let provider = LearnedDependencyProvider::new(order);
let result = provider.select("test task", &actions);
assert!(!result.needs_llm());
assert_eq!(result.vote_count(), 0);
assert!(matches!(result, SelectResult::UseLearnedGraph { .. }));
}
#[test]
fn test_select_no_match_returns_llm_fallback() {
let order = LearnedActionOrder::new(
vec!["A".to_string(), "B".to_string()],
vec!["C".to_string()],
&["A".to_string(), "B".to_string(), "C".to_string()],
);
let provider = LearnedDependencyProvider::new(order);
let result = provider.select("test task", &["X".to_string(), "Y".to_string()]);
assert!(result.needs_llm());
assert_eq!(result.vote_count(), 3);
assert!(result.lora().is_none());
assert!(matches!(
result,
SelectResult::UseLlm {
hint: None,
match_rate,
..
} if match_rate == 0.0
));
}
#[test]
fn test_select_partial_match_with_lora_returns_1_vote() {
use crate::types::LoraConfig;
let lora = LoraConfig {
id: 1,
name: Some("test-lora".to_string()),
scale: 1.0,
};
let all_actions: Vec<String> = ["A", "B", "C", "D", "E"]
.iter()
.map(|s| s.to_string())
.collect();
let order = LearnedActionOrder::new(
vec![
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
],
vec!["E".to_string()],
&all_actions,
)
.with_lora(lora);
let provider = LearnedDependencyProvider::new(order);
let query_actions: Vec<String> =
["A", "B", "C", "D"].iter().map(|s| s.to_string()).collect();
let result = provider.select("test task", &query_actions);
assert!(result.needs_llm());
assert_eq!(result.vote_count(), 1);
assert!(result.lora().is_some());
}
#[test]
fn test_select_empty_provider_returns_fallback() {
let provider = LearnedDependencyProvider::empty();
let result = provider.select("test task", &["A".to_string()]);
assert!(result.needs_llm());
assert_eq!(result.vote_count(), 3);
assert!(result.lora().is_none());
}
#[test]
fn test_select_multiple_entries_best_match() {
use crate::types::LoraConfig;
let order1 = LearnedActionOrder::new(
vec!["A".to_string(), "B".to_string()],
vec!["C".to_string()],
&["A".to_string(), "B".to_string(), "C".to_string()],
);
let lora = LoraConfig {
id: 2,
name: Some("better-lora".to_string()),
scale: 1.0,
};
let order2 = LearnedActionOrder::new(
vec!["A".to_string(), "B".to_string()],
vec!["D".to_string()],
&["A".to_string(), "B".to_string(), "D".to_string()],
)
.with_lora(lora);
let provider = LearnedDependencyProvider::with_entries(vec![order1, order2]);
let query = vec!["A".to_string(), "B".to_string(), "D".to_string()];
let result = provider.select("test task", &query);
assert!(!result.needs_llm());
assert!(matches!(
result,
SelectResult::UseLearnedGraph { lora: Some(l), .. } if l.name == Some("better-lora".to_string())
));
}
#[test]
fn test_provide_graph_exact_match_via_select() {
let actions = vec![
"CheckStatus".to_string(),
"ReadLogs".to_string(),
"Restart".to_string(),
];
let order = LearnedActionOrder::new(
vec!["CheckStatus".to_string(), "ReadLogs".to_string()],
vec!["Restart".to_string()],
&actions,
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("test task", &actions);
assert!(graph.is_some());
let graph = graph.unwrap();
assert!(graph.is_start("CheckStatus"));
assert!(graph.is_terminal("Restart"));
}
#[test]
fn test_provide_graph_no_match_returns_none() {
let order = LearnedActionOrder::new(
vec!["A".to_string(), "B".to_string()],
vec!["C".to_string()],
&["A".to_string(), "B".to_string(), "C".to_string()],
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph("test task", &["X".to_string(), "Y".to_string()]);
assert!(graph.is_none());
}
#[test]
fn test_provide_graph_partial_match_returns_none() {
let order = LearnedActionOrder::new(
vec![
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
],
vec!["E".to_string()],
&[
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
"E".to_string(),
],
);
let provider = LearnedDependencyProvider::new(order);
let graph = provider.provide_graph(
"test task",
&[
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
],
);
assert!(graph.is_none());
}
}