use parking_lot::Mutex;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use crate::agent::guardrail::{GuardAction, Guardrail};
use crate::error::Error;
use crate::llm::types::ToolCall;
pub struct BudgetRule {
pub tool_pattern: String,
pub max_calls: usize,
}
pub struct ActionBudgetGuardrail {
budgets: Vec<BudgetRule>,
default_budget: Option<usize>,
counts: Mutex<HashMap<String, usize>>,
}
impl ActionBudgetGuardrail {
pub fn builder() -> ActionBudgetGuardrailBuilder {
ActionBudgetGuardrailBuilder {
budgets: Vec::new(),
default_budget: None,
}
}
}
fn pattern_matches(pattern: &str, name: &str) -> bool {
if pattern == "*" {
true
} else if let Some(prefix) = pattern.strip_suffix('*') {
name.starts_with(prefix)
} else {
pattern == name
}
}
impl Guardrail for ActionBudgetGuardrail {
fn name(&self) -> &str {
"action_budget"
}
fn pre_tool(
&self,
call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
let tool_name = &call.name;
let matched_pattern = self
.budgets
.iter()
.find(|rule| pattern_matches(&rule.tool_pattern, tool_name));
let (pattern_key, max_calls) = match matched_pattern {
Some(rule) => (rule.tool_pattern.clone(), rule.max_calls),
None => match self.default_budget {
Some(max) => ("*".to_owned(), max),
None => return Box::pin(async { Ok(GuardAction::Allow) }),
},
};
let action = {
let mut counts = self.counts.lock();
let entry = counts.entry(pattern_key.clone()).or_insert(0);
if *entry + 1 > max_calls {
GuardAction::deny(format!(
"Tool `{}` denied: budget exhausted for pattern `{}` ({}/{})",
tool_name,
pattern_key,
*entry + 1,
max_calls
))
} else {
*entry += 1;
GuardAction::Allow
}
};
Box::pin(async move { Ok(action) })
}
}
pub struct ActionBudgetGuardrailBuilder {
budgets: Vec<BudgetRule>,
default_budget: Option<usize>,
}
impl ActionBudgetGuardrailBuilder {
pub fn rule(mut self, tool_pattern: impl Into<String>, max_calls: usize) -> Self {
self.budgets.push(BudgetRule {
tool_pattern: tool_pattern.into(),
max_calls,
});
self
}
pub fn default_budget(mut self, max_calls: usize) -> Self {
self.default_budget = Some(max_calls);
self
}
pub fn build(self) -> ActionBudgetGuardrail {
ActionBudgetGuardrail {
budgets: self.budgets,
default_budget: self.default_budget,
counts: Mutex::new(HashMap::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_call(name: &str) -> ToolCall {
ToolCall {
id: "c1".into(),
name: name.into(),
input: serde_json::json!({}),
}
}
#[tokio::test]
async fn five_bash_calls_allowed_sixth_denied() {
let g = ActionBudgetGuardrail::builder().rule("bash", 5).build();
for i in 1..=5 {
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert_eq!(action, GuardAction::Allow, "call {i} should be allowed");
}
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied(), "6th call should be denied");
}
#[tokio::test]
async fn default_budget_applies_to_unconfigured_tools() {
let g = ActionBudgetGuardrail::builder()
.rule("bash", 10)
.default_budget(2)
.build();
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn different_patterns_track_independently() {
let g = ActionBudgetGuardrail::builder()
.rule("bash", 1)
.rule("read", 1)
.build();
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied());
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn glob_matching_works() {
let g = ActionBudgetGuardrail::builder().rule("write*", 2).build();
let action = g.pre_tool(&test_call("write_file")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("write_config")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("write_other")).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn first_matching_rule_wins() {
let g = ActionBudgetGuardrail::builder()
.rule("bash", 1) .rule("*", 100) .build();
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied());
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
}
#[tokio::test]
async fn no_rules_no_default_allows_all() {
let g = ActionBudgetGuardrail::builder().build();
for _ in 0..100 {
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
}
}
#[test]
fn builder_pattern_works() {
let g = ActionBudgetGuardrail::builder()
.rule("bash", 5)
.rule("write*", 10)
.default_budget(20)
.build();
assert_eq!(g.budgets.len(), 2);
assert_eq!(g.budgets[0].tool_pattern, "bash");
assert_eq!(g.budgets[0].max_calls, 5);
assert_eq!(g.budgets[1].tool_pattern, "write*");
assert_eq!(g.budgets[1].max_calls, 10);
assert_eq!(g.default_budget, Some(20));
}
#[tokio::test]
async fn count_never_resets() {
let g = ActionBudgetGuardrail::builder().rule("bash", 2).build();
g.pre_tool(&test_call("bash")).await.unwrap();
g.pre_tool(&test_call("bash")).await.unwrap();
g.pre_tool(&test_call("read")).await.unwrap();
g.pre_tool(&test_call("write")).await.unwrap();
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied());
}
#[test]
fn meta_name() {
let g = ActionBudgetGuardrail::builder().build();
assert_eq!(g.name(), "action_budget");
}
#[test]
fn pattern_matches_exact() {
assert!(pattern_matches("bash", "bash"));
assert!(!pattern_matches("bash", "read"));
}
#[test]
fn pattern_matches_glob() {
assert!(pattern_matches("write*", "write_file"));
assert!(pattern_matches("write*", "write"));
assert!(!pattern_matches("write*", "read"));
}
#[test]
fn pattern_matches_wildcard() {
assert!(pattern_matches("*", "anything"));
assert!(pattern_matches("*", ""));
}
}