Skip to main content

harn_vm/
call_budget.rs

1//! Per-dispatch ceilings on outbound *call counts* — MCP tool calls and
2//! Postgres queries — mirroring the LLM cost/token budgets in
3//! [`crate::llm::cost`]. A `.harn` handler exported through `harn-serve`
4//! declares `@budget(mcp_calls: 20, pg_queries: 50)`; the dispatcher
5//! installs the matching guards for the lifetime of the call. Each
6//! charge increments a per-thread counter and, once the ceiling is
7//! crossed, raises a structured `BudgetExceeded`-categorised error that
8//! adapter codecs render as HTTP 429.
9//!
10//! Counters only advance while a budget is installed, so dispatches
11//! without a `@budget` declaration pay nothing and never accumulate
12//! cross-call state. Guards restore the prior ceiling and count on drop,
13//! keeping nested dispatches (a handler that re-enters the dispatcher)
14//! from leaking a tighter budget outward or a wider one back into a
15//! finished inner scope.
16
17use crate::value::VmDictExt;
18use std::cell::RefCell;
19use std::collections::BTreeMap;
20use std::thread::LocalKey;
21
22use crate::value::{VmError, VmValue};
23
24thread_local! {
25    static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
26    static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
27    static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
28    static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
29}
30
31/// Reset thread-local call-budget state. Call between test runs so a
32/// guard that outlived an unwinding test cannot leak a ceiling.
33pub(crate) fn reset_call_budget_state() {
34    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
35    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
36    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
37    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
38}
39
40/// The two call-count dimensions. Each names the `@budget(...)` field it
41/// backs so the structured error carries the dimension that fired and
42/// `harn-serve`'s `budget_category_from_error` can recover it from the
43/// `limit` field without inspecting the message.
44#[derive(Clone, Copy)]
45enum CallBudgetKind {
46    McpCalls,
47    PgQueries,
48}
49
50impl CallBudgetKind {
51    /// The `@budget(...)` field name, surfaced as the error's `limit`.
52    fn limit_label(self) -> &'static str {
53        match self {
54            CallBudgetKind::McpCalls => "mcp_calls",
55            CallBudgetKind::PgQueries => "pg_queries",
56        }
57    }
58
59    /// Human-readable noun for the error message, pluralised to agree
60    /// with the ceiling count.
61    fn noun(self, plural: bool) -> &'static str {
62        match (self, plural) {
63            (CallBudgetKind::McpCalls, false) => "MCP call",
64            (CallBudgetKind::McpCalls, true) => "MCP calls",
65            (CallBudgetKind::PgQueries, false) => "Postgres query",
66            (CallBudgetKind::PgQueries, true) => "Postgres queries",
67        }
68    }
69}
70
71/// Increment the counter behind `budget`/`count` and raise once the
72/// ceiling is crossed. A `None` budget short-circuits — no install, no
73/// charge. The counter only advances while a ceiling is present so
74/// budget-free dispatches stay zero-cost.
75fn charge(
76    budget: &'static LocalKey<RefCell<Option<u64>>>,
77    count: &'static LocalKey<RefCell<u64>>,
78    kind: CallBudgetKind,
79) -> Result<(), VmError> {
80    let Some(max) = budget.with(|b| *b.borrow()) else {
81        return Ok(());
82    };
83    let spent = count.with(|c| {
84        let mut slot = c.borrow_mut();
85        *slot = slot.saturating_add(1);
86        *slot
87    });
88    if spent > max {
89        return Err(budget_exceeded_error(kind, spent, max));
90    }
91    Ok(())
92}
93
94/// Build the structured error rendered as HTTP 429. The `category` field
95/// routes it through `ErrorCategory::BudgetExceeded`; the `limit` field
96/// names the dimension so adapters report `code: "budget_exceeded"` with
97/// the precise `@budget(...)` field that fired.
98fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
99    let mut dict = BTreeMap::new();
100    dict.put_str("category", "budget_exceeded");
101    dict.put_str("kind", "terminal");
102    dict.put_str("reason", "budget_exceeded");
103    dict.put_str("limit", kind.limit_label());
104    dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
105    dict.insert("spent".to_string(), VmValue::Int(spent as i64));
106    dict.put_str(
107        "message",
108        format!(
109            "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
110            kind.limit_label(),
111            spent,
112            max,
113            kind.noun(max != 1),
114        ),
115    );
116    VmError::Thrown(VmValue::dict(dict))
117}
118
119/// RAII guard for [`install_mcp_call_budget`]. Restores the prior MCP
120/// call ceiling and count on drop.
121#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
122pub struct McpCallBudgetGuard {
123    previous_budget: Option<u64>,
124    previous_count: u64,
125}
126
127impl Drop for McpCallBudgetGuard {
128    fn drop(&mut self) {
129        MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
130        MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
131    }
132}
133
134/// Pin the per-dispatch MCP tool-call ceiling at `max` for the lifetime
135/// of the returned guard. Sourced from `@budget(mcp_calls: …)` on
136/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th call raises a
137/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
138pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
139    let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
140    let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
141    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
142    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
143    McpCallBudgetGuard {
144        previous_budget,
145        previous_count,
146    }
147}
148
149/// Charge one MCP tool call against the active `@budget(mcp_calls: …)`
150/// ceiling, if any. Called once per logical `mcp.call` dispatch.
151pub fn charge_mcp_call() -> Result<(), VmError> {
152    charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
153}
154
155/// RAII guard for [`install_pg_query_budget`]. Restores the prior
156/// Postgres query ceiling and count on drop.
157#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
158pub struct PgQueryBudgetGuard {
159    previous_budget: Option<u64>,
160    previous_count: u64,
161}
162
163impl Drop for PgQueryBudgetGuard {
164    fn drop(&mut self) {
165        PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
166        PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
167    }
168}
169
170/// Pin the per-dispatch Postgres query ceiling at `max` for the lifetime
171/// of the returned guard. Sourced from `@budget(pg_queries: …)` on
172/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th query raises a
173/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
174pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
175    let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
176    let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
177    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
178    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
179    PgQueryBudgetGuard {
180        previous_budget,
181        previous_count,
182    }
183}
184
185/// Charge one Postgres query against the active `@budget(pg_queries: …)`
186/// ceiling, if any. Called once per `pg_query` / `pg_query_one` /
187/// `pg_execute` statement (including mock-pool statements).
188pub fn charge_pg_query() -> Result<(), VmError> {
189    charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::value::{error_to_category, ErrorCategory};
196
197    #[test]
198    fn charge_is_noop_without_installed_budget() {
199        reset_call_budget_state();
200        for _ in 0..1000 {
201            assert!(charge_mcp_call().is_ok());
202            assert!(charge_pg_query().is_ok());
203        }
204        // No guard installed → counters never advance.
205        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
206        assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
207    }
208
209    #[test]
210    fn mcp_budget_admits_up_to_ceiling_then_rejects() {
211        reset_call_budget_state();
212        let _guard = install_mcp_call_budget(2);
213        assert!(charge_mcp_call().is_ok());
214        assert!(charge_mcp_call().is_ok());
215        let third = charge_mcp_call();
216        let err = third.expect_err("third call must exceed mcp_calls: 2");
217        assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
218        match &err {
219            VmError::Thrown(VmValue::Dict(d)) => {
220                assert_eq!(
221                    d.get("limit").map(|v| v.display()).as_deref(),
222                    Some("mcp_calls")
223                );
224                assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
225                assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
226            }
227            other => panic!("expected structured Thrown dict, got {other:?}"),
228        }
229        reset_call_budget_state();
230    }
231
232    #[test]
233    fn pg_budget_message_pluralises_and_names_dimension() {
234        reset_call_budget_state();
235        let _guard = install_pg_query_budget(1);
236        assert!(charge_pg_query().is_ok());
237        let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
238        match &err {
239            VmError::Thrown(VmValue::Dict(d)) => {
240                let message = d.get("message").map(|v| v.display()).unwrap_or_default();
241                assert!(
242                    message.contains("pg_queries budget exceeded"),
243                    "got: {message}"
244                );
245                assert!(message.contains("Postgres query"), "got: {message}");
246            }
247            other => panic!("expected structured Thrown dict, got {other:?}"),
248        }
249        reset_call_budget_state();
250    }
251
252    #[test]
253    fn nested_guard_restores_outer_budget_and_count_on_drop() {
254        reset_call_budget_state();
255        let outer = install_mcp_call_budget(5);
256        assert!(charge_mcp_call().is_ok());
257        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
258
259        {
260            // Nested dispatch installs a tighter ceiling and starts fresh.
261            let _inner = install_mcp_call_budget(1);
262            assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
263            assert!(charge_mcp_call().is_ok());
264            assert!(charge_mcp_call().is_err());
265        }
266
267        // Inner drop restores the outer ceiling and its accumulated count.
268        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
269        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
270        drop(outer);
271        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
272        reset_call_budget_state();
273    }
274}