Skip to main content

harn_vm/
call_budget.rs

1//! Per-dispatch ceilings on outbound *call counts* — MCP tool calls and
2//! Postgres queries — mirroring the LLM cost/token budgets in
3//! [`crate::llm::cost`]. A `.harn` handler exported through `harn-serve`
4//! declares `@budget(mcp_calls: 20, pg_queries: 50)`; the dispatcher
5//! installs the matching guards for the lifetime of the call. Each
6//! charge increments a per-thread counter and, once the ceiling is
7//! crossed, raises a structured `BudgetExceeded`-categorised error that
8//! adapter codecs render as HTTP 429.
9//!
10//! Counters only advance while a budget is installed, so dispatches
11//! without a `@budget` declaration pay nothing and never accumulate
12//! cross-call state. Guards restore the prior ceiling and count on drop,
13//! keeping nested dispatches (a handler that re-enters the dispatcher)
14//! from leaking a tighter budget outward or a wider one back into a
15//! finished inner scope.
16
17use std::cell::RefCell;
18use std::collections::BTreeMap;
19use std::rc::Rc;
20use std::thread::LocalKey;
21
22use crate::value::{VmError, VmValue};
23
24thread_local! {
25    static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
26    static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
27    static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
28    static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
29}
30
31/// Reset thread-local call-budget state. Call between test runs so a
32/// guard that outlived an unwinding test cannot leak a ceiling.
33pub(crate) fn reset_call_budget_state() {
34    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
35    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
36    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
37    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
38}
39
40/// The two call-count dimensions. Each names the `@budget(...)` field it
41/// backs so the structured error carries the dimension that fired and
42/// `harn-serve`'s `budget_category_from_error` can recover it from the
43/// `limit` field without inspecting the message.
44#[derive(Clone, Copy)]
45enum CallBudgetKind {
46    McpCalls,
47    PgQueries,
48}
49
50impl CallBudgetKind {
51    /// The `@budget(...)` field name, surfaced as the error's `limit`.
52    fn limit_label(self) -> &'static str {
53        match self {
54            CallBudgetKind::McpCalls => "mcp_calls",
55            CallBudgetKind::PgQueries => "pg_queries",
56        }
57    }
58
59    /// Human-readable noun for the error message, pluralised to agree
60    /// with the ceiling count.
61    fn noun(self, plural: bool) -> &'static str {
62        match (self, plural) {
63            (CallBudgetKind::McpCalls, false) => "MCP call",
64            (CallBudgetKind::McpCalls, true) => "MCP calls",
65            (CallBudgetKind::PgQueries, false) => "Postgres query",
66            (CallBudgetKind::PgQueries, true) => "Postgres queries",
67        }
68    }
69}
70
71/// Increment the counter behind `budget`/`count` and raise once the
72/// ceiling is crossed. A `None` budget short-circuits — no install, no
73/// charge. The counter only advances while a ceiling is present so
74/// budget-free dispatches stay zero-cost.
75fn charge(
76    budget: &'static LocalKey<RefCell<Option<u64>>>,
77    count: &'static LocalKey<RefCell<u64>>,
78    kind: CallBudgetKind,
79) -> Result<(), VmError> {
80    let Some(max) = budget.with(|b| *b.borrow()) else {
81        return Ok(());
82    };
83    let spent = count.with(|c| {
84        let mut slot = c.borrow_mut();
85        *slot = slot.saturating_add(1);
86        *slot
87    });
88    if spent > max {
89        return Err(budget_exceeded_error(kind, spent, max));
90    }
91    Ok(())
92}
93
94/// Build the structured error rendered as HTTP 429. The `category` field
95/// routes it through `ErrorCategory::BudgetExceeded`; the `limit` field
96/// names the dimension so adapters report `code: "budget_exceeded"` with
97/// the precise `@budget(...)` field that fired.
98fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
99    let mut dict = BTreeMap::new();
100    dict.insert(
101        "category".to_string(),
102        VmValue::String(Rc::from("budget_exceeded")),
103    );
104    dict.insert("kind".to_string(), VmValue::String(Rc::from("terminal")));
105    dict.insert(
106        "reason".to_string(),
107        VmValue::String(Rc::from("budget_exceeded")),
108    );
109    dict.insert(
110        "limit".to_string(),
111        VmValue::String(Rc::from(kind.limit_label())),
112    );
113    dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
114    dict.insert("spent".to_string(), VmValue::Int(spent as i64));
115    dict.insert(
116        "message".to_string(),
117        VmValue::String(Rc::from(format!(
118            "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
119            kind.limit_label(),
120            spent,
121            max,
122            kind.noun(max != 1),
123        ))),
124    );
125    VmError::Thrown(VmValue::Dict(Rc::new(dict)))
126}
127
128/// RAII guard for [`install_mcp_call_budget`]. Restores the prior MCP
129/// call ceiling and count on drop.
130#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
131pub struct McpCallBudgetGuard {
132    previous_budget: Option<u64>,
133    previous_count: u64,
134}
135
136impl Drop for McpCallBudgetGuard {
137    fn drop(&mut self) {
138        MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
139        MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
140    }
141}
142
143/// Pin the per-dispatch MCP tool-call ceiling at `max` for the lifetime
144/// of the returned guard. Sourced from `@budget(mcp_calls: …)` on
145/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th call raises a
146/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
147pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
148    let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
149    let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
150    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
151    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
152    McpCallBudgetGuard {
153        previous_budget,
154        previous_count,
155    }
156}
157
158/// Charge one MCP tool call against the active `@budget(mcp_calls: …)`
159/// ceiling, if any. Called once per logical `mcp.call` dispatch.
160pub fn charge_mcp_call() -> Result<(), VmError> {
161    charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
162}
163
164/// RAII guard for [`install_pg_query_budget`]. Restores the prior
165/// Postgres query ceiling and count on drop.
166#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
167pub struct PgQueryBudgetGuard {
168    previous_budget: Option<u64>,
169    previous_count: u64,
170}
171
172impl Drop for PgQueryBudgetGuard {
173    fn drop(&mut self) {
174        PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
175        PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
176    }
177}
178
179/// Pin the per-dispatch Postgres query ceiling at `max` for the lifetime
180/// of the returned guard. Sourced from `@budget(pg_queries: …)` on
181/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th query raises a
182/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
183pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
184    let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
185    let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
186    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
187    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
188    PgQueryBudgetGuard {
189        previous_budget,
190        previous_count,
191    }
192}
193
194/// Charge one Postgres query against the active `@budget(pg_queries: …)`
195/// ceiling, if any. Called once per `pg_query` / `pg_query_one` /
196/// `pg_execute` statement (including mock-pool statements).
197pub fn charge_pg_query() -> Result<(), VmError> {
198    charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::value::{error_to_category, ErrorCategory};
205
206    #[test]
207    fn charge_is_noop_without_installed_budget() {
208        reset_call_budget_state();
209        for _ in 0..1000 {
210            assert!(charge_mcp_call().is_ok());
211            assert!(charge_pg_query().is_ok());
212        }
213        // No guard installed → counters never advance.
214        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
215        assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
216    }
217
218    #[test]
219    fn mcp_budget_admits_up_to_ceiling_then_rejects() {
220        reset_call_budget_state();
221        let _guard = install_mcp_call_budget(2);
222        assert!(charge_mcp_call().is_ok());
223        assert!(charge_mcp_call().is_ok());
224        let third = charge_mcp_call();
225        let err = third.expect_err("third call must exceed mcp_calls: 2");
226        assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
227        match &err {
228            VmError::Thrown(VmValue::Dict(d)) => {
229                assert_eq!(
230                    d.get("limit").map(|v| v.display()).as_deref(),
231                    Some("mcp_calls")
232                );
233                assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
234                assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
235            }
236            other => panic!("expected structured Thrown dict, got {other:?}"),
237        }
238        reset_call_budget_state();
239    }
240
241    #[test]
242    fn pg_budget_message_pluralises_and_names_dimension() {
243        reset_call_budget_state();
244        let _guard = install_pg_query_budget(1);
245        assert!(charge_pg_query().is_ok());
246        let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
247        match &err {
248            VmError::Thrown(VmValue::Dict(d)) => {
249                let message = d.get("message").map(|v| v.display()).unwrap_or_default();
250                assert!(
251                    message.contains("pg_queries budget exceeded"),
252                    "got: {message}"
253                );
254                assert!(message.contains("Postgres query"), "got: {message}");
255            }
256            other => panic!("expected structured Thrown dict, got {other:?}"),
257        }
258        reset_call_budget_state();
259    }
260
261    #[test]
262    fn nested_guard_restores_outer_budget_and_count_on_drop() {
263        reset_call_budget_state();
264        let outer = install_mcp_call_budget(5);
265        assert!(charge_mcp_call().is_ok());
266        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
267
268        {
269            // Nested dispatch installs a tighter ceiling and starts fresh.
270            let _inner = install_mcp_call_budget(1);
271            assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
272            assert!(charge_mcp_call().is_ok());
273            assert!(charge_mcp_call().is_err());
274        }
275
276        // Inner drop restores the outer ceiling and its accumulated count.
277        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
278        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
279        drop(outer);
280        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
281        reset_call_budget_state();
282    }
283}