use car_ir::ActionProposal;
use car_verify::VerifyResult;
use serde::Serialize;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Default)]
pub struct ToolFeedback {
pub tool_success_rates: HashMap<String, f64>,
}
impl ToolFeedback {
pub fn from_trajectories(trajectories: &[car_memgine::Trajectory]) -> Self {
let mut tool_outcomes: HashMap<String, (u64, u64)> = HashMap::new(); for traj in trajectories {
for event in &traj.events {
if let Some(ref tool) = event.tool {
let entry = tool_outcomes.entry(tool.clone()).or_default();
entry.1 += 1; if event.kind == "action_succeeded" {
entry.0 += 1; }
}
}
}
let tool_success_rates = tool_outcomes
.into_iter()
.map(|(tool, (success, total))| {
(
tool,
if total > 0 {
success as f64 / total as f64
} else {
0.5
},
)
})
.collect();
Self { tool_success_rates }
}
pub fn rate(&self, tool: &str) -> f64 {
self.tool_success_rates.get(tool).copied().unwrap_or(0.5)
}
pub fn proposal_tool_confidence(&self, proposal: &ActionProposal) -> f64 {
let tool_calls: Vec<&str> = proposal
.actions
.iter()
.filter(|a| a.action_type == car_ir::ActionType::ToolCall)
.filter_map(|a| a.tool.as_deref())
.collect();
if tool_calls.is_empty() {
return 1.0; }
let sum: f64 = tool_calls.iter().map(|t| self.rate(t)).sum();
sum / tool_calls.len() as f64
}
}
pub fn estimate_proposal_tokens(proposal: &ActionProposal) -> usize {
serde_json::to_string(proposal)
.map(|s| s.len() / 4)
.unwrap_or_else(|_| proposal.actions.len() * 32)
}
#[derive(Debug, Clone)]
pub struct PlannerConfig {
pub cost_weight: f64,
pub action_budget: usize,
pub tool_call_budget: usize,
pub conflict_penalty: f64,
pub feedback_weight: f64,
pub token_budget: usize,
pub token_weight: f64,
}
impl Default for PlannerConfig {
fn default() -> Self {
Self {
cost_weight: 0.2,
action_budget: 20,
tool_call_budget: 10,
conflict_penalty: 0.15,
feedback_weight: 0.3,
token_budget: 4000,
token_weight: 0.25,
}
}
}
impl PlannerConfig {
pub fn from_cost_target(target: &car_ir::CostTarget) -> Self {
Self {
cost_weight: target.cost_weight.clamp(0.0, 1.0),
action_budget: target.target_actions as usize,
tool_call_budget: target.target_tool_calls as usize,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ScoredProposal {
pub index: usize,
pub score: f64,
pub validity: f64,
pub cost_efficiency: f64,
pub error_count: usize,
pub warning_count: usize,
pub action_count: usize,
pub tool_call_count: usize,
pub parallelism_levels: usize,
pub valid: bool,
pub state_keys_written: usize,
pub has_write_conflicts: bool,
pub historical_confidence: f64,
pub token_estimate: usize,
pub quality_per_token: f64,
}
pub struct Planner {
config: PlannerConfig,
}
impl Planner {
pub fn new(config: PlannerConfig) -> Self {
Self { config }
}
pub fn score(
&self,
proposal: &ActionProposal,
initial_state: Option<&HashMap<String, Value>>,
registered_tools: Option<&HashSet<String>>,
) -> ScoredProposal {
self.score_indexed(0, proposal, initial_state, registered_tools, None)
}
fn score_indexed(
&self,
index: usize,
proposal: &ActionProposal,
initial_state: Option<&HashMap<String, Value>>,
registered_tools: Option<&HashSet<String>>,
feedback: Option<&ToolFeedback>,
) -> ScoredProposal {
let vr = car_verify::verify(proposal, initial_state, registered_tools, 100);
self.score_from_verify(index, proposal, &vr, feedback)
}
fn score_from_verify(
&self,
index: usize,
proposal: &ActionProposal,
vr: &VerifyResult,
feedback: Option<&ToolFeedback>,
) -> ScoredProposal {
let error_count = vr.issues.iter().filter(|i| i.severity == "error").count();
let warning_count = vr.issues.iter().filter(|i| i.severity == "warning").count();
let action_count = proposal.actions.len();
let tool_call_count = proposal
.actions
.iter()
.filter(|a| a.action_type == car_ir::ActionType::ToolCall)
.count();
let state_keys_written = vr.simulated_state.len();
let has_write_conflicts = !vr.conflicts.is_empty();
let validity = if error_count > 0 {
0.0
} else {
let mut v = 1.0;
v -= warning_count as f64 * 0.1;
if has_write_conflicts {
v -= vr.conflicts.len() as f64 * self.config.conflict_penalty;
}
v.max(0.1)
};
let action_ratio = if self.config.action_budget > 0 {
1.0 - (action_count as f64 / self.config.action_budget as f64).min(1.0)
} else {
1.0
};
let tool_ratio = if self.config.tool_call_budget > 0 {
1.0 - (tool_call_count as f64 / self.config.tool_call_budget as f64).min(1.0)
} else {
1.0
};
let parallelism_levels = vr.execution_levels.len();
let parallelism_bonus = if action_count > 1 && parallelism_levels > 0 {
1.0 - (parallelism_levels as f64 / action_count as f64).min(1.0)
} else {
0.0
};
let token_estimate = estimate_proposal_tokens(proposal);
let token_ratio = if self.config.token_budget > 0 {
1.0 - (token_estimate as f64 / self.config.token_budget as f64).min(1.0)
} else {
1.0
};
let tw = self.config.token_weight.clamp(0.0, 1.0);
let rest = 1.0 - tw;
let cost_efficiency = (action_ratio * (0.4 * rest)
+ tool_ratio * (0.4 * rest)
+ parallelism_bonus * (0.2 * rest)
+ token_ratio * tw)
.clamp(0.0, 1.0);
let historical_confidence = feedback
.map(|f| f.proposal_tool_confidence(proposal))
.unwrap_or(1.0);
let score = if error_count > 0 {
0.0
} else {
let cw = self.config.cost_weight.clamp(0.0, 1.0);
let base = validity * (1.0 - cw) + cost_efficiency * cw;
let fw = self.config.feedback_weight.clamp(0.0, 1.0);
base * (1.0 - fw + fw * historical_confidence)
};
let quality_per_token = if token_estimate > 0 {
score / token_estimate as f64
} else {
score
};
ScoredProposal {
index,
score,
validity,
cost_efficiency,
error_count,
warning_count,
action_count,
tool_call_count,
parallelism_levels,
valid: vr.valid,
state_keys_written,
has_write_conflicts,
historical_confidence,
token_estimate,
quality_per_token,
}
}
pub fn rank(
&self,
candidates: &[ActionProposal],
initial_state: Option<&HashMap<String, Value>>,
registered_tools: Option<&HashSet<String>>,
) -> Vec<ScoredProposal> {
self.rank_with_feedback(candidates, initial_state, registered_tools, None)
}
pub fn rank_with_feedback(
&self,
candidates: &[ActionProposal],
initial_state: Option<&HashMap<String, Value>>,
registered_tools: Option<&HashSet<String>>,
feedback: Option<&ToolFeedback>,
) -> Vec<ScoredProposal> {
let mut scored: Vec<ScoredProposal> = candidates
.iter()
.enumerate()
.map(|(i, p)| self.score_indexed(i, p, initial_state, registered_tools, feedback))
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.action_count.cmp(&b.action_count))
});
scored
}
pub fn pick_best(
&self,
candidates: &[ActionProposal],
initial_state: Option<&HashMap<String, Value>>,
registered_tools: Option<&HashSet<String>>,
) -> Option<(usize, ScoredProposal)> {
let ranked = self.rank_with_feedback(candidates, initial_state, registered_tools, None);
ranked.into_iter().find(|s| s.valid).map(|s| (s.index, s))
}
}
#[cfg(test)]
mod tests {
use super::*;
use car_ir::*;
fn tool_call(tool: &str, params: HashMap<String, Value>) -> Action {
Action {
id: format!("a-{}", tool),
action_type: ActionType::ToolCall,
tool: Some(tool.to_string()),
parameters: params,
preconditions: vec![],
expected_effects: HashMap::new(),
state_dependencies: vec![],
idempotent: false,
max_retries: 3,
failure_behavior: FailureBehavior::Abort,
timeout_ms: None,
metadata: HashMap::new(),
}
}
fn state_write(key: &str, value: Value) -> Action {
Action {
id: format!("sw-{}", key),
action_type: ActionType::StateWrite,
tool: None,
parameters: [
("key".to_string(), Value::from(key)),
("value".to_string(), value),
]
.into(),
preconditions: vec![],
expected_effects: HashMap::new(),
state_dependencies: vec![],
idempotent: false,
max_retries: 0,
failure_behavior: FailureBehavior::Abort,
timeout_ms: None,
metadata: HashMap::new(),
}
}
fn proposal(id: &str, actions: Vec<Action>) -> ActionProposal {
ActionProposal {
id: id.to_string(),
source: "test".to_string(),
actions,
timestamp: chrono::Utc::now(),
context: HashMap::new(),
}
}
#[test]
fn score_clean_proposal() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["search".into()].into();
let p = proposal(
"p1",
vec![tool_call(
"search",
[("q".into(), Value::from("rust"))].into(),
)],
);
let scored = planner.score(&p, None, Some(&tools));
assert!(scored.valid);
assert!(scored.score > 0.5);
assert_eq!(scored.error_count, 0);
assert_eq!(scored.action_count, 1);
assert_eq!(scored.tool_call_count, 1);
}
#[test]
fn score_invalid_proposal_unregistered_tool() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["search".into()].into();
let p = proposal("p1", vec![tool_call("nonexistent", HashMap::new())]);
let scored = planner.score(&p, None, Some(&tools));
assert!(!scored.valid);
assert_eq!(scored.validity, 0.0);
assert!(scored.error_count > 0);
}
#[test]
fn rank_prefers_valid_over_invalid() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["search".into()].into();
let valid = proposal(
"valid",
vec![tool_call(
"search",
[("q".into(), Value::from("test"))].into(),
)],
);
let invalid = proposal("invalid", vec![tool_call("nonexistent", HashMap::new())]);
let ranked = planner.rank(&[invalid, valid], None, Some(&tools));
assert!(ranked[0].valid);
assert!(!ranked[1].valid);
assert_eq!(ranked[0].index, 1); }
#[test]
fn rank_prefers_cheaper_among_valid() {
let planner = Planner::new(PlannerConfig {
cost_weight: 0.5, action_budget: 10,
tool_call_budget: 5,
..Default::default()
});
let tools: HashSet<String> = ["a".into(), "b".into(), "c".into()].into();
let cheap = proposal("cheap", vec![tool_call("a", HashMap::new())]);
let expensive = proposal(
"expensive",
vec![
tool_call("a", HashMap::new()),
tool_call("b", HashMap::new()),
tool_call("c", HashMap::new()),
],
);
let ranked = planner.rank(&[expensive, cheap], None, Some(&tools));
assert_eq!(ranked[0].index, 1); assert!(ranked[0].cost_efficiency > ranked[1].cost_efficiency);
}
#[test]
fn pick_best_skips_invalid() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["ok".into()].into();
let bad = proposal("bad", vec![tool_call("nonexistent", HashMap::new())]);
let good = proposal("good", vec![tool_call("ok", HashMap::new())]);
let result = planner.pick_best(&[bad, good], None, Some(&tools));
assert!(result.is_some());
let (idx, scored) = result.unwrap();
assert_eq!(idx, 1);
assert!(scored.valid);
}
#[test]
fn pick_best_returns_none_when_all_invalid() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = HashSet::new();
let bad1 = proposal("bad1", vec![tool_call("x", HashMap::new())]);
let bad2 = proposal("bad2", vec![tool_call("y", HashMap::new())]);
let result = planner.pick_best(&[bad1, bad2], None, Some(&tools));
assert!(result.is_none());
}
#[test]
fn score_state_write_only() {
let planner = Planner::new(PlannerConfig::default());
let p = proposal("sw", vec![state_write("key", Value::from("value"))]);
let scored = planner.score(&p, None, None);
assert!(scored.valid);
assert_eq!(scored.tool_call_count, 0);
assert_eq!(scored.action_count, 1);
}
#[test]
fn parallelism_bonus_rewards_independent_actions() {
let planner = Planner::new(PlannerConfig {
cost_weight: 0.5,
action_budget: 10,
tool_call_budget: 5,
..Default::default()
});
let tools: HashSet<String> = ["a".into(), "b".into()].into();
let parallel = proposal(
"par",
vec![
tool_call("a", HashMap::new()),
tool_call("b", HashMap::new()),
],
);
let mut seq_actions = vec![
tool_call("a", HashMap::new()),
tool_call("b", HashMap::new()),
];
seq_actions[1].state_dependencies.push("key".into());
seq_actions[0]
.expected_effects
.insert("key".into(), Value::from("v"));
let sequential = proposal("seq", seq_actions);
let par_score = planner.score(¶llel, None, Some(&tools));
let seq_score = planner.score(&sequential, None, Some(&tools));
assert!(
par_score.cost_efficiency >= seq_score.cost_efficiency,
"parallel={:.3} should >= sequential={:.3}",
par_score.cost_efficiency,
seq_score.cost_efficiency
);
}
#[test]
fn state_write_tracks_keys() {
let planner = Planner::new(PlannerConfig::default());
let p = proposal(
"sw",
vec![
state_write("key_a", Value::from("val_a")),
state_write("key_b", Value::from("val_b")),
],
);
let scored = planner.score(&p, None, None);
assert!(scored.valid);
assert_eq!(scored.state_keys_written, 2);
assert!(!scored.has_write_conflicts);
}
#[test]
fn write_conflict_penalizes_score() {
let planner = Planner::new(PlannerConfig::default());
let p = proposal(
"conflict",
vec![
state_write("shared_key", Value::from("v1")),
state_write("shared_key", Value::from("v2")),
],
);
let scored = planner.score(&p, None, None);
assert!(scored.has_write_conflicts);
assert!(
scored.validity < 1.0,
"expected conflict penalty, got validity={:.3}",
scored.validity
);
}
#[test]
fn feedback_penalizes_tools_that_fail_often() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["reliable".into(), "flaky".into()].into();
let reliable_plan = proposal("reliable", vec![tool_call("reliable", HashMap::new())]);
let flaky_plan = proposal("flaky", vec![tool_call("flaky", HashMap::new())]);
let feedback = ToolFeedback {
tool_success_rates: [("reliable".into(), 0.95), ("flaky".into(), 0.2)].into(),
};
let ranked = planner.rank_with_feedback(
&[flaky_plan, reliable_plan],
None,
Some(&tools),
Some(&feedback),
);
assert_eq!(ranked[0].index, 1, "reliable plan should rank first");
assert!(ranked[0].historical_confidence > ranked[1].historical_confidence);
assert!(
ranked[0].score > ranked[1].score,
"reliable={:.3} should > flaky={:.3}",
ranked[0].score,
ranked[1].score
);
}
#[test]
fn feedback_from_trajectories() {
use car_memgine::{TraceEvent, Trajectory, TrajectoryOutcome};
let trajectories = vec![
Trajectory {
proposal_id: "t1".into(),
source: "test".into(),
action_count: 1,
events: vec![TraceEvent {
kind: "action_succeeded".into(),
action_id: Some("a1".into()),
tool: Some("good_tool".into()),
data: serde_json::json!({}),
..Default::default()
}],
outcome: TrajectoryOutcome::Success,
timestamp: chrono::Utc::now(),
duration_ms: 100.0,
replan_attempts: 0,
},
Trajectory {
proposal_id: "t2".into(),
source: "test".into(),
action_count: 1,
events: vec![TraceEvent {
kind: "action_failed".into(),
action_id: Some("a2".into()),
tool: Some("bad_tool".into()),
data: serde_json::json!({}),
..Default::default()
}],
outcome: TrajectoryOutcome::Failed,
timestamp: chrono::Utc::now(),
duration_ms: 50.0,
replan_attempts: 0,
},
];
let feedback = ToolFeedback::from_trajectories(&trajectories);
assert!((feedback.rate("good_tool") - 1.0).abs() < 0.01);
assert!((feedback.rate("bad_tool") - 0.0).abs() < 0.01);
assert!((feedback.rate("unknown") - 0.5).abs() < 0.01); }
#[test]
fn token_estimate_surfaced_and_nonzero() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["a".into()].into();
let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
let scored = planner.score(&p, None, Some(&tools));
assert!(scored.token_estimate > 0);
assert!(scored.quality_per_token > 0.0);
assert!(
(scored.quality_per_token - scored.score / scored.token_estimate as f64).abs() < 1e-9
);
}
#[test]
fn tiny_token_budget_penalizes_proposals() {
let planner = Planner::new(PlannerConfig {
token_budget: 10, token_weight: 0.8,
cost_weight: 0.5,
..Default::default()
});
let tools: HashSet<String> = ["a".into()].into();
let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
let scored = planner.score(&p, None, Some(&tools));
assert!(scored.valid);
assert!(
scored.cost_efficiency < 0.5,
"small token budget should tank cost_efficiency, got {:.3}",
scored.cost_efficiency
);
}
#[test]
fn token_weight_zero_matches_legacy_blend() {
let planner = Planner::new(PlannerConfig {
token_weight: 0.0,
..Default::default()
});
let tools: HashSet<String> = ["a".into()].into();
let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
let scored = planner.score(&p, None, Some(&tools));
let action_ratio = 1.0 - (1.0 / 20.0); let tool_ratio = 1.0 - (1.0 / 10.0); let expected = action_ratio * 0.4 + tool_ratio * 0.4 + 0.0 * 0.2;
assert!(
(scored.cost_efficiency - expected).abs() < 1e-6,
"cost_efficiency={:.6} expected={:.6}",
scored.cost_efficiency,
expected
);
}
#[test]
fn no_feedback_means_full_confidence() {
let planner = Planner::new(PlannerConfig::default());
let tools: HashSet<String> = ["a".into()].into();
let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
let scored = planner.score(&p, None, Some(&tools));
assert!((scored.historical_confidence - 1.0).abs() < 0.01);
}
}