1use std::cell::RefCell;
18use std::collections::BTreeMap;
19use std::thread::LocalKey;
20
21use crate::value::{VmError, VmValue};
22
23thread_local! {
24 static MCP_CALL_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
25 static MCP_CALL_COUNT: RefCell<u64> = const { RefCell::new(0) };
26 static PG_QUERY_BUDGET: RefCell<Option<u64>> = const { RefCell::new(None) };
27 static PG_QUERY_COUNT: RefCell<u64> = const { RefCell::new(0) };
28}
29
30pub(crate) fn reset_call_budget_state() {
33 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = None);
34 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
35 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = None);
36 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
37}
38
39#[derive(Clone, Copy)]
44enum CallBudgetKind {
45 McpCalls,
46 PgQueries,
47}
48
49impl CallBudgetKind {
50 fn limit_label(self) -> &'static str {
52 match self {
53 CallBudgetKind::McpCalls => "mcp_calls",
54 CallBudgetKind::PgQueries => "pg_queries",
55 }
56 }
57
58 fn noun(self, plural: bool) -> &'static str {
61 match (self, plural) {
62 (CallBudgetKind::McpCalls, false) => "MCP call",
63 (CallBudgetKind::McpCalls, true) => "MCP calls",
64 (CallBudgetKind::PgQueries, false) => "Postgres query",
65 (CallBudgetKind::PgQueries, true) => "Postgres queries",
66 }
67 }
68}
69
70fn charge(
75 budget: &'static LocalKey<RefCell<Option<u64>>>,
76 count: &'static LocalKey<RefCell<u64>>,
77 kind: CallBudgetKind,
78) -> Result<(), VmError> {
79 let Some(max) = budget.with(|b| *b.borrow()) else {
80 return Ok(());
81 };
82 let spent = count.with(|c| {
83 let mut slot = c.borrow_mut();
84 *slot = slot.saturating_add(1);
85 *slot
86 });
87 if spent > max {
88 return Err(budget_exceeded_error(kind, spent, max));
89 }
90 Ok(())
91}
92
93fn budget_exceeded_error(kind: CallBudgetKind, spent: u64, max: u64) -> VmError {
98 let mut dict = BTreeMap::new();
99 dict.insert(
100 "category".to_string(),
101 VmValue::String(std::sync::Arc::from("budget_exceeded")),
102 );
103 dict.insert(
104 "kind".to_string(),
105 VmValue::String(std::sync::Arc::from("terminal")),
106 );
107 dict.insert(
108 "reason".to_string(),
109 VmValue::String(std::sync::Arc::from("budget_exceeded")),
110 );
111 dict.insert(
112 "limit".to_string(),
113 VmValue::String(std::sync::Arc::from(kind.limit_label())),
114 );
115 dict.insert("limit_value".to_string(), VmValue::Int(max as i64));
116 dict.insert("spent".to_string(), VmValue::Int(spent as i64));
117 dict.insert(
118 "message".to_string(),
119 VmValue::String(std::sync::Arc::from(format!(
120 "{} budget exceeded: this dispatch attempted {} of {} permitted {}",
121 kind.limit_label(),
122 spent,
123 max,
124 kind.noun(max != 1),
125 ))),
126 );
127 VmError::Thrown(VmValue::Dict(std::sync::Arc::new(dict)))
128}
129
130#[must_use = "dropping the guard immediately restores the prior MCP call budget"]
133pub struct McpCallBudgetGuard {
134 previous_budget: Option<u64>,
135 previous_count: u64,
136}
137
138impl Drop for McpCallBudgetGuard {
139 fn drop(&mut self) {
140 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
141 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
142 }
143}
144
145pub fn install_mcp_call_budget(max: u64) -> McpCallBudgetGuard {
150 let previous_budget = MCP_CALL_BUDGET.with(|b| *b.borrow());
151 let previous_count = MCP_CALL_COUNT.with(|c| *c.borrow());
152 MCP_CALL_BUDGET.with(|b| *b.borrow_mut() = Some(max));
153 MCP_CALL_COUNT.with(|c| *c.borrow_mut() = 0);
154 McpCallBudgetGuard {
155 previous_budget,
156 previous_count,
157 }
158}
159
160pub fn charge_mcp_call() -> Result<(), VmError> {
163 charge(&MCP_CALL_BUDGET, &MCP_CALL_COUNT, CallBudgetKind::McpCalls)
164}
165
166#[must_use = "dropping the guard immediately restores the prior Postgres query budget"]
169pub struct PgQueryBudgetGuard {
170 previous_budget: Option<u64>,
171 previous_count: u64,
172}
173
174impl Drop for PgQueryBudgetGuard {
175 fn drop(&mut self) {
176 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = self.previous_budget);
177 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = self.previous_count);
178 }
179}
180
181pub fn install_pg_query_budget(max: u64) -> PgQueryBudgetGuard {
186 let previous_budget = PG_QUERY_BUDGET.with(|b| *b.borrow());
187 let previous_count = PG_QUERY_COUNT.with(|c| *c.borrow());
188 PG_QUERY_BUDGET.with(|b| *b.borrow_mut() = Some(max));
189 PG_QUERY_COUNT.with(|c| *c.borrow_mut() = 0);
190 PgQueryBudgetGuard {
191 previous_budget,
192 previous_count,
193 }
194}
195
196pub fn charge_pg_query() -> Result<(), VmError> {
200 charge(&PG_QUERY_BUDGET, &PG_QUERY_COUNT, CallBudgetKind::PgQueries)
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::value::{error_to_category, ErrorCategory};
207
208 #[test]
209 fn charge_is_noop_without_installed_budget() {
210 reset_call_budget_state();
211 for _ in 0..1000 {
212 assert!(charge_mcp_call().is_ok());
213 assert!(charge_pg_query().is_ok());
214 }
215 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
217 assert_eq!(PG_QUERY_COUNT.with(|c| *c.borrow()), 0);
218 }
219
220 #[test]
221 fn mcp_budget_admits_up_to_ceiling_then_rejects() {
222 reset_call_budget_state();
223 let _guard = install_mcp_call_budget(2);
224 assert!(charge_mcp_call().is_ok());
225 assert!(charge_mcp_call().is_ok());
226 let third = charge_mcp_call();
227 let err = third.expect_err("third call must exceed mcp_calls: 2");
228 assert_eq!(error_to_category(&err), ErrorCategory::BudgetExceeded);
229 match &err {
230 VmError::Thrown(VmValue::Dict(d)) => {
231 assert_eq!(
232 d.get("limit").map(|v| v.display()).as_deref(),
233 Some("mcp_calls")
234 );
235 assert_eq!(d.get("limit_value").and_then(VmValue::as_int), Some(2));
236 assert_eq!(d.get("spent").and_then(VmValue::as_int), Some(3));
237 }
238 other => panic!("expected structured Thrown dict, got {other:?}"),
239 }
240 reset_call_budget_state();
241 }
242
243 #[test]
244 fn pg_budget_message_pluralises_and_names_dimension() {
245 reset_call_budget_state();
246 let _guard = install_pg_query_budget(1);
247 assert!(charge_pg_query().is_ok());
248 let err = charge_pg_query().expect_err("second query must exceed pg_queries: 1");
249 match &err {
250 VmError::Thrown(VmValue::Dict(d)) => {
251 let message = d.get("message").map(|v| v.display()).unwrap_or_default();
252 assert!(
253 message.contains("pg_queries budget exceeded"),
254 "got: {message}"
255 );
256 assert!(message.contains("Postgres query"), "got: {message}");
257 }
258 other => panic!("expected structured Thrown dict, got {other:?}"),
259 }
260 reset_call_budget_state();
261 }
262
263 #[test]
264 fn nested_guard_restores_outer_budget_and_count_on_drop() {
265 reset_call_budget_state();
266 let outer = install_mcp_call_budget(5);
267 assert!(charge_mcp_call().is_ok());
268 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
269
270 {
271 let _inner = install_mcp_call_budget(1);
273 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 0);
274 assert!(charge_mcp_call().is_ok());
275 assert!(charge_mcp_call().is_err());
276 }
277
278 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), Some(5));
280 assert_eq!(MCP_CALL_COUNT.with(|c| *c.borrow()), 1);
281 drop(outer);
282 assert_eq!(MCP_CALL_BUDGET.with(|b| *b.borrow()), None);
283 reset_call_budget_state();
284 }
285}