1use crate::value::VmDictExt;
18use std::cell::RefCell;
19use std::collections::BTreeMap;
20use std::thread::LocalKey;
21
22use crate::value::{VmError, VmValue};
23
24thread_local! {
25 static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
26 static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
27 static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
28 static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
29}
30
31pub(crate) fn reset_call_budget_state() {
34 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
35 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
36 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
37 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
38}
39
40#[derive(Clone, Copy)]
45enum CallBudgetKind {
46 McpCalls,
47 PgQueries,
48}
49
50impl CallBudgetKind {
51 fn limit_label(self) -> &'static str {
53 match self {
54 CallBudgetKind::McpCalls => "mcp_calls",
55 CallBudgetKind::PgQueries => "pg_queries",
56 }
57 }
58
59 fn noun(self, plural: bool) -> &'static str {
62 match (self, plural) {
63 (CallBudgetKind::McpCalls, false) => "MCP call",
64 (CallBudgetKind::McpCalls, true) => "MCP calls",
65 (CallBudgetKind::PgQueries, false) => "Postgres query",
66 (CallBudgetKind::PgQueries, true) => "Postgres queries",
67 }
68 }
69}
70
71fn charge(
76 budget: &'static LocalKey<RefCell<Option<u64>>>,
77 count: &'static LocalKey<RefCell<u64>>,
78 kind: CallBudgetKind,
79) -> Result<(), VmError> {
80 let Some(max) = budget.with(|b| *b.borrow()) else {
81 return Ok(());
82 };
83 let spent = count.with(|c| {
84 let mut slot = c.borrow_mut();
85 *slot = slot.saturating_add(1);
86 *slot
87 });
88 if spent > max {
89 return Err(budget_exceeded_error(kind, spent, max));
90 }
91 Ok(())
92}
93
94fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
99 let mut dict = BTreeMap::new();
100 dict.put_str("category", "budget_exceeded");
101 dict.put_str("kind", "terminal");
102 dict.put_str("reason", "budget_exceeded");
103 dict.put_str("limit", kind.limit_label());
104 dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
105 dict.insert("spent".to_string(), VmValue::Int(spent as i64));
106 dict.put_str(
107 "message",
108 format!(
109 "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
110 kind.limit_label(),
111 spent,
112 max,
113 kind.noun(max != 1),
114 ),
115 );
116 VmError::Thrown(VmValue::dict(dict))
117}
118
119#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
122pub struct McpCallBudgetGuard {
123 previous_budget: Option<u64>,
124 previous_count: u64,
125}
126
127impl Drop for McpCallBudgetGuard {
128 fn drop(&mut self) {
129 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
130 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
131 }
132}
133
134pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
139 let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
140 let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
141 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
142 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
143 McpCallBudgetGuard {
144 previous_budget,
145 previous_count,
146 }
147}
148
149pub fn charge_mcp_call() -> Result<(), VmError> {
152 charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
153}
154
155#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
158pub struct PgQueryBudgetGuard {
159 previous_budget: Option<u64>,
160 previous_count: u64,
161}
162
163impl Drop for PgQueryBudgetGuard {
164 fn drop(&mut self) {
165 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
166 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
167 }
168}
169
170pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
175 let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
176 let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
177 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
178 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
179 PgQueryBudgetGuard {
180 previous_budget,
181 previous_count,
182 }
183}
184
185pub fn charge_pg_query() -> Result<(), VmError> {
189 charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::value::{error_to_category, ErrorCategory};
196
197 #[test]
198 fn charge_is_noop_without_installed_budget() {
199 reset_call_budget_state();
200 for _ in 0..1000 {
201 assert!(charge_mcp_call().is_ok());
202 assert!(charge_pg_query().is_ok());
203 }
204 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
206 assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
207 }
208
209 #[test]
210 fn mcp_budget_admits_up_to_ceiling_then_rejects() {
211 reset_call_budget_state();
212 let _guard = install_mcp_call_budget(2);
213 assert!(charge_mcp_call().is_ok());
214 assert!(charge_mcp_call().is_ok());
215 let third = charge_mcp_call();
216 let err = third.expect_err("third call must exceed mcp_calls: 2");
217 assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
218 match &err {
219 VmError::Thrown(VmValue::Dict(d)) => {
220 assert_eq!(
221 d.get("limit").map(|v| v.display()).as_deref(),
222 Some("mcp_calls")
223 );
224 assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
225 assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
226 }
227 other => panic!("expected structured Thrown dict, got {other:?}"),
228 }
229 reset_call_budget_state();
230 }
231
232 #[test]
233 fn pg_budget_message_pluralises_and_names_dimension() {
234 reset_call_budget_state();
235 let _guard = install_pg_query_budget(1);
236 assert!(charge_pg_query().is_ok());
237 let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
238 match &err {
239 VmError::Thrown(VmValue::Dict(d)) => {
240 let message = d.get("message").map(|v| v.display()).unwrap_or_default();
241 assert!(
242 message.contains("pg_queries budget exceeded"),
243 "got: {message}"
244 );
245 assert!(message.contains("Postgres query"), "got: {message}");
246 }
247 other => panic!("expected structured Thrown dict, got {other:?}"),
248 }
249 reset_call_budget_state();
250 }
251
252 #[test]
253 fn nested_guard_restores_outer_budget_and_count_on_drop() {
254 reset_call_budget_state();
255 let outer = install_mcp_call_budget(5);
256 assert!(charge_mcp_call().is_ok());
257 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
258
259 {
260 let _inner = install_mcp_call_budget(1);
262 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
263 assert!(charge_mcp_call().is_ok());
264 assert!(charge_mcp_call().is_err());
265 }
266
267 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
269 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
270 drop(outer);
271 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
272 reset_call_budget_state();
273 }
274}