1use car_ir::ActionProposal;
22use car_verify::VerifyResult;
23use serde::Serialize;
24use serde_json::Value;
25use std::collections::{HashMap, HashSet};
26
27#[derive(Debug, Clone, Default)]
30pub struct ToolFeedback {
31 pub tool_success_rates: HashMap<String, f64>,
33}
34
35impl ToolFeedback {
36 pub fn from_trajectories(trajectories: &[car_memgine::Trajectory]) -> Self {
38 let mut tool_outcomes: HashMap<String, (u64, u64)> = HashMap::new(); for traj in trajectories {
40 for event in &traj.events {
41 if let Some(ref tool) = event.tool {
42 let entry = tool_outcomes.entry(tool.clone()).or_default();
43 entry.1 += 1; if event.kind == "action_succeeded" {
45 entry.0 += 1; }
47 }
48 }
49 }
50
51 let tool_success_rates = tool_outcomes
52 .into_iter()
53 .map(|(tool, (success, total))| {
54 (
55 tool,
56 if total > 0 {
57 success as f64 / total as f64
58 } else {
59 0.5
60 },
61 )
62 })
63 .collect();
64
65 Self { tool_success_rates }
66 }
67
68 pub fn rate(&self, tool: &str) -> f64 {
70 self.tool_success_rates.get(tool).copied().unwrap_or(0.5)
71 }
72
73 pub fn proposal_tool_confidence(&self, proposal: &ActionProposal) -> f64 {
75 let tool_calls: Vec<&str> = proposal
76 .actions
77 .iter()
78 .filter(|a| a.action_type == car_ir::ActionType::ToolCall)
79 .filter_map(|a| a.tool.as_deref())
80 .collect();
81
82 if tool_calls.is_empty() {
83 return 1.0; }
85
86 let sum: f64 = tool_calls.iter().map(|t| self.rate(t)).sum();
87 sum / tool_calls.len() as f64
88 }
89}
90
91pub fn estimate_proposal_tokens(proposal: &ActionProposal) -> usize {
96 serde_json::to_string(proposal)
97 .map(|s| s.len() / 4)
98 .unwrap_or_else(|_| proposal.actions.len() * 32)
99}
100
101#[derive(Debug, Clone)]
103pub struct PlannerConfig {
104 pub cost_weight: f64,
107 pub action_budget: usize,
109 pub tool_call_budget: usize,
111 pub conflict_penalty: f64,
113 pub feedback_weight: f64,
117 pub token_budget: usize,
122 pub token_weight: f64,
125}
126
127impl Default for PlannerConfig {
128 fn default() -> Self {
129 Self {
130 cost_weight: 0.2,
131 action_budget: 20,
132 tool_call_budget: 10,
133 conflict_penalty: 0.15,
134 feedback_weight: 0.3,
135 token_budget: 4000,
136 token_weight: 0.25,
137 }
138 }
139}
140
141impl PlannerConfig {
142 pub fn from_cost_target(target: &car_ir::CostTarget) -> Self {
148 Self {
149 cost_weight: target.cost_weight.clamp(0.0, 1.0),
150 action_budget: target.target_actions as usize,
151 tool_call_budget: target.target_tool_calls as usize,
152 ..Default::default()
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize)]
159pub struct ScoredProposal {
160 pub index: usize,
162 pub score: f64,
164 pub validity: f64,
166 pub cost_efficiency: f64,
168 pub error_count: usize,
170 pub warning_count: usize,
172 pub action_count: usize,
174 pub tool_call_count: usize,
176 pub parallelism_levels: usize,
178 pub valid: bool,
180 pub state_keys_written: usize,
182 pub has_write_conflicts: bool,
184 pub historical_confidence: f64,
187 pub token_estimate: usize,
191 pub quality_per_token: f64,
194}
195
196pub struct Planner {
198 config: PlannerConfig,
199}
200
201impl Planner {
202 pub fn new(config: PlannerConfig) -> Self {
203 Self { config }
204 }
205
206 pub fn score(
208 &self,
209 proposal: &ActionProposal,
210 initial_state: Option<&HashMap<String, Value>>,
211 registered_tools: Option<&HashSet<String>>,
212 ) -> ScoredProposal {
213 self.score_indexed(0, proposal, initial_state, registered_tools, None)
214 }
215
216 fn score_indexed(
218 &self,
219 index: usize,
220 proposal: &ActionProposal,
221 initial_state: Option<&HashMap<String, Value>>,
222 registered_tools: Option<&HashSet<String>>,
223 feedback: Option<&ToolFeedback>,
224 ) -> ScoredProposal {
225 let vr = car_verify::verify(proposal, initial_state, registered_tools, 100);
226 self.score_from_verify(index, proposal, &vr, feedback)
227 }
228
229 fn score_from_verify(
232 &self,
233 index: usize,
234 proposal: &ActionProposal,
235 vr: &VerifyResult,
236 feedback: Option<&ToolFeedback>,
237 ) -> ScoredProposal {
238 let error_count = vr.issues.iter().filter(|i| i.severity == "error").count();
239 let warning_count = vr.issues.iter().filter(|i| i.severity == "warning").count();
240
241 let action_count = proposal.actions.len();
242 let tool_call_count = proposal
243 .actions
244 .iter()
245 .filter(|a| a.action_type == car_ir::ActionType::ToolCall)
246 .count();
247
248 let state_keys_written = vr.simulated_state.len();
250 let has_write_conflicts = !vr.conflicts.is_empty();
251
252 let validity = if error_count > 0 {
254 0.0
255 } else {
256 let mut v = 1.0;
257 v -= warning_count as f64 * 0.1;
258 if has_write_conflicts {
260 v -= vr.conflicts.len() as f64 * self.config.conflict_penalty;
261 }
262 v.max(0.1)
263 };
264
265 let action_ratio = if self.config.action_budget > 0 {
267 1.0 - (action_count as f64 / self.config.action_budget as f64).min(1.0)
268 } else {
269 1.0
270 };
271 let tool_ratio = if self.config.tool_call_budget > 0 {
272 1.0 - (tool_call_count as f64 / self.config.tool_call_budget as f64).min(1.0)
273 } else {
274 1.0
275 };
276 let parallelism_levels = vr.execution_levels.len();
278 let parallelism_bonus = if action_count > 1 && parallelism_levels > 0 {
279 1.0 - (parallelism_levels as f64 / action_count as f64).min(1.0)
280 } else {
281 0.0
282 };
283 let token_estimate = estimate_proposal_tokens(proposal);
286 let token_ratio = if self.config.token_budget > 0 {
287 1.0 - (token_estimate as f64 / self.config.token_budget as f64).min(1.0)
288 } else {
289 1.0
290 };
291
292 let tw = self.config.token_weight.clamp(0.0, 1.0);
295 let rest = 1.0 - tw;
296 let cost_efficiency = (action_ratio * (0.4 * rest)
297 + tool_ratio * (0.4 * rest)
298 + parallelism_bonus * (0.2 * rest)
299 + token_ratio * tw)
300 .clamp(0.0, 1.0);
301
302 let historical_confidence = feedback
304 .map(|f| f.proposal_tool_confidence(proposal))
305 .unwrap_or(1.0); let score = if error_count > 0 {
310 0.0
311 } else {
312 let cw = self.config.cost_weight.clamp(0.0, 1.0);
313 let base = validity * (1.0 - cw) + cost_efficiency * cw;
315 let fw = self.config.feedback_weight.clamp(0.0, 1.0);
317 base * (1.0 - fw + fw * historical_confidence)
318 };
319
320 let quality_per_token = if token_estimate > 0 {
321 score / token_estimate as f64
322 } else {
323 score
324 };
325
326 ScoredProposal {
327 index,
328 score,
329 validity,
330 cost_efficiency,
331 error_count,
332 warning_count,
333 action_count,
334 tool_call_count,
335 parallelism_levels,
336 valid: vr.valid,
337 state_keys_written,
338 has_write_conflicts,
339 historical_confidence,
340 token_estimate,
341 quality_per_token,
342 }
343 }
344
345 pub fn rank(
348 &self,
349 candidates: &[ActionProposal],
350 initial_state: Option<&HashMap<String, Value>>,
351 registered_tools: Option<&HashSet<String>>,
352 ) -> Vec<ScoredProposal> {
353 self.rank_with_feedback(candidates, initial_state, registered_tools, None)
354 }
355
356 pub fn rank_with_feedback(
358 &self,
359 candidates: &[ActionProposal],
360 initial_state: Option<&HashMap<String, Value>>,
361 registered_tools: Option<&HashSet<String>>,
362 feedback: Option<&ToolFeedback>,
363 ) -> Vec<ScoredProposal> {
364 let mut scored: Vec<ScoredProposal> = candidates
365 .iter()
366 .enumerate()
367 .map(|(i, p)| self.score_indexed(i, p, initial_state, registered_tools, feedback))
368 .collect();
369
370 scored.sort_by(|a, b| {
372 b.score
373 .partial_cmp(&a.score)
374 .unwrap_or(std::cmp::Ordering::Equal)
375 .then(a.action_count.cmp(&b.action_count))
376 });
377
378 scored
379 }
380
381 pub fn pick_best(
384 &self,
385 candidates: &[ActionProposal],
386 initial_state: Option<&HashMap<String, Value>>,
387 registered_tools: Option<&HashSet<String>>,
388 ) -> Option<(usize, ScoredProposal)> {
389 let ranked = self.rank_with_feedback(candidates, initial_state, registered_tools, None);
390 ranked.into_iter().find(|s| s.valid).map(|s| (s.index, s))
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use car_ir::*;
398
399 fn tool_call(tool: &str, params: HashMap<String, Value>) -> Action {
400 Action {
401 id: format!("a-{}", tool),
402 action_type: ActionType::ToolCall,
403 tool: Some(tool.to_string()),
404 parameters: params,
405 preconditions: vec![],
406 expected_effects: HashMap::new(),
407 state_dependencies: vec![],
408 idempotent: false,
409 max_retries: 3,
410 failure_behavior: FailureBehavior::Abort,
411 timeout_ms: None,
412 metadata: HashMap::new(),
413 }
414 }
415
416 fn state_write(key: &str, value: Value) -> Action {
417 Action {
418 id: format!("sw-{}", key),
419 action_type: ActionType::StateWrite,
420 tool: None,
421 parameters: [
422 ("key".to_string(), Value::from(key)),
423 ("value".to_string(), value),
424 ]
425 .into(),
426 preconditions: vec![],
427 expected_effects: HashMap::new(),
428 state_dependencies: vec![],
429 idempotent: false,
430 max_retries: 0,
431 failure_behavior: FailureBehavior::Abort,
432 timeout_ms: None,
433 metadata: HashMap::new(),
434 }
435 }
436
437 fn proposal(id: &str, actions: Vec<Action>) -> ActionProposal {
438 ActionProposal {
439 id: id.to_string(),
440 source: "test".to_string(),
441 actions,
442 timestamp: chrono::Utc::now(),
443 context: HashMap::new(),
444 }
445 }
446
447 #[test]
448 fn score_clean_proposal() {
449 let planner = Planner::new(PlannerConfig::default());
450 let tools: HashSet<String> = ["search".into()].into();
451 let p = proposal(
452 "p1",
453 vec![tool_call(
454 "search",
455 [("q".into(), Value::from("rust"))].into(),
456 )],
457 );
458
459 let scored = planner.score(&p, None, Some(&tools));
460 assert!(scored.valid);
461 assert!(scored.score > 0.5);
462 assert_eq!(scored.error_count, 0);
463 assert_eq!(scored.action_count, 1);
464 assert_eq!(scored.tool_call_count, 1);
465 }
466
467 #[test]
468 fn score_invalid_proposal_unregistered_tool() {
469 let planner = Planner::new(PlannerConfig::default());
470 let tools: HashSet<String> = ["search".into()].into();
471 let p = proposal("p1", vec![tool_call("nonexistent", HashMap::new())]);
472
473 let scored = planner.score(&p, None, Some(&tools));
474 assert!(!scored.valid);
475 assert_eq!(scored.validity, 0.0);
476 assert!(scored.error_count > 0);
477 }
478
479 #[test]
480 fn rank_prefers_valid_over_invalid() {
481 let planner = Planner::new(PlannerConfig::default());
482 let tools: HashSet<String> = ["search".into()].into();
483
484 let valid = proposal(
485 "valid",
486 vec![tool_call(
487 "search",
488 [("q".into(), Value::from("test"))].into(),
489 )],
490 );
491 let invalid = proposal("invalid", vec![tool_call("nonexistent", HashMap::new())]);
492
493 let ranked = planner.rank(&[invalid, valid], None, Some(&tools));
494 assert!(ranked[0].valid);
495 assert!(!ranked[1].valid);
496 assert_eq!(ranked[0].index, 1); }
498
499 #[test]
500 fn rank_prefers_cheaper_among_valid() {
501 let planner = Planner::new(PlannerConfig {
502 cost_weight: 0.5, action_budget: 10,
504 tool_call_budget: 5,
505 ..Default::default()
506 });
507 let tools: HashSet<String> = ["a".into(), "b".into(), "c".into()].into();
508
509 let cheap = proposal("cheap", vec![tool_call("a", HashMap::new())]);
510 let expensive = proposal(
511 "expensive",
512 vec![
513 tool_call("a", HashMap::new()),
514 tool_call("b", HashMap::new()),
515 tool_call("c", HashMap::new()),
516 ],
517 );
518
519 let ranked = planner.rank(&[expensive, cheap], None, Some(&tools));
520 assert_eq!(ranked[0].index, 1); assert!(ranked[0].cost_efficiency > ranked[1].cost_efficiency);
523 }
524
525 #[test]
526 fn pick_best_skips_invalid() {
527 let planner = Planner::new(PlannerConfig::default());
528 let tools: HashSet<String> = ["ok".into()].into();
529
530 let bad = proposal("bad", vec![tool_call("nonexistent", HashMap::new())]);
531 let good = proposal("good", vec![tool_call("ok", HashMap::new())]);
532
533 let result = planner.pick_best(&[bad, good], None, Some(&tools));
534 assert!(result.is_some());
535 let (idx, scored) = result.unwrap();
536 assert_eq!(idx, 1);
537 assert!(scored.valid);
538 }
539
540 #[test]
541 fn pick_best_returns_none_when_all_invalid() {
542 let planner = Planner::new(PlannerConfig::default());
543 let tools: HashSet<String> = HashSet::new();
544
545 let bad1 = proposal("bad1", vec![tool_call("x", HashMap::new())]);
546 let bad2 = proposal("bad2", vec![tool_call("y", HashMap::new())]);
547
548 let result = planner.pick_best(&[bad1, bad2], None, Some(&tools));
549 assert!(result.is_none());
550 }
551
552 #[test]
553 fn score_state_write_only() {
554 let planner = Planner::new(PlannerConfig::default());
555 let p = proposal("sw", vec![state_write("key", Value::from("value"))]);
556
557 let scored = planner.score(&p, None, None);
558 assert!(scored.valid);
559 assert_eq!(scored.tool_call_count, 0);
560 assert_eq!(scored.action_count, 1);
561 }
562
563 #[test]
564 fn parallelism_bonus_rewards_independent_actions() {
565 let planner = Planner::new(PlannerConfig {
566 cost_weight: 0.5,
567 action_budget: 10,
568 tool_call_budget: 5,
569 ..Default::default()
570 });
571 let tools: HashSet<String> = ["a".into(), "b".into()].into();
572
573 let parallel = proposal(
575 "par",
576 vec![
577 tool_call("a", HashMap::new()),
578 tool_call("b", HashMap::new()),
579 ],
580 );
581
582 let mut seq_actions = vec![
584 tool_call("a", HashMap::new()),
585 tool_call("b", HashMap::new()),
586 ];
587 seq_actions[1].state_dependencies.push("key".into());
588 seq_actions[0]
590 .expected_effects
591 .insert("key".into(), Value::from("v"));
592 let sequential = proposal("seq", seq_actions);
593
594 let par_score = planner.score(¶llel, None, Some(&tools));
595 let seq_score = planner.score(&sequential, None, Some(&tools));
596
597 assert!(
599 par_score.cost_efficiency >= seq_score.cost_efficiency,
600 "parallel={:.3} should >= sequential={:.3}",
601 par_score.cost_efficiency,
602 seq_score.cost_efficiency
603 );
604 }
605
606 #[test]
607 fn state_write_tracks_keys() {
608 let planner = Planner::new(PlannerConfig::default());
609 let p = proposal(
610 "sw",
611 vec![
612 state_write("key_a", Value::from("val_a")),
613 state_write("key_b", Value::from("val_b")),
614 ],
615 );
616
617 let scored = planner.score(&p, None, None);
618 assert!(scored.valid);
619 assert_eq!(scored.state_keys_written, 2);
620 assert!(!scored.has_write_conflicts);
621 }
622
623 #[test]
624 fn write_conflict_penalizes_score() {
625 let planner = Planner::new(PlannerConfig::default());
626 let p = proposal(
628 "conflict",
629 vec![
630 state_write("shared_key", Value::from("v1")),
631 state_write("shared_key", Value::from("v2")),
632 ],
633 );
634
635 let scored = planner.score(&p, None, None);
636 assert!(scored.has_write_conflicts);
637 assert!(
639 scored.validity < 1.0,
640 "expected conflict penalty, got validity={:.3}",
641 scored.validity
642 );
643 }
644
645 #[test]
646 fn feedback_penalizes_tools_that_fail_often() {
647 let planner = Planner::new(PlannerConfig::default());
648 let tools: HashSet<String> = ["reliable".into(), "flaky".into()].into();
649
650 let reliable_plan = proposal("reliable", vec![tool_call("reliable", HashMap::new())]);
651 let flaky_plan = proposal("flaky", vec![tool_call("flaky", HashMap::new())]);
652
653 let feedback = ToolFeedback {
654 tool_success_rates: [("reliable".into(), 0.95), ("flaky".into(), 0.2)].into(),
655 };
656
657 let ranked = planner.rank_with_feedback(
658 &[flaky_plan, reliable_plan],
659 None,
660 Some(&tools),
661 Some(&feedback),
662 );
663
664 assert_eq!(ranked[0].index, 1, "reliable plan should rank first");
666 assert!(ranked[0].historical_confidence > ranked[1].historical_confidence);
667 assert!(
668 ranked[0].score > ranked[1].score,
669 "reliable={:.3} should > flaky={:.3}",
670 ranked[0].score,
671 ranked[1].score
672 );
673 }
674
675 #[test]
676 fn feedback_from_trajectories() {
677 use car_memgine::{TraceEvent, Trajectory, TrajectoryOutcome};
678
679 let trajectories = vec![
680 Trajectory {
681 proposal_id: "t1".into(),
682 source: "test".into(),
683 action_count: 1,
684 events: vec![TraceEvent {
685 kind: "action_succeeded".into(),
686 action_id: Some("a1".into()),
687 tool: Some("good_tool".into()),
688 data: serde_json::json!({}),
689 ..Default::default()
690 }],
691 outcome: TrajectoryOutcome::Success,
692 timestamp: chrono::Utc::now(),
693 duration_ms: 100.0,
694 replan_attempts: 0,
695 },
696 Trajectory {
697 proposal_id: "t2".into(),
698 source: "test".into(),
699 action_count: 1,
700 events: vec![TraceEvent {
701 kind: "action_failed".into(),
702 action_id: Some("a2".into()),
703 tool: Some("bad_tool".into()),
704 data: serde_json::json!({}),
705 ..Default::default()
706 }],
707 outcome: TrajectoryOutcome::Failed,
708 timestamp: chrono::Utc::now(),
709 duration_ms: 50.0,
710 replan_attempts: 0,
711 },
712 ];
713
714 let feedback = ToolFeedback::from_trajectories(&trajectories);
715 assert!((feedback.rate("good_tool") - 1.0).abs() < 0.01);
716 assert!((feedback.rate("bad_tool") - 0.0).abs() < 0.01);
717 assert!((feedback.rate("unknown") - 0.5).abs() < 0.01); }
719
720 #[test]
721 fn token_estimate_surfaced_and_nonzero() {
722 let planner = Planner::new(PlannerConfig::default());
723 let tools: HashSet<String> = ["a".into()].into();
724 let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
725
726 let scored = planner.score(&p, None, Some(&tools));
727 assert!(scored.token_estimate > 0);
728 assert!(scored.quality_per_token > 0.0);
729 assert!(
730 (scored.quality_per_token - scored.score / scored.token_estimate as f64).abs() < 1e-9
731 );
732 }
733
734 #[test]
735 fn tiny_token_budget_penalizes_proposals() {
736 let planner = Planner::new(PlannerConfig {
739 token_budget: 10, token_weight: 0.8,
741 cost_weight: 0.5,
742 ..Default::default()
743 });
744 let tools: HashSet<String> = ["a".into()].into();
745 let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
746 let scored = planner.score(&p, None, Some(&tools));
747 assert!(scored.valid);
748 assert!(
749 scored.cost_efficiency < 0.5,
750 "small token budget should tank cost_efficiency, got {:.3}",
751 scored.cost_efficiency
752 );
753 }
754
755 #[test]
756 fn token_weight_zero_matches_legacy_blend() {
757 let planner = Planner::new(PlannerConfig {
760 token_weight: 0.0,
761 ..Default::default()
762 });
763 let tools: HashSet<String> = ["a".into()].into();
764 let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
765 let scored = planner.score(&p, None, Some(&tools));
766 let action_ratio = 1.0 - (1.0 / 20.0); let tool_ratio = 1.0 - (1.0 / 10.0); let expected = action_ratio * 0.4 + tool_ratio * 0.4 + 0.0 * 0.2;
769 assert!(
770 (scored.cost_efficiency - expected).abs() < 1e-6,
771 "cost_efficiency={:.6} expected={:.6}",
772 scored.cost_efficiency,
773 expected
774 );
775 }
776
777 #[test]
778 fn no_feedback_means_full_confidence() {
779 let planner = Planner::new(PlannerConfig::default());
780 let tools: HashSet<String> = ["a".into()].into();
781 let p = proposal("p1", vec![tool_call("a", HashMap::new())]);
782
783 let scored = planner.score(&p, None, Some(&tools));
784 assert!((scored.historical_confidence - 1.0).abs() < 0.01);
785 }
786}