1use std::sync::{Arc, Mutex};
2
3use crate::metrics::SessionStatus;
4
5#[derive(Debug, Clone, Default)]
17pub struct Budget {
18 pub max_llm_calls: Option<u64>,
21 pub max_elapsed_ms: Option<u64>,
26}
27
28impl Budget {
29 pub fn from_ctx(ctx: &serde_json::Value) -> Option<Self> {
31 let obj = ctx.as_object()?.get("budget")?.as_object()?;
32 let max_llm_calls = obj.get("max_llm_calls").and_then(|v| v.as_u64());
33 let max_elapsed_ms = obj.get("max_elapsed_ms").and_then(|v| v.as_u64());
34 if max_llm_calls.is_none() && max_elapsed_ms.is_none() {
35 return None;
36 }
37 Some(Self {
38 max_llm_calls,
39 max_elapsed_ms,
40 })
41 }
42
43 pub fn check(&self, llm_calls: u64, elapsed_ms: u64) -> Result<(), String> {
46 if let Some(max) = self.max_llm_calls {
47 if llm_calls >= max {
48 return Err(format!(
49 "budget_exceeded: max_llm_calls ({max}) reached ({llm_calls} used)"
50 ));
51 }
52 }
53 if let Some(max_ms) = self.max_elapsed_ms {
54 if elapsed_ms >= max_ms {
55 return Err(format!(
56 "budget_exceeded: max_elapsed_ms ({max_ms}ms) reached ({elapsed_ms}ms elapsed)"
57 ));
58 }
59 }
60 Ok(())
61 }
62
63 pub fn remaining_json(&self, llm_calls: u64, elapsed_ms: u64) -> serde_json::Value {
66 serde_json::json!({
67 "llm_calls": self.max_llm_calls.map(|max| max.saturating_sub(llm_calls)),
68 "elapsed_ms": self.max_elapsed_ms.map(|max| max.saturating_sub(elapsed_ms)),
69 })
70 }
71
72 pub fn to_json(&self) -> serde_json::Value {
74 let mut map = serde_json::Map::new();
75 if let Some(max) = self.max_llm_calls {
76 map.insert("max_llm_calls".into(), max.into());
77 }
78 if let Some(max) = self.max_elapsed_ms {
79 map.insert("max_elapsed_ms".into(), max.into());
80 }
81 serde_json::Value::Object(map)
82 }
83}
84
85#[derive(Clone)]
118pub struct BudgetHandle {
119 auto: Arc<Mutex<SessionStatus>>,
120}
121
122impl BudgetHandle {
123 pub(crate) fn new(auto: Arc<Mutex<SessionStatus>>) -> Self {
124 Self { auto }
125 }
126
127 pub fn check(&self) -> Result<(), String> {
131 let m = self
132 .auto
133 .lock()
134 .map_err(|_| "budget check: mutex poisoned".to_string())?;
135 m.check_budget()
136 }
137
138 pub fn remaining(&self) -> serde_json::Value {
143 let m = match self.auto.lock() {
144 Ok(m) => m,
145 Err(_) => return serde_json::Value::Null,
146 };
147 m.budget_remaining()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::{ExecutionMetrics, ExecutionObserver, LlmQuery, QueryId};
155
156 #[test]
157 fn budget_from_ctx_none_when_missing() {
158 let ctx = serde_json::json!({"task": "test"});
159 assert!(Budget::from_ctx(&ctx).is_none());
160 }
161
162 #[test]
163 fn budget_from_ctx_none_when_empty() {
164 let ctx = serde_json::json!({"budget": {}});
165 assert!(Budget::from_ctx(&ctx).is_none());
166 }
167
168 #[test]
169 fn budget_from_ctx_extracts_llm_calls() {
170 let ctx = serde_json::json!({"budget": {"max_llm_calls": 10}});
171 let budget = Budget::from_ctx(&ctx).expect("should parse");
172 assert_eq!(budget.max_llm_calls, Some(10));
173 assert_eq!(budget.max_elapsed_ms, None);
174 }
175
176 #[test]
177 fn budget_from_ctx_extracts_elapsed_ms() {
178 let ctx = serde_json::json!({"budget": {"max_elapsed_ms": 5000}});
179 let budget = Budget::from_ctx(&ctx).expect("should parse");
180 assert_eq!(budget.max_llm_calls, None);
181 assert_eq!(budget.max_elapsed_ms, Some(5000));
182 }
183
184 #[test]
185 fn budget_from_ctx_extracts_both() {
186 let ctx = serde_json::json!({"budget": {"max_llm_calls": 5, "max_elapsed_ms": 30000}});
187 let budget = Budget::from_ctx(&ctx).expect("should parse");
188 assert_eq!(budget.max_llm_calls, Some(5));
189 assert_eq!(budget.max_elapsed_ms, Some(30000));
190 }
191
192 #[test]
193 fn budget_check_passes_when_within_limits() {
194 let metrics = ExecutionMetrics::new();
195 metrics.set_budget(Budget {
196 max_llm_calls: Some(5),
197 max_elapsed_ms: None,
198 });
199 let handle = metrics.budget_handle();
200 assert!(handle.check().is_ok());
201 }
202
203 #[test]
204 fn budget_check_fails_when_llm_calls_exceeded() {
205 let metrics = ExecutionMetrics::new();
206 metrics.set_budget(Budget {
207 max_llm_calls: Some(2),
208 max_elapsed_ms: None,
209 });
210 let observer = metrics.create_observer();
211 let handle = metrics.budget_handle();
212
213 let q = vec![LlmQuery {
215 id: QueryId::single(),
216 prompt: "p".into(),
217 system: None,
218 max_tokens: 10,
219 grounded: false,
220 underspecified: false,
221 }];
222 observer.on_paused(&q);
223 observer.on_response_fed(&QueryId::single(), "r");
224 observer.on_resumed();
225 observer.on_paused(&q);
226
227 let result = handle.check();
229 assert!(result.is_err());
230 assert!(result.unwrap_err().contains("budget_exceeded"));
231 }
232
233 #[test]
234 fn budget_remaining_null_when_no_budget() {
235 let metrics = ExecutionMetrics::new();
236 let handle = metrics.budget_handle();
237 assert!(handle.remaining().is_null());
238 }
239
240 #[test]
241 fn budget_remaining_tracks_llm_calls() {
242 let metrics = ExecutionMetrics::new();
243 metrics.set_budget(Budget {
244 max_llm_calls: Some(5),
245 max_elapsed_ms: None,
246 });
247 let observer = metrics.create_observer();
248 let handle = metrics.budget_handle();
249
250 let q = vec![LlmQuery {
251 id: QueryId::single(),
252 prompt: "p".into(),
253 system: None,
254 max_tokens: 10,
255 grounded: false,
256 underspecified: false,
257 }];
258 observer.on_paused(&q);
259
260 let remaining = handle.remaining();
261 assert_eq!(remaining["llm_calls"], 4); }
263
264 #[test]
265 fn budget_in_stats_json() {
266 let metrics = ExecutionMetrics::new();
267 metrics.set_budget(Budget {
268 max_llm_calls: Some(10),
269 max_elapsed_ms: Some(60000),
270 });
271 let observer = metrics.create_observer();
272 observer.on_completed(&serde_json::json!(null));
273
274 let json = metrics.to_json();
275 let budget = &json["auto"]["budget"];
276 assert_eq!(budget["max_llm_calls"], 10);
277 assert_eq!(budget["max_elapsed_ms"], 60000);
278 }
279
280 #[test]
281 fn no_budget_in_stats_json_when_not_set() {
282 let metrics = ExecutionMetrics::new();
283 let observer = metrics.create_observer();
284 observer.on_completed(&serde_json::json!(null));
285
286 let json = metrics.to_json();
287 assert!(json["auto"].get("budget").is_none());
288 }
289}