1use std::sync::{Arc, Mutex};
2
3use crate::metrics::SessionStatus;
4
5#[derive(Debug, Clone, Default)]
17pub struct Budget {
18 pub max_llm_calls: Option<u64>,
21 pub max_elapsed_ms: Option<u64>,
26 pub max_tokens: Option<u64>,
31}
32
33impl Budget {
34 pub fn from_ctx(ctx: &serde_json::Value) -> Option<Self> {
36 let obj = ctx.as_object()?.get("budget")?.as_object()?;
37 let max_llm_calls = obj.get("max_llm_calls").and_then(|v| v.as_u64());
38 let max_elapsed_ms = obj.get("max_elapsed_ms").and_then(|v| v.as_u64());
39 let max_tokens = obj.get("max_tokens").and_then(|v| v.as_u64());
40 if max_llm_calls.is_none() && max_elapsed_ms.is_none() && max_tokens.is_none() {
41 return None;
42 }
43 Some(Self {
44 max_llm_calls,
45 max_elapsed_ms,
46 max_tokens,
47 })
48 }
49
50 pub fn check(&self, llm_calls: u64, elapsed_ms: u64, total_tokens: u64) -> Result<(), String> {
53 if let Some(max) = self.max_llm_calls {
54 if llm_calls >= max {
55 return Err(format!(
56 "budget_exceeded: max_llm_calls ({max}) reached ({llm_calls} used)"
57 ));
58 }
59 }
60 if let Some(max_ms) = self.max_elapsed_ms {
61 if elapsed_ms >= max_ms {
62 return Err(format!(
63 "budget_exceeded: max_elapsed_ms ({max_ms}ms) reached ({elapsed_ms}ms elapsed)"
64 ));
65 }
66 }
67 if let Some(max) = self.max_tokens {
68 if total_tokens >= max {
69 return Err(format!(
70 "budget_exceeded: max_tokens ({max}) reached ({total_tokens} used)"
71 ));
72 }
73 }
74 Ok(())
75 }
76
77 pub fn remaining_json(
80 &self,
81 llm_calls: u64,
82 elapsed_ms: u64,
83 total_tokens: u64,
84 ) -> serde_json::Value {
85 serde_json::json!({
86 "llm_calls": self.max_llm_calls.map(|max| max.saturating_sub(llm_calls)),
87 "elapsed_ms": self.max_elapsed_ms.map(|max| max.saturating_sub(elapsed_ms)),
88 "tokens": self.max_tokens.map(|max| max.saturating_sub(total_tokens)),
89 })
90 }
91
92 pub fn to_json(&self) -> serde_json::Value {
94 let mut map = serde_json::Map::new();
95 if let Some(max) = self.max_llm_calls {
96 map.insert("max_llm_calls".into(), max.into());
97 }
98 if let Some(max) = self.max_elapsed_ms {
99 map.insert("max_elapsed_ms".into(), max.into());
100 }
101 if let Some(max) = self.max_tokens {
102 map.insert("max_tokens".into(), max.into());
103 }
104 serde_json::Value::Object(map)
105 }
106}
107
108#[derive(Clone)]
141pub struct BudgetHandle {
142 auto: Arc<Mutex<SessionStatus>>,
143}
144
145impl BudgetHandle {
146 pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
147 Self { auto }
148 }
149
150 pub fn check(&self) -> Result<(), String> {
154 let m = self
155 .auto
156 .lock()
157 .map_err(|_| "budget check: mutex poisoned".to_string())?;
158 m.check_budget()
159 }
160
161 pub fn remaining(&self) -> serde_json::Value {
166 let m = match self.auto.lock() {
167 Ok(m) => m,
168 Err(_) => return serde_json::Value::Null,
169 };
170 m.budget_remaining()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
178
179 #[test]
180 fn budget_from_ctx_none_when_missing() {
181 let ctx = serde_json::json!({"task": "test"});
182 assert!(Budget::from_ctx(&ctx).is_none());
183 }
184
185 #[test]
186 fn budget_from_ctx_none_when_empty() {
187 let ctx = serde_json::json!({"budget": {}});
188 assert!(Budget::from_ctx(&ctx).is_none());
189 }
190
191 #[test]
192 fn budget_from_ctx_extracts_llm_calls() {
193 let ctx = serde_json::json!({"budget": {"max_llm_calls": 10}});
194 let budget = Budget::from_ctx(&ctx).expect("should parse");
195 assert_eq!(budget.max_llm_calls, Some(10));
196 assert_eq!(budget.max_elapsed_ms, None);
197 }
198
199 #[test]
200 fn budget_from_ctx_extracts_elapsed_ms() {
201 let ctx = serde_json::json!({"budget": {"max_elapsed_ms": 5000}});
202 let budget = Budget::from_ctx(&ctx).expect("should parse");
203 assert_eq!(budget.max_llm_calls, None);
204 assert_eq!(budget.max_elapsed_ms, Some(5000));
205 }
206
207 #[test]
208 fn budget_from_ctx_extracts_both() {
209 let ctx = serde_json::json!({"budget": {"max_llm_calls": 5, "max_elapsed_ms": 30000}});
210 let budget = Budget::from_ctx(&ctx).expect("should parse");
211 assert_eq!(budget.max_llm_calls, Some(5));
212 assert_eq!(budget.max_elapsed_ms, Some(30000));
213 }
214
215 #[test]
216 fn budget_check_passes_when_within_limits() {
217 let metrics = ExecutionMetrics::new();
218 metrics.set_budget(Budget {
219 max_llm_calls: Some(5),
220 max_elapsed_ms: None,
221 max_tokens: None,
222 });
223 let handle = metrics.budget_handle();
224 assert!(handle.check().is_ok());
225 }
226
227 #[test]
228 fn budget_check_fails_when_llm_calls_exceeded() {
229 let metrics = ExecutionMetrics::new();
230 metrics.set_budget(Budget {
231 max_llm_calls: Some(2),
232 max_elapsed_ms: None,
233 max_tokens: None,
234 });
235 let observer = metrics.create_observer();
236 let handle = metrics.budget_handle();
237
238 let q = vec![LlmQuery {
240 id: QueryId::single(),
241 prompt: "p".into(),
242 system: None,
243 max_tokens: 10,
244 grounded: false,
245 underspecified: false,
246 }];
247 observer.on_paused(&q);
248 observer.on_response_fed(&QueryId::single(), "r", None);
249 observer.on_resumed();
250 observer.on_paused(&q);
251
252 let result = handle.check();
254 assert!(result.is_err());
255 assert!(result.unwrap_err().contains("budget_exceeded"));
256 }
257
258 #[test]
259 fn budget_check_fails_when_tokens_exceeded() {
260 let metrics = ExecutionMetrics::new();
261 metrics.set_budget(Budget {
262 max_llm_calls: None,
263 max_elapsed_ms: None,
264 max_tokens: Some(10),
265 });
266 let observer = metrics.create_observer();
267 let handle = metrics.budget_handle();
268
269 let q = vec![LlmQuery {
272 id: QueryId::single(),
273 prompt: "abcdefghijklmnopqrstuvwxyz0123456789abcd".into(),
274 system: None,
275 max_tokens: 100,
276 grounded: false,
277 underspecified: false,
278 }];
279 observer.on_paused(&q);
280 observer.on_response_fed(&QueryId::single(), "r", None);
281 observer.on_resumed();
282
283 let result = handle.check();
284 assert!(result.is_err());
285 assert!(result.unwrap_err().contains("max_tokens"));
286 }
287
288 #[test]
289 fn budget_check_passes_when_tokens_within_limit() {
290 let metrics = ExecutionMetrics::new();
291 metrics.set_budget(Budget {
292 max_llm_calls: None,
293 max_elapsed_ms: None,
294 max_tokens: Some(1000),
295 });
296 let observer = metrics.create_observer();
297 let handle = metrics.budget_handle();
298
299 let q = vec![LlmQuery {
300 id: QueryId::single(),
301 prompt: "short".into(),
302 system: None,
303 max_tokens: 100,
304 grounded: false,
305 underspecified: false,
306 }];
307 observer.on_paused(&q);
308 observer.on_response_fed(&QueryId::single(), "reply", None);
309 observer.on_resumed();
310
311 assert!(handle.check().is_ok());
312 }
313
314 #[test]
315 fn budget_remaining_tracks_tokens() {
316 let metrics = ExecutionMetrics::new();
317 metrics.set_budget(Budget {
318 max_llm_calls: None,
319 max_elapsed_ms: None,
320 max_tokens: Some(100),
321 });
322 let observer = metrics.create_observer();
323 let handle = metrics.budget_handle();
324
325 let q = vec![LlmQuery {
327 id: QueryId::single(),
328 prompt: "test".into(),
329 system: None,
330 max_tokens: 10,
331 grounded: false,
332 underspecified: false,
333 }];
334 observer.on_paused(&q);
335 observer.on_response_fed(&QueryId::single(), "r", None);
336 observer.on_resumed();
337
338 let remaining = handle.remaining();
339 assert_eq!(remaining["tokens"], 98);
341 }
342
343 #[test]
344 fn budget_from_ctx_extracts_max_tokens() {
345 let ctx = serde_json::json!({"budget": {"max_tokens": 5000}});
346 let budget = Budget::from_ctx(&ctx).expect("should parse");
347 assert_eq!(budget.max_llm_calls, None);
348 assert_eq!(budget.max_elapsed_ms, None);
349 assert_eq!(budget.max_tokens, Some(5000));
350 }
351
352 #[test]
353 fn budget_remaining_null_when_no_budget() {
354 let metrics = ExecutionMetrics::new();
355 let handle = metrics.budget_handle();
356 assert!(handle.remaining().is_null());
357 }
358
359 #[test]
360 fn budget_remaining_tracks_llm_calls() {
361 let metrics = ExecutionMetrics::new();
362 metrics.set_budget(Budget {
363 max_llm_calls: Some(5),
364 max_elapsed_ms: None,
365 max_tokens: None,
366 });
367 let observer = metrics.create_observer();
368 let handle = metrics.budget_handle();
369
370 let q = vec![LlmQuery {
371 id: QueryId::single(),
372 prompt: "p".into(),
373 system: None,
374 max_tokens: 10,
375 grounded: false,
376 underspecified: false,
377 }];
378 observer.on_paused(&q);
379
380 let remaining = handle.remaining();
381 assert_eq!(remaining["llm_calls"], 4); }
383
384 #[test]
385 fn budget_in_stats_json() {
386 let metrics = ExecutionMetrics::new();
387 metrics.set_budget(Budget {
388 max_llm_calls: Some(10),
389 max_elapsed_ms: Some(60000),
390 max_tokens: None,
391 });
392 let observer = metrics.create_observer();
393 observer.on_completed(&serde_json::json!(null));
394
395 let json = metrics.to_json();
396 let budget = &json["auto"]["budget"];
397 assert_eq!(budget["max_llm_calls"], 10);
398 assert_eq!(budget["max_elapsed_ms"], 60000);
399 }
400
401 #[test]
402 fn no_budget_in_stats_json_when_not_set() {
403 let metrics = ExecutionMetrics::new();
404 let observer = metrics.create_observer();
405 observer.on_completed(&serde_json::json!(null));
406
407 let json = metrics.to_json();
408 assert!(json["auto"].get("budget").is_none());
409 }
410}