1use std::collections::{BTreeMap, BTreeSet};
22
23use devboy_core::{ToolValueModel, ValueClass};
24
25use crate::adaptive_config::AdaptiveConfig;
26
27#[derive(Debug, Clone, Default)]
29pub struct TurnContext<'a> {
30 pub recent_tools: &'a [String],
33 pub budget_tokens: u32,
37 pub intent_keywords: Vec<String>,
44}
45
46impl<'a> TurnContext<'a> {
47 pub fn new(recent_tools: &'a [String], budget_tokens: u32) -> Self {
48 Self {
49 recent_tools,
50 budget_tokens,
51 intent_keywords: Vec::new(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub struct PlannedCall {
59 pub tool: String,
62 pub projection: Option<String>,
65 pub probability: f32,
68 pub estimated_cost_bytes: u32,
70 pub estimated_cost_tokens: u32,
72 pub value_class: ValueClass,
76}
77
78#[derive(Debug, Clone, Default)]
82pub struct EnrichmentPlan {
83 pub calls: Vec<PlannedCall>,
84 pub total_cost_tokens: u32,
85 pub remaining_budget_tokens: u32,
86 pub declined: Vec<DeclineReason>,
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub struct DeclineReason {
94 pub tool: String,
95 pub reason: DeclineKind,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99#[non_exhaustive]
100pub enum DeclineKind {
101 BudgetExceeded,
108}
109
110#[derive(Debug, Clone, Copy)]
113pub struct PlannerOptions {
114 pub min_followup_probability: f32,
118 pub bytes_per_token: u32,
122 pub latency_penalty_ms: Option<u32>,
127 pub dollar_penalty: Option<f32>,
131}
132
133impl Default for PlannerOptions {
134 fn default() -> Self {
135 Self {
136 min_followup_probability: 0.5,
137 bytes_per_token: 4,
138 latency_penalty_ms: None,
139 dollar_penalty: None,
140 }
141 }
142}
143
144impl PlannerOptions {
145 pub fn cost_aware() -> Self {
149 Self {
150 latency_penalty_ms: Some(5_000),
151 dollar_penalty: Some(0.10),
152 ..Self::default()
153 }
154 }
155}
156
157pub fn build_plan(
173 config: &AdaptiveConfig,
174 context: &TurnContext<'_>,
175 options: PlannerOptions,
176) -> EnrichmentPlan {
177 let candidates = enumerate_candidates(config, context, options);
178
179 let mut scored: Vec<(f32, Candidate)> = candidates
183 .into_iter()
184 .map(|c| {
185 let density = if matches!(c.model.value_class, ValueClass::AuditOnly) {
186 f32::INFINITY
187 } else {
188 let cost_tokens = cost_tokens_for(&c.model, options.bytes_per_token).max(1) as f32;
189 let boost = intent_boost(&c.model, &context.intent_keywords);
190 let penalty = cost_penalty(&c.model, &options);
191 value_score(&c.model) * boost * penalty / cost_tokens
192 };
193 (density, c)
194 })
195 .collect();
196 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
199
200 let mut plan = EnrichmentPlan {
201 remaining_budget_tokens: context.budget_tokens,
202 ..EnrichmentPlan::default()
203 };
204
205 for (_, c) in scored {
206 let raw_cost_tokens = cost_tokens_for(&c.model, options.bytes_per_token);
207 let cost_bytes = (c.model.cost_model.typical_kb * 1024.0) as u32;
208
209 let is_free = c.model.excluded_from_budget();
210 let cost_tokens = if is_free {
218 raw_cost_tokens
219 } else {
220 raw_cost_tokens.max(1)
221 };
222 if !is_free && cost_tokens > plan.remaining_budget_tokens {
223 plan.declined.push(DeclineReason {
224 tool: c.tool.clone(),
225 reason: DeclineKind::BudgetExceeded,
226 });
227 continue;
228 }
229
230 plan.calls.push(PlannedCall {
231 tool: c.tool,
232 projection: c.projection,
233 probability: c.probability,
234 estimated_cost_bytes: cost_bytes,
235 estimated_cost_tokens: cost_tokens,
236 value_class: c.model.value_class,
237 });
238 if !is_free {
239 plan.total_cost_tokens = plan.total_cost_tokens.saturating_add(cost_tokens);
240 plan.remaining_budget_tokens = plan.remaining_budget_tokens.saturating_sub(cost_tokens);
241 }
242 }
243
244 plan
245}
246
247struct Candidate {
250 tool: String,
251 projection: Option<String>,
252 probability: f32,
253 model: ToolValueModel,
254}
255
256fn enumerate_candidates(
257 config: &AdaptiveConfig,
258 context: &TurnContext<'_>,
259 options: PlannerOptions,
260) -> Vec<Candidate> {
261 let mut by_tool: BTreeMap<String, (Option<String>, f32)> = BTreeMap::new();
264 let recent_set: BTreeSet<&str> = context.recent_tools.iter().map(String::as_str).collect();
265
266 for trigger in context.recent_tools {
267 let Some(model) = config.effective_tool_value_model(trigger) else {
268 continue;
269 };
270 for link in &model.follow_up {
271 if link.probability < options.min_followup_probability {
272 continue;
273 }
274 if link.tool == *trigger {
277 continue;
278 }
279 if recent_set.contains(link.tool.as_str()) {
281 continue;
282 }
283 let entry = by_tool
284 .entry(link.tool.clone())
285 .or_insert((link.projection.clone(), link.probability));
286 if link.probability > entry.1 {
287 entry.0 = link.projection.clone();
288 entry.1 = link.probability;
289 }
290 }
291 }
292
293 by_tool
294 .into_iter()
295 .map(|(tool, (projection, probability))| {
296 let model = config
302 .effective_tool_value_model(&tool)
303 .cloned()
304 .unwrap_or_default();
305 Candidate {
306 tool,
307 projection,
308 probability,
309 model,
310 }
311 })
312 .collect()
313}
314
315fn cost_tokens_for(model: &ToolValueModel, bytes_per_token: u32) -> u32 {
316 let bytes = (model.cost_model.typical_kb * 1024.0) as u32;
317 bytes.saturating_div(bytes_per_token.max(1))
318}
319
320fn value_score(model: &ToolValueModel) -> f32 {
321 match model.value_class {
322 ValueClass::Critical => 1.0,
323 ValueClass::Supporting => 0.5,
324 ValueClass::Optional => 0.2,
325 ValueClass::AuditOnly => 0.0,
326 }
327}
328
329fn cost_penalty(model: &ToolValueModel, options: &PlannerOptions) -> f32 {
336 let mut penalty = 1.0_f32;
337 if let (Some(knee), Some(latency)) =
338 (options.latency_penalty_ms, model.cost_model.latency_ms_p50)
339 && latency >= knee
340 {
341 penalty *= 0.5;
342 }
343 if let (Some(knee), Some(dollars)) = (options.dollar_penalty, model.cost_model.dollars)
344 && dollars >= knee
345 {
346 penalty *= 0.5;
347 }
348 penalty
349}
350
351fn intent_boost(model: &ToolValueModel, intent_keywords: &[String]) -> f32 {
367 if intent_keywords.is_empty() || model.field_groups.is_empty() {
368 return 1.0;
369 }
370 let lowered: Vec<String> = intent_keywords
371 .iter()
372 .map(|k| k.to_ascii_lowercase())
373 .collect();
374 let mut boost: f32 = 1.0;
375 for (_name, group) in model.field_groups.iter() {
376 if group.default_include {
377 continue; }
379 let any_match = group
380 .fields
381 .iter()
382 .any(|f| lowered.iter().any(|kw| f.to_ascii_lowercase().contains(kw)));
383 if any_match {
384 boost += group.estimated_value;
385 }
386 }
387 boost.min(2.5)
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::tool_defaults::default_tool_value_models;
394
395 fn config_with_defaults() -> AdaptiveConfig {
396 AdaptiveConfig {
397 tools: default_tool_value_models(),
398 ..AdaptiveConfig::default()
399 }
400 }
401
402 #[test]
403 fn empty_recent_tools_returns_empty_plan() {
404 let cfg = config_with_defaults();
405 let recent: Vec<String> = vec![];
406 let ctx = TurnContext::new(&recent, 1024);
407 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
408 assert!(plan.calls.is_empty());
409 assert_eq!(plan.total_cost_tokens, 0);
410 }
411
412 #[test]
413 fn after_grep_planner_prefetches_read_with_path_projection() {
414 let cfg = config_with_defaults();
415 let recent = vec!["Grep".to_string()];
416 let ctx = TurnContext::new(&recent, 4_000);
417 let plan = build_plan(
418 &cfg,
419 &ctx,
420 PlannerOptions {
421 min_followup_probability: 0.3,
422 ..Default::default()
423 },
424 );
425 let read = plan
426 .calls
427 .iter()
428 .find(|c| c.tool == "Read")
429 .expect("Read should be admitted after Grep");
430 assert_eq!(read.projection.as_deref(), Some("path"));
431 }
432
433 #[test]
434 fn after_websearch_planner_prefetches_webfetch_with_url_projection() {
435 let cfg = config_with_defaults();
436 let recent = vec!["WebSearch".to_string()];
437 let ctx = TurnContext::new(&recent, 4_000);
438 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
439 let fetch = plan
440 .calls
441 .iter()
442 .find(|c| c.tool == "WebFetch")
443 .expect("WebFetch should be admitted after WebSearch");
444 assert_eq!(fetch.projection.as_deref(), Some("url"));
445 }
446
447 #[test]
448 fn budget_exceeded_decline_recorded() {
449 let cfg = config_with_defaults();
450 let recent = vec!["Glob".to_string()];
451 let ctx = TurnContext::new(&recent, 50);
454 let plan = build_plan(
455 &cfg,
456 &ctx,
457 PlannerOptions {
458 min_followup_probability: 0.3,
459 ..Default::default()
460 },
461 );
462 assert!(
463 plan.declined
464 .iter()
465 .any(|d| d.tool == "Read" && d.reason == DeclineKind::BudgetExceeded),
466 "expected Read to be declined for budget, got {:?}",
467 plan.declined
468 );
469 }
470
471 #[test]
472 fn audit_only_tools_do_not_consume_budget() {
473 let mut cfg = AdaptiveConfig {
474 tools: default_tool_value_models(),
475 ..AdaptiveConfig::default()
476 };
477 let mut grep = cfg.tools.get("Grep").unwrap().clone();
480 grep.follow_up.push(devboy_core::FollowUpLink {
481 tool: "TaskUpdate".into(),
482 probability: 0.9,
483 ..devboy_core::FollowUpLink::default()
484 });
485 cfg.tools.insert("Grep".into(), grep);
486
487 let recent = vec!["Grep".to_string()];
488 let ctx = TurnContext::new(&recent, 1_000);
489 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
490 let task = plan
491 .calls
492 .iter()
493 .find(|c| c.tool == "TaskUpdate")
494 .expect("TaskUpdate should be admitted");
495 assert_eq!(task.value_class, ValueClass::AuditOnly);
496 assert_eq!(
498 plan.remaining_budget_tokens,
499 1_000 - critical_supporting_tokens(&plan)
500 );
501 }
502
503 fn critical_supporting_tokens(plan: &EnrichmentPlan) -> u32 {
504 plan.calls
505 .iter()
506 .filter(|c| !matches!(c.value_class, ValueClass::AuditOnly))
507 .map(|c| c.estimated_cost_tokens)
508 .sum()
509 }
510
511 #[test]
512 fn self_loops_skipped() {
513 let cfg = config_with_defaults();
514 let recent = vec!["Read".to_string()];
515 let ctx = TurnContext::new(&recent, 4_000);
516 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
517 assert!(
521 !plan.calls.iter().any(|c| c.tool == "Read"),
522 "Read self-loop should be skipped"
523 );
524 }
525
526 #[test]
527 fn already_used_tools_skipped() {
528 let cfg = config_with_defaults();
529 let recent = vec!["Read".to_string(), "Grep".to_string()];
532 let ctx = TurnContext::new(&recent, 4_000);
533 let plan = build_plan(
534 &cfg,
535 &ctx,
536 PlannerOptions {
537 min_followup_probability: 0.3,
538 ..Default::default()
539 },
540 );
541 assert!(
542 !plan.calls.iter().any(|c| c.tool == "Read"),
543 "Read already used in this turn should not be re-admitted"
544 );
545 }
546
547 #[test]
548 fn zero_typical_kb_supporting_tool_costs_at_least_one_token() {
549 let mut cfg = AdaptiveConfig::default();
554 let trigger = ToolValueModel {
555 follow_up: vec![devboy_core::FollowUpLink {
556 tool: "Cheap".into(),
557 probability: 1.0,
558 ..devboy_core::FollowUpLink::default()
559 }],
560 ..ToolValueModel::default()
561 };
562 let cheap = ToolValueModel {
563 value_class: ValueClass::Supporting,
564 cost_model: devboy_core::CostModel {
565 typical_kb: 0.0,
566 ..Default::default()
567 },
568 ..ToolValueModel::default()
569 };
570 cfg.tools.insert("Trigger".into(), trigger);
571 cfg.tools.insert("Cheap".into(), cheap);
572
573 let recent = vec!["Trigger".to_string()];
574 let ctx = TurnContext::new(&recent, 1);
575 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
576
577 let cheap_call = plan
578 .calls
579 .iter()
580 .find(|c| c.tool == "Cheap")
581 .expect("Cheap must still be admitted at budget=1");
582 assert_eq!(
583 cheap_call.estimated_cost_tokens, 1,
584 "zero-typical-kb non-AuditOnly tool must clamp to 1 token"
585 );
586 assert_eq!(
587 plan.remaining_budget_tokens, 0,
588 "budget must be drained by 1, not left at 1"
589 );
590
591 let ctx0 = TurnContext::new(&recent, 0);
595 let plan0 = build_plan(&cfg, &ctx0, PlannerOptions::default());
596 assert!(
597 plan0.calls.iter().all(|c| c.tool != "Cheap"),
598 "Cheap must be declined at budget=0 (clamp ≥ 1)"
599 );
600 assert!(
601 plan0.declined.iter().any(|d| d.tool == "Cheap"),
602 "decline reason must be recorded"
603 );
604 }
605
606 fn model_with_optin_group(field: &str, est: f32) -> ToolValueModel {
609 let mut groups = std::collections::BTreeMap::new();
610 groups.insert(
611 "must_have".into(),
612 devboy_core::FieldGroup {
613 fields: vec!["title".into(), "url".into()],
614 estimated_value: 1.0,
615 default_include: true,
616 },
617 );
618 groups.insert(
619 "nice_to_have".into(),
620 devboy_core::FieldGroup {
621 fields: vec![field.into()],
622 estimated_value: est,
623 default_include: false,
624 },
625 );
626 ToolValueModel {
627 value_class: ValueClass::Supporting,
628 field_groups: groups,
629 ..Default::default()
630 }
631 }
632
633 #[test]
634 fn intent_boost_neutral_with_no_keywords() {
635 let m = model_with_optin_group("snippet", 0.3);
636 assert!((intent_boost(&m, &[]) - 1.0).abs() < 1e-6);
637 }
638
639 #[test]
640 fn intent_boost_neutral_when_keyword_misses_optin_groups() {
641 let m = model_with_optin_group("snippet", 0.3);
642 let kw = vec!["totally_unrelated".to_string()];
643 assert!((intent_boost(&m, &kw) - 1.0).abs() < 1e-6);
644 }
645
646 #[test]
647 fn intent_boost_lifts_score_when_keyword_hits_optin_field() {
648 let m = model_with_optin_group("snippet", 0.3);
649 let kw = vec!["SNIPPET".to_string()]; let b = intent_boost(&m, &kw);
651 assert!((b - 1.3).abs() < 1e-6, "expected 1.3, got {b}");
652 }
653
654 #[test]
655 fn intent_boost_caps_at_2_5x() {
656 let mut groups = std::collections::BTreeMap::new();
659 for i in 0..5 {
660 groups.insert(
661 format!("g{i}"),
662 devboy_core::FieldGroup {
663 fields: vec!["foo".into()],
664 estimated_value: 1.0,
665 default_include: false,
666 },
667 );
668 }
669 let m = ToolValueModel {
670 field_groups: groups,
671 ..Default::default()
672 };
673 let kw = vec!["foo".to_string()];
674 let b = intent_boost(&m, &kw);
675 assert!((b - 2.5).abs() < 1e-6, "boost must clamp at 2.5, got {b}");
676 }
677
678 #[test]
679 fn intent_boost_changes_admit_order() {
680 let plain = ToolValueModel {
684 value_class: ValueClass::Supporting,
685 ..Default::default()
686 };
687 let intent_match = model_with_optin_group("snippet", 0.4);
688 let kw = vec!["snippet".to_string()];
689
690 let p_score = value_score(&plain) * intent_boost(&plain, &kw);
691 let i_score = value_score(&intent_match) * intent_boost(&intent_match, &kw);
692 assert!(
693 i_score > p_score,
694 "intent-matching tool must outrank the plain one: {i_score} vs {p_score}"
695 );
696 }
697
698 fn model_with_costs(latency_ms: Option<u32>, dollars: Option<f32>) -> ToolValueModel {
701 ToolValueModel {
702 value_class: ValueClass::Supporting,
703 cost_model: devboy_core::CostModel {
704 typical_kb: 1.0,
705 latency_ms_p50: latency_ms,
706 dollars,
707 ..Default::default()
708 },
709 ..Default::default()
710 }
711 }
712
713 #[test]
714 fn cost_penalty_neutral_when_options_are_none() {
715 let m = model_with_costs(Some(60_000), Some(1.0));
716 let opts = PlannerOptions::default();
717 assert!((cost_penalty(&m, &opts) - 1.0).abs() < 1e-6);
718 }
719
720 #[test]
721 fn cost_penalty_halves_for_slow_tool_when_latency_aware() {
722 let m = model_with_costs(Some(7_000), None);
723 let opts = PlannerOptions::cost_aware();
724 assert!((cost_penalty(&m, &opts) - 0.5).abs() < 1e-6);
725 }
726
727 #[test]
728 fn cost_penalty_halves_for_expensive_tool_when_dollar_aware() {
729 let m = model_with_costs(None, Some(0.50));
730 let opts = PlannerOptions::cost_aware();
731 assert!((cost_penalty(&m, &opts) - 0.5).abs() < 1e-6);
732 }
733
734 #[test]
735 fn cost_penalty_compounds_for_slow_and_expensive() {
736 let m = model_with_costs(Some(7_000), Some(0.50));
737 let opts = PlannerOptions::cost_aware();
738 assert!((cost_penalty(&m, &opts) - 0.25).abs() < 1e-6);
740 }
741
742 #[test]
743 fn cost_penalty_no_penalty_below_knee() {
744 let m = model_with_costs(Some(800), Some(0.01));
745 let opts = PlannerOptions::cost_aware();
746 assert!((cost_penalty(&m, &opts) - 1.0).abs() < 1e-6);
747 }
748
749 #[test]
750 fn cost_aware_planner_demotes_slow_tool_below_fast_one() {
751 let mut cfg = AdaptiveConfig::default();
752 let trigger = ToolValueModel {
753 follow_up: vec![
754 devboy_core::FollowUpLink {
755 tool: "FastTool".into(),
756 probability: 0.9,
757 ..Default::default()
758 },
759 devboy_core::FollowUpLink {
760 tool: "SlowTool".into(),
761 probability: 0.9,
762 ..Default::default()
763 },
764 ],
765 ..Default::default()
766 };
767 cfg.tools.insert("Trigger".into(), trigger);
768 cfg.tools
769 .insert("FastTool".into(), model_with_costs(Some(200), None));
770 cfg.tools
771 .insert("SlowTool".into(), model_with_costs(Some(20_000), None));
772
773 let recent = vec!["Trigger".to_string()];
774 let ctx = TurnContext::new(&recent, 1024);
777 let plan_blind = build_plan(&cfg, &ctx, PlannerOptions::default());
780 let plan_aware = build_plan(&cfg, &ctx, PlannerOptions::cost_aware());
782
783 let fast_first = plan_aware.calls.first().map(|c| c.tool.as_str());
784 assert_eq!(
785 fast_first,
786 Some("FastTool"),
787 "cost-aware planner must admit FastTool first; got {:?}",
788 plan_aware.calls.iter().map(|c| &c.tool).collect::<Vec<_>>()
789 );
790 assert_eq!(plan_aware.calls.len(), 2);
791 assert_eq!(plan_blind.calls.len(), 2);
792 }
793
794 #[test]
795 fn high_probability_link_wins_over_low_probability_for_same_tool() {
796 let mut cfg = AdaptiveConfig::default();
797 let a = ToolValueModel {
798 follow_up: vec![devboy_core::FollowUpLink {
799 tool: "Target".into(),
800 probability: 0.55,
801 projection: Some("low".into()),
802 ..devboy_core::FollowUpLink::default()
803 }],
804 ..ToolValueModel::default()
805 };
806 let b = ToolValueModel {
807 follow_up: vec![devboy_core::FollowUpLink {
808 tool: "Target".into(),
809 probability: 0.85,
810 projection: Some("high".into()),
811 ..devboy_core::FollowUpLink::default()
812 }],
813 ..ToolValueModel::default()
814 };
815 cfg.tools.insert("A".into(), a);
816 cfg.tools.insert("B".into(), b);
817 cfg.tools
818 .insert("Target".into(), ToolValueModel::critical_with_size(0.1));
819
820 let recent = vec!["A".to_string(), "B".to_string()];
821 let ctx = TurnContext::new(&recent, 1_000);
822 let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
823 let t = plan.calls.iter().find(|c| c.tool == "Target").unwrap();
824 assert_eq!(t.projection.as_deref(), Some("high"));
825 assert!((t.probability - 0.85).abs() < 1e-6);
826 }
827}