1use std::cell::RefCell;
18use std::collections::BTreeMap;
19use std::rc::Rc;
20use std::thread::LocalKey;
21
22use crate::value::{VmError, VmValue};
23
24thread_local! {
25 static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
26 static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
27 static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
28 static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
29}
30
31pub(crate) fn reset_call_budget_state() {
34 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
35 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
36 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
37 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
38}
39
40#[derive(Clone, Copy)]
45enum CallBudgetKind {
46 McpCalls,
47 PgQueries,
48}
49
50impl CallBudgetKind {
51 fn limit_label(self) -> &'static str {
53 match self {
54 CallBudgetKind::McpCalls => "mcp_calls",
55 CallBudgetKind::PgQueries => "pg_queries",
56 }
57 }
58
59 fn noun(self, plural: bool) -> &'static str {
62 match (self, plural) {
63 (CallBudgetKind::McpCalls, false) => "MCP call",
64 (CallBudgetKind::McpCalls, true) => "MCP calls",
65 (CallBudgetKind::PgQueries, false) => "Postgres query",
66 (CallBudgetKind::PgQueries, true) => "Postgres queries",
67 }
68 }
69}
70
71fn charge(
76 budget: &'static LocalKey<RefCell<Option<u64>>>,
77 count: &'static LocalKey<RefCell<u64>>,
78 kind: CallBudgetKind,
79) -> Result<(), VmError> {
80 let Some(max) = budget.with(|b| *b.borrow()) else {
81 return Ok(());
82 };
83 let spent = count.with(|c| {
84 let mut slot = c.borrow_mut();
85 *slot = slot.saturating_add(1);
86 *slot
87 });
88 if spent > max {
89 return Err(budget_exceeded_error(kind, spent, max));
90 }
91 Ok(())
92}
93
94fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
99 let mut dict = BTreeMap::new();
100 dict.insert(
101 "category".to_string(),
102 VmValue::String(Rc::from("budget_exceeded")),
103 );
104 dict.insert("kind".to_string(), VmValue::String(Rc::from("terminal")));
105 dict.insert(
106 "reason".to_string(),
107 VmValue::String(Rc::from("budget_exceeded")),
108 );
109 dict.insert(
110 "limit".to_string(),
111 VmValue::String(Rc::from(kind.limit_label())),
112 );
113 dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
114 dict.insert("spent".to_string(), VmValue::Int(spent as i64));
115 dict.insert(
116 "message".to_string(),
117 VmValue::String(Rc::from(format!(
118 "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
119 kind.limit_label(),
120 spent,
121 max,
122 kind.noun(max != 1),
123 ))),
124 );
125 VmError::Thrown(VmValue::Dict(Rc::new(dict)))
126}
127
128#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
131pub struct McpCallBudgetGuard {
132 previous_budget: Option<u64>,
133 previous_count: u64,
134}
135
136impl Drop for McpCallBudgetGuard {
137 fn drop(&mut self) {
138 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
139 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
140 }
141}
142
143pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
148 let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
149 let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
150 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
151 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
152 McpCallBudgetGuard {
153 previous_budget,
154 previous_count,
155 }
156}
157
158pub fn charge_mcp_call() -> Result<(), VmError> {
161 charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
162}
163
164#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
167pub struct PgQueryBudgetGuard {
168 previous_budget: Option<u64>,
169 previous_count: u64,
170}
171
172impl Drop for PgQueryBudgetGuard {
173 fn drop(&mut self) {
174 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
175 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
176 }
177}
178
179pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
184 let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
185 let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
186 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
187 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
188 PgQueryBudgetGuard {
189 previous_budget,
190 previous_count,
191 }
192}
193
194pub fn charge_pg_query() -> Result<(), VmError> {
198 charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::value::{error_to_category, ErrorCategory};
205
206 #[test]
207 fn charge_is_noop_without_installed_budget() {
208 reset_call_budget_state();
209 for _ in 0..1000 {
210 assert!(charge_mcp_call().is_ok());
211 assert!(charge_pg_query().is_ok());
212 }
213 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
215 assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
216 }
217
218 #[test]
219 fn mcp_budget_admits_up_to_ceiling_then_rejects() {
220 reset_call_budget_state();
221 let _guard = install_mcp_call_budget(2);
222 assert!(charge_mcp_call().is_ok());
223 assert!(charge_mcp_call().is_ok());
224 let third = charge_mcp_call();
225 let err = third.expect_err("third call must exceed mcp_calls: 2");
226 assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
227 match &err {
228 VmError::Thrown(VmValue::Dict(d)) => {
229 assert_eq!(
230 d.get("limit").map(|v| v.display()).as_deref(),
231 Some("mcp_calls")
232 );
233 assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
234 assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
235 }
236 other => panic!("expected structured Thrown dict, got {other:?}"),
237 }
238 reset_call_budget_state();
239 }
240
241 #[test]
242 fn pg_budget_message_pluralises_and_names_dimension() {
243 reset_call_budget_state();
244 let _guard = install_pg_query_budget(1);
245 assert!(charge_pg_query().is_ok());
246 let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
247 match &err {
248 VmError::Thrown(VmValue::Dict(d)) => {
249 let message = d.get("message").map(|v| v.display()).unwrap_or_default();
250 assert!(
251 message.contains("pg_queries budget exceeded"),
252 "got: {message}"
253 );
254 assert!(message.contains("Postgres query"), "got: {message}");
255 }
256 other => panic!("expected structured Thrown dict, got {other:?}"),
257 }
258 reset_call_budget_state();
259 }
260
261 #[test]
262 fn nested_guard_restores_outer_budget_and_count_on_drop() {
263 reset_call_budget_state();
264 let outer = install_mcp_call_budget(5);
265 assert!(charge_mcp_call().is_ok());
266 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
267
268 {
269 let _inner = install_mcp_call_budget(1);
271 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
272 assert!(charge_mcp_call().is_ok());
273 assert!(charge_mcp_call().is_err());
274 }
275
276 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
278 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
279 drop(outer);
280 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
281 reset_call_budget_state();
282 }
283}