Skip to main content

harn_vm/
call_budget.rs

1//! Per-dispatch ceilings on outbound *call counts* — MCP tool calls and
2//! Postgres queries — mirroring the LLM cost/token budgets in
3//! [`crate::llm::cost`]. A `.harn` handler exported through `harn-serve`
4//! declares `@budget(mcp_calls: 20, pg_queries: 50)`; the dispatcher
5//! installs the matching guards for the lifetime of the call. Each
6//! charge increments a per-thread counter and, once the ceiling is
7//! crossed, raises a structured `BudgetExceeded`-categorised error that
8//! adapter codecs render as HTTP 429.
9//!
10//! Counters only advance while a budget is installed, so dispatches
11//! without a `@budget` declaration pay nothing and never accumulate
12//! cross-call state. Guards restore the prior ceiling and count on drop,
13//! keeping nested dispatches (a handler that re-enters the dispatcher)
14//! from leaking a tighter budget outward or a wider one back into a
15//! finished inner scope.
16
17use std::cell::RefCell;
18use std::collections::BTreeMap;
19use std::thread::LocalKey;
20
21use crate::value::{VmError, VmValue};
22
23thread_local! {
24    static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
25    static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
26    static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
27    static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
28}
29
30/// Reset thread-local call-budget state. Call between test runs so a
31/// guard that outlived an unwinding test cannot leak a ceiling.
32pub(crate) fn reset_call_budget_state() {
33    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
34    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
35    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
36    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
37}
38
39/// The two call-count dimensions. Each names the `@budget(...)` field it
40/// backs so the structured error carries the dimension that fired and
41/// `harn-serve`'s `budget_category_from_error` can recover it from the
42/// `limit` field without inspecting the message.
43#[derive(Clone, Copy)]
44enum CallBudgetKind {
45    McpCalls,
46    PgQueries,
47}
48
49impl CallBudgetKind {
50    /// The `@budget(...)` field name, surfaced as the error's `limit`.
51    fn limit_label(self) -> &'static str {
52        match self {
53            CallBudgetKind::McpCalls => "mcp_calls",
54            CallBudgetKind::PgQueries => "pg_queries",
55        }
56    }
57
58    /// Human-readable noun for the error message, pluralised to agree
59    /// with the ceiling count.
60    fn noun(self, plural: bool) -> &'static str {
61        match (self, plural) {
62            (CallBudgetKind::McpCalls, false) => "MCP call",
63            (CallBudgetKind::McpCalls, true) => "MCP calls",
64            (CallBudgetKind::PgQueries, false) => "Postgres query",
65            (CallBudgetKind::PgQueries, true) => "Postgres queries",
66        }
67    }
68}
69
70/// Increment the counter behind `budget`/`count` and raise once the
71/// ceiling is crossed. A `None` budget short-circuits — no install, no
72/// charge. The counter only advances while a ceiling is present so
73/// budget-free dispatches stay zero-cost.
74fn charge(
75    budget: &'static LocalKey<RefCell<Option<u64>>>,
76    count: &'static LocalKey<RefCell<u64>>,
77    kind: CallBudgetKind,
78) -> Result<(), VmError> {
79    let Some(max) = budget.with(|b| *b.borrow()) else {
80        return Ok(());
81    };
82    let spent = count.with(|c| {
83        let mut slot = c.borrow_mut();
84        *slot = slot.saturating_add(1);
85        *slot
86    });
87    if spent > max {
88        return Err(budget_exceeded_error(kind, spent, max));
89    }
90    Ok(())
91}
92
93/// Build the structured error rendered as HTTP 429. The `category` field
94/// routes it through `ErrorCategory::BudgetExceeded`; the `limit` field
95/// names the dimension so adapters report `code: "budget_exceeded"` with
96/// the precise `@budget(...)` field that fired.
97fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
98    let mut dict = BTreeMap::new();
99    dict.insert(
100        "category".to_string(),
101        VmValue::String(std::sync::Arc::from("budget_exceeded")),
102    );
103    dict.insert(
104        "kind".to_string(),
105        VmValue::String(std::sync::Arc::from("terminal")),
106    );
107    dict.insert(
108        "reason".to_string(),
109        VmValue::String(std::sync::Arc::from("budget_exceeded")),
110    );
111    dict.insert(
112        "limit".to_string(),
113        VmValue::String(std::sync::Arc::from(kind.limit_label())),
114    );
115    dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
116    dict.insert("spent".to_string(), VmValue::Int(spent as i64));
117    dict.insert(
118        "message".to_string(),
119        VmValue::String(std::sync::Arc::from(format!(
120            "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
121            kind.limit_label(),
122            spent,
123            max,
124            kind.noun(max != 1),
125        ))),
126    );
127    VmError::Thrown(VmValue::Dict(std::sync::Arc::new(dict)))
128}
129
130/// RAII guard for [`install_mcp_call_budget`]. Restores the prior MCP
131/// call ceiling and count on drop.
132#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
133pub struct McpCallBudgetGuard {
134    previous_budget: Option<u64>,
135    previous_count: u64,
136}
137
138impl Drop for McpCallBudgetGuard {
139    fn drop(&mut self) {
140        MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
141        MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
142    }
143}
144
145/// Pin the per-dispatch MCP tool-call ceiling at `max` for the lifetime
146/// of the returned guard. Sourced from `@budget(mcp_calls: …)` on
147/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th call raises a
148/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
149pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
150    let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
151    let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
152    MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
153    MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
154    McpCallBudgetGuard {
155        previous_budget,
156        previous_count,
157    }
158}
159
160/// Charge one MCP tool call against the active `@budget(mcp_calls: …)`
161/// ceiling, if any. Called once per logical `mcp.call` dispatch.
162pub fn charge_mcp_call() -> Result<(), VmError> {
163    charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
164}
165
166/// RAII guard for [`install_pg_query_budget`]. Restores the prior
167/// Postgres query ceiling and count on drop.
168#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
169pub struct PgQueryBudgetGuard {
170    previous_budget: Option<u64>,
171    previous_count: u64,
172}
173
174impl Drop for PgQueryBudgetGuard {
175    fn drop(&mut self) {
176        PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
177        PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
178    }
179}
180
181/// Pin the per-dispatch Postgres query ceiling at `max` for the lifetime
182/// of the returned guard. Sourced from `@budget(pg_queries: …)` on
183/// `.harn` handlers in `harn-serve`; the `(max + 1)`-th query raises a
184/// `BudgetExceeded`-categorised error adapters render as HTTP 429.
185pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
186    let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
187    let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
188    PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
189    PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
190    PgQueryBudgetGuard {
191        previous_budget,
192        previous_count,
193    }
194}
195
196/// Charge one Postgres query against the active `@budget(pg_queries: …)`
197/// ceiling, if any. Called once per `pg_query` / `pg_query_one` /
198/// `pg_execute` statement (including mock-pool statements).
199pub fn charge_pg_query() -> Result<(), VmError> {
200    charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::value::{error_to_category, ErrorCategory};
207
208    #[test]
209    fn charge_is_noop_without_installed_budget() {
210        reset_call_budget_state();
211        for _ in 0..1000 {
212            assert!(charge_mcp_call().is_ok());
213            assert!(charge_pg_query().is_ok());
214        }
215        // No guard installed → counters never advance.
216        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
217        assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
218    }
219
220    #[test]
221    fn mcp_budget_admits_up_to_ceiling_then_rejects() {
222        reset_call_budget_state();
223        let _guard = install_mcp_call_budget(2);
224        assert!(charge_mcp_call().is_ok());
225        assert!(charge_mcp_call().is_ok());
226        let third = charge_mcp_call();
227        let err = third.expect_err("third call must exceed mcp_calls: 2");
228        assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
229        match &err {
230            VmError::Thrown(VmValue::Dict(d)) => {
231                assert_eq!(
232                    d.get("limit").map(|v| v.display()).as_deref(),
233                    Some("mcp_calls")
234                );
235                assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
236                assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
237            }
238            other => panic!("expected structured Thrown dict, got {other:?}"),
239        }
240        reset_call_budget_state();
241    }
242
243    #[test]
244    fn pg_budget_message_pluralises_and_names_dimension() {
245        reset_call_budget_state();
246        let _guard = install_pg_query_budget(1);
247        assert!(charge_pg_query().is_ok());
248        let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
249        match &err {
250            VmError::Thrown(VmValue::Dict(d)) => {
251                let message = d.get("message").map(|v| v.display()).unwrap_or_default();
252                assert!(
253                    message.contains("pg_queries budget exceeded"),
254                    "got: {message}"
255                );
256                assert!(message.contains("Postgres query"), "got: {message}");
257            }
258            other => panic!("expected structured Thrown dict, got {other:?}"),
259        }
260        reset_call_budget_state();
261    }
262
263    #[test]
264    fn nested_guard_restores_outer_budget_and_count_on_drop() {
265        reset_call_budget_state();
266        let outer = install_mcp_call_budget(5);
267        assert!(charge_mcp_call().is_ok());
268        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
269
270        {
271            // Nested dispatch installs a tighter ceiling and starts fresh.
272            let _inner = install_mcp_call_budget(1);
273            assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
274            assert!(charge_mcp_call().is_ok());
275            assert!(charge_mcp_call().is_err());
276        }
277
278        // Inner drop restores the outer ceiling and its accumulated count.
279        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
280        assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
281        drop(outer);
282        assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
283        reset_call_budget_state();
284    }
285}