use std::collections::{BTreeMap, BTreeSet};
use devboy_core::{ToolValueModel, ValueClass};
use crate::adaptive_config::AdaptiveConfig;
#[derive(Debug, Clone, Default)]
pub struct TurnContext<'a> {
pub recent_tools: &'a [String],
pub budget_tokens: u32,
pub intent_keywords: Vec<String>,
}
impl<'a> TurnContext<'a> {
pub fn new(recent_tools: &'a [String], budget_tokens: u32) -> Self {
Self {
recent_tools,
budget_tokens,
intent_keywords: Vec::new(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PlannedCall {
pub tool: String,
pub projection: Option<String>,
pub probability: f32,
pub estimated_cost_bytes: u32,
pub estimated_cost_tokens: u32,
pub value_class: ValueClass,
}
#[derive(Debug, Clone, Default)]
pub struct EnrichmentPlan {
pub calls: Vec<PlannedCall>,
pub total_cost_tokens: u32,
pub remaining_budget_tokens: u32,
pub declined: Vec<DeclineReason>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DeclineReason {
pub tool: String,
pub reason: DeclineKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DeclineKind {
BudgetExceeded,
}
#[derive(Debug, Clone, Copy)]
pub struct PlannerOptions {
pub min_followup_probability: f32,
pub bytes_per_token: u32,
pub latency_penalty_ms: Option<u32>,
pub dollar_penalty: Option<f32>,
}
impl Default for PlannerOptions {
fn default() -> Self {
Self {
min_followup_probability: 0.5,
bytes_per_token: 4,
latency_penalty_ms: None,
dollar_penalty: None,
}
}
}
impl PlannerOptions {
pub fn cost_aware() -> Self {
Self {
latency_penalty_ms: Some(5_000),
dollar_penalty: Some(0.10),
..Self::default()
}
}
}
pub fn build_plan(
config: &AdaptiveConfig,
context: &TurnContext<'_>,
options: PlannerOptions,
) -> EnrichmentPlan {
let candidates = enumerate_candidates(config, context, options);
let mut scored: Vec<(f32, Candidate)> = candidates
.into_iter()
.map(|c| {
let density = if matches!(c.model.value_class, ValueClass::AuditOnly) {
f32::INFINITY
} else {
let cost_tokens = cost_tokens_for(&c.model, options.bytes_per_token).max(1) as f32;
let boost = intent_boost(&c.model, &context.intent_keywords);
let penalty = cost_penalty(&c.model, &options);
value_score(&c.model) * boost * penalty / cost_tokens
};
(density, c)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut plan = EnrichmentPlan {
remaining_budget_tokens: context.budget_tokens,
..EnrichmentPlan::default()
};
for (_, c) in scored {
let raw_cost_tokens = cost_tokens_for(&c.model, options.bytes_per_token);
let cost_bytes = (c.model.cost_model.typical_kb * 1024.0) as u32;
let is_free = c.model.excluded_from_budget();
let cost_tokens = if is_free {
raw_cost_tokens
} else {
raw_cost_tokens.max(1)
};
if !is_free && cost_tokens > plan.remaining_budget_tokens {
plan.declined.push(DeclineReason {
tool: c.tool.clone(),
reason: DeclineKind::BudgetExceeded,
});
continue;
}
plan.calls.push(PlannedCall {
tool: c.tool,
projection: c.projection,
probability: c.probability,
estimated_cost_bytes: cost_bytes,
estimated_cost_tokens: cost_tokens,
value_class: c.model.value_class,
});
if !is_free {
plan.total_cost_tokens = plan.total_cost_tokens.saturating_add(cost_tokens);
plan.remaining_budget_tokens = plan.remaining_budget_tokens.saturating_sub(cost_tokens);
}
}
plan
}
struct Candidate {
tool: String,
projection: Option<String>,
probability: f32,
model: ToolValueModel,
}
fn enumerate_candidates(
config: &AdaptiveConfig,
context: &TurnContext<'_>,
options: PlannerOptions,
) -> Vec<Candidate> {
let mut by_tool: BTreeMap<String, (Option<String>, f32)> = BTreeMap::new();
let recent_set: BTreeSet<&str> = context.recent_tools.iter().map(String::as_str).collect();
for trigger in context.recent_tools {
let Some(model) = config.effective_tool_value_model(trigger) else {
continue;
};
for link in &model.follow_up {
if link.probability < options.min_followup_probability {
continue;
}
if link.tool == *trigger {
continue;
}
if recent_set.contains(link.tool.as_str()) {
continue;
}
let entry = by_tool
.entry(link.tool.clone())
.or_insert((link.projection.clone(), link.probability));
if link.probability > entry.1 {
entry.0 = link.projection.clone();
entry.1 = link.probability;
}
}
}
by_tool
.into_iter()
.map(|(tool, (projection, probability))| {
let model = config
.effective_tool_value_model(&tool)
.cloned()
.unwrap_or_default();
Candidate {
tool,
projection,
probability,
model,
}
})
.collect()
}
fn cost_tokens_for(model: &ToolValueModel, bytes_per_token: u32) -> u32 {
let bytes = (model.cost_model.typical_kb * 1024.0) as u32;
bytes.saturating_div(bytes_per_token.max(1))
}
fn value_score(model: &ToolValueModel) -> f32 {
match model.value_class {
ValueClass::Critical => 1.0,
ValueClass::Supporting => 0.5,
ValueClass::Optional => 0.2,
ValueClass::AuditOnly => 0.0,
}
}
fn cost_penalty(model: &ToolValueModel, options: &PlannerOptions) -> f32 {
let mut penalty = 1.0_f32;
if let (Some(knee), Some(latency)) =
(options.latency_penalty_ms, model.cost_model.latency_ms_p50)
&& latency >= knee
{
penalty *= 0.5;
}
if let (Some(knee), Some(dollars)) = (options.dollar_penalty, model.cost_model.dollars)
&& dollars >= knee
{
penalty *= 0.5;
}
penalty
}
fn intent_boost(model: &ToolValueModel, intent_keywords: &[String]) -> f32 {
if intent_keywords.is_empty() || model.field_groups.is_empty() {
return 1.0;
}
let lowered: Vec<String> = intent_keywords
.iter()
.map(|k| k.to_ascii_lowercase())
.collect();
let mut boost: f32 = 1.0;
for (_name, group) in model.field_groups.iter() {
if group.default_include {
continue; }
let any_match = group
.fields
.iter()
.any(|f| lowered.iter().any(|kw| f.to_ascii_lowercase().contains(kw)));
if any_match {
boost += group.estimated_value;
}
}
boost.min(2.5)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool_defaults::default_tool_value_models;
fn config_with_defaults() -> AdaptiveConfig {
AdaptiveConfig {
tools: default_tool_value_models(),
..AdaptiveConfig::default()
}
}
#[test]
fn empty_recent_tools_returns_empty_plan() {
let cfg = config_with_defaults();
let recent: Vec<String> = vec![];
let ctx = TurnContext::new(&recent, 1024);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
assert!(plan.calls.is_empty());
assert_eq!(plan.total_cost_tokens, 0);
}
#[test]
fn after_grep_planner_prefetches_read_with_path_projection() {
let cfg = config_with_defaults();
let recent = vec!["Grep".to_string()];
let ctx = TurnContext::new(&recent, 4_000);
let plan = build_plan(
&cfg,
&ctx,
PlannerOptions {
min_followup_probability: 0.3,
..Default::default()
},
);
let read = plan
.calls
.iter()
.find(|c| c.tool == "Read")
.expect("Read should be admitted after Grep");
assert_eq!(read.projection.as_deref(), Some("path"));
}
#[test]
fn after_websearch_planner_prefetches_webfetch_with_url_projection() {
let cfg = config_with_defaults();
let recent = vec!["WebSearch".to_string()];
let ctx = TurnContext::new(&recent, 4_000);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
let fetch = plan
.calls
.iter()
.find(|c| c.tool == "WebFetch")
.expect("WebFetch should be admitted after WebSearch");
assert_eq!(fetch.projection.as_deref(), Some("url"));
}
#[test]
fn budget_exceeded_decline_recorded() {
let cfg = config_with_defaults();
let recent = vec!["Glob".to_string()];
let ctx = TurnContext::new(&recent, 50);
let plan = build_plan(
&cfg,
&ctx,
PlannerOptions {
min_followup_probability: 0.3,
..Default::default()
},
);
assert!(
plan.declined
.iter()
.any(|d| d.tool == "Read" && d.reason == DeclineKind::BudgetExceeded),
"expected Read to be declined for budget, got {:?}",
plan.declined
);
}
#[test]
fn audit_only_tools_do_not_consume_budget() {
let mut cfg = AdaptiveConfig {
tools: default_tool_value_models(),
..AdaptiveConfig::default()
};
let mut grep = cfg.tools.get("Grep").unwrap().clone();
grep.follow_up.push(devboy_core::FollowUpLink {
tool: "TaskUpdate".into(),
probability: 0.9,
..devboy_core::FollowUpLink::default()
});
cfg.tools.insert("Grep".into(), grep);
let recent = vec!["Grep".to_string()];
let ctx = TurnContext::new(&recent, 1_000);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
let task = plan
.calls
.iter()
.find(|c| c.tool == "TaskUpdate")
.expect("TaskUpdate should be admitted");
assert_eq!(task.value_class, ValueClass::AuditOnly);
assert_eq!(
plan.remaining_budget_tokens,
1_000 - critical_supporting_tokens(&plan)
);
}
fn critical_supporting_tokens(plan: &EnrichmentPlan) -> u32 {
plan.calls
.iter()
.filter(|c| !matches!(c.value_class, ValueClass::AuditOnly))
.map(|c| c.estimated_cost_tokens)
.sum()
}
#[test]
fn self_loops_skipped() {
let cfg = config_with_defaults();
let recent = vec!["Read".to_string()];
let ctx = TurnContext::new(&recent, 4_000);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
assert!(
!plan.calls.iter().any(|c| c.tool == "Read"),
"Read self-loop should be skipped"
);
}
#[test]
fn already_used_tools_skipped() {
let cfg = config_with_defaults();
let recent = vec!["Read".to_string(), "Grep".to_string()];
let ctx = TurnContext::new(&recent, 4_000);
let plan = build_plan(
&cfg,
&ctx,
PlannerOptions {
min_followup_probability: 0.3,
..Default::default()
},
);
assert!(
!plan.calls.iter().any(|c| c.tool == "Read"),
"Read already used in this turn should not be re-admitted"
);
}
#[test]
fn zero_typical_kb_supporting_tool_costs_at_least_one_token() {
let mut cfg = AdaptiveConfig::default();
let trigger = ToolValueModel {
follow_up: vec![devboy_core::FollowUpLink {
tool: "Cheap".into(),
probability: 1.0,
..devboy_core::FollowUpLink::default()
}],
..ToolValueModel::default()
};
let cheap = ToolValueModel {
value_class: ValueClass::Supporting,
cost_model: devboy_core::CostModel {
typical_kb: 0.0,
..Default::default()
},
..ToolValueModel::default()
};
cfg.tools.insert("Trigger".into(), trigger);
cfg.tools.insert("Cheap".into(), cheap);
let recent = vec!["Trigger".to_string()];
let ctx = TurnContext::new(&recent, 1);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
let cheap_call = plan
.calls
.iter()
.find(|c| c.tool == "Cheap")
.expect("Cheap must still be admitted at budget=1");
assert_eq!(
cheap_call.estimated_cost_tokens, 1,
"zero-typical-kb non-AuditOnly tool must clamp to 1 token"
);
assert_eq!(
plan.remaining_budget_tokens, 0,
"budget must be drained by 1, not left at 1"
);
let ctx0 = TurnContext::new(&recent, 0);
let plan0 = build_plan(&cfg, &ctx0, PlannerOptions::default());
assert!(
plan0.calls.iter().all(|c| c.tool != "Cheap"),
"Cheap must be declined at budget=0 (clamp ≥ 1)"
);
assert!(
plan0.declined.iter().any(|d| d.tool == "Cheap"),
"decline reason must be recorded"
);
}
fn model_with_optin_group(field: &str, est: f32) -> ToolValueModel {
let mut groups = std::collections::BTreeMap::new();
groups.insert(
"must_have".into(),
devboy_core::FieldGroup {
fields: vec!["title".into(), "url".into()],
estimated_value: 1.0,
default_include: true,
},
);
groups.insert(
"nice_to_have".into(),
devboy_core::FieldGroup {
fields: vec![field.into()],
estimated_value: est,
default_include: false,
},
);
ToolValueModel {
value_class: ValueClass::Supporting,
field_groups: groups,
..Default::default()
}
}
#[test]
fn intent_boost_neutral_with_no_keywords() {
let m = model_with_optin_group("snippet", 0.3);
assert!((intent_boost(&m, &[]) - 1.0).abs() < 1e-6);
}
#[test]
fn intent_boost_neutral_when_keyword_misses_optin_groups() {
let m = model_with_optin_group("snippet", 0.3);
let kw = vec!["totally_unrelated".to_string()];
assert!((intent_boost(&m, &kw) - 1.0).abs() < 1e-6);
}
#[test]
fn intent_boost_lifts_score_when_keyword_hits_optin_field() {
let m = model_with_optin_group("snippet", 0.3);
let kw = vec!["SNIPPET".to_string()]; let b = intent_boost(&m, &kw);
assert!((b - 1.3).abs() < 1e-6, "expected 1.3, got {b}");
}
#[test]
fn intent_boost_caps_at_2_5x() {
let mut groups = std::collections::BTreeMap::new();
for i in 0..5 {
groups.insert(
format!("g{i}"),
devboy_core::FieldGroup {
fields: vec!["foo".into()],
estimated_value: 1.0,
default_include: false,
},
);
}
let m = ToolValueModel {
field_groups: groups,
..Default::default()
};
let kw = vec!["foo".to_string()];
let b = intent_boost(&m, &kw);
assert!((b - 2.5).abs() < 1e-6, "boost must clamp at 2.5, got {b}");
}
#[test]
fn intent_boost_changes_admit_order() {
let plain = ToolValueModel {
value_class: ValueClass::Supporting,
..Default::default()
};
let intent_match = model_with_optin_group("snippet", 0.4);
let kw = vec!["snippet".to_string()];
let p_score = value_score(&plain) * intent_boost(&plain, &kw);
let i_score = value_score(&intent_match) * intent_boost(&intent_match, &kw);
assert!(
i_score > p_score,
"intent-matching tool must outrank the plain one: {i_score} vs {p_score}"
);
}
fn model_with_costs(latency_ms: Option<u32>, dollars: Option<f32>) -> ToolValueModel {
ToolValueModel {
value_class: ValueClass::Supporting,
cost_model: devboy_core::CostModel {
typical_kb: 1.0,
latency_ms_p50: latency_ms,
dollars,
..Default::default()
},
..Default::default()
}
}
#[test]
fn cost_penalty_neutral_when_options_are_none() {
let m = model_with_costs(Some(60_000), Some(1.0));
let opts = PlannerOptions::default();
assert!((cost_penalty(&m, &opts) - 1.0).abs() < 1e-6);
}
#[test]
fn cost_penalty_halves_for_slow_tool_when_latency_aware() {
let m = model_with_costs(Some(7_000), None);
let opts = PlannerOptions::cost_aware();
assert!((cost_penalty(&m, &opts) - 0.5).abs() < 1e-6);
}
#[test]
fn cost_penalty_halves_for_expensive_tool_when_dollar_aware() {
let m = model_with_costs(None, Some(0.50));
let opts = PlannerOptions::cost_aware();
assert!((cost_penalty(&m, &opts) - 0.5).abs() < 1e-6);
}
#[test]
fn cost_penalty_compounds_for_slow_and_expensive() {
let m = model_with_costs(Some(7_000), Some(0.50));
let opts = PlannerOptions::cost_aware();
assert!((cost_penalty(&m, &opts) - 0.25).abs() < 1e-6);
}
#[test]
fn cost_penalty_no_penalty_below_knee() {
let m = model_with_costs(Some(800), Some(0.01));
let opts = PlannerOptions::cost_aware();
assert!((cost_penalty(&m, &opts) - 1.0).abs() < 1e-6);
}
#[test]
fn cost_aware_planner_demotes_slow_tool_below_fast_one() {
let mut cfg = AdaptiveConfig::default();
let trigger = ToolValueModel {
follow_up: vec![
devboy_core::FollowUpLink {
tool: "FastTool".into(),
probability: 0.9,
..Default::default()
},
devboy_core::FollowUpLink {
tool: "SlowTool".into(),
probability: 0.9,
..Default::default()
},
],
..Default::default()
};
cfg.tools.insert("Trigger".into(), trigger);
cfg.tools
.insert("FastTool".into(), model_with_costs(Some(200), None));
cfg.tools
.insert("SlowTool".into(), model_with_costs(Some(20_000), None));
let recent = vec!["Trigger".to_string()];
let ctx = TurnContext::new(&recent, 1024);
let plan_blind = build_plan(&cfg, &ctx, PlannerOptions::default());
let plan_aware = build_plan(&cfg, &ctx, PlannerOptions::cost_aware());
let fast_first = plan_aware.calls.first().map(|c| c.tool.as_str());
assert_eq!(
fast_first,
Some("FastTool"),
"cost-aware planner must admit FastTool first; got {:?}",
plan_aware.calls.iter().map(|c| &c.tool).collect::<Vec<_>>()
);
assert_eq!(plan_aware.calls.len(), 2);
assert_eq!(plan_blind.calls.len(), 2);
}
#[test]
fn high_probability_link_wins_over_low_probability_for_same_tool() {
let mut cfg = AdaptiveConfig::default();
let a = ToolValueModel {
follow_up: vec![devboy_core::FollowUpLink {
tool: "Target".into(),
probability: 0.55,
projection: Some("low".into()),
..devboy_core::FollowUpLink::default()
}],
..ToolValueModel::default()
};
let b = ToolValueModel {
follow_up: vec![devboy_core::FollowUpLink {
tool: "Target".into(),
probability: 0.85,
projection: Some("high".into()),
..devboy_core::FollowUpLink::default()
}],
..ToolValueModel::default()
};
cfg.tools.insert("A".into(), a);
cfg.tools.insert("B".into(), b);
cfg.tools
.insert("Target".into(), ToolValueModel::critical_with_size(0.1));
let recent = vec!["A".to_string(), "B".to_string()];
let ctx = TurnContext::new(&recent, 1_000);
let plan = build_plan(&cfg, &ctx, PlannerOptions::default());
let t = plan.calls.iter().find(|c| c.tool == "Target").unwrap();
assert_eq!(t.projection.as_deref(), Some("high"));
assert!((t.probability - 0.85).abs() < 1e-6);
}
}