use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::multi_agent::types::GoapPlan;
pub const DEFAULT_MAX_STEPS_PER_INVOCATION: u8 = 8;
pub const DEFAULT_MAX_CONCURRENT_PER_SPECIALIST: u8 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanBudget {
pub tenant_id: Option<String>,
pub specialist: String,
pub max_steps_per_invocation: u8,
pub max_concurrent_per_specialist: u8,
}
impl PlanBudget {
pub fn new(tenant_id: Option<String>, specialist: impl Into<String>) -> Self {
Self {
tenant_id,
specialist: specialist.into(),
max_steps_per_invocation: DEFAULT_MAX_STEPS_PER_INVOCATION,
max_concurrent_per_specialist: DEFAULT_MAX_CONCURRENT_PER_SPECIALIST,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "verdict", rename_all = "snake_case")]
pub enum BudgetVerdict {
Ok,
WouldExceedSteps { current: u8, max: u8 },
WouldExceedConcurrency { current: u8, max: u8 },
}
pub fn check_budget(plan: &GoapPlan, budget: &PlanBudget, in_flight: &[Uuid]) -> BudgetVerdict {
let step_count = plan.steps.len();
let step_count_u8 = u8::try_from(step_count).unwrap_or(u8::MAX);
if step_count_u8 > budget.max_steps_per_invocation {
return BudgetVerdict::WouldExceedSteps {
current: step_count_u8,
max: budget.max_steps_per_invocation,
};
}
let concurrency = u8::try_from(in_flight.len()).unwrap_or(u8::MAX);
if concurrency >= budget.max_concurrent_per_specialist {
return BudgetVerdict::WouldExceedConcurrency {
current: concurrency,
max: budget.max_concurrent_per_specialist,
};
}
BudgetVerdict::Ok
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi_agent::goap_adapter::load_action_catalogue;
use crate::multi_agent::planner::plan;
#[test]
fn defaults_match_constants() {
let b = PlanBudget::new(Some("local".into()), "cfa-equity-analyst");
assert_eq!(b.max_steps_per_invocation, DEFAULT_MAX_STEPS_PER_INVOCATION);
assert_eq!(
b.max_concurrent_per_specialist,
DEFAULT_MAX_CONCURRENT_PER_SPECIALIST
);
}
#[test]
fn ok_when_within_budget() {
let cat = load_action_catalogue();
let p = plan("dcf for AAPL", &cat).unwrap();
let b = PlanBudget::new(Some("local".into()), "cfa-equity-analyst");
assert_eq!(check_budget(&p, &b, &[]), BudgetVerdict::Ok);
}
#[test]
fn rejects_when_steps_exceed_cap() {
let cat = load_action_catalogue();
let p = plan("initiate coverage on PFE", &cat).unwrap();
let mut b = PlanBudget::new(Some("local".into()), "cfa-equity-analyst");
b.max_steps_per_invocation = 2;
match check_budget(&p, &b, &[]) {
BudgetVerdict::WouldExceedSteps { current, max } => {
assert!(current > max);
assert_eq!(max, 2);
}
_ => panic!("expected WouldExceedSteps verdict"),
}
}
#[test]
fn rejects_when_concurrency_at_cap() {
let cat = load_action_catalogue();
let p = plan("dcf for AAPL", &cat).unwrap();
let b = PlanBudget::new(Some("local".into()), "cfa-equity-analyst");
let in_flight = vec![Uuid::now_v7()];
assert!(matches!(
check_budget(&p, &b, &in_flight),
BudgetVerdict::WouldExceedConcurrency { .. }
));
}
}