1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use tokio::sync::mpsc;
10use tokio_util::sync::CancellationToken;
11
12use arcp_core::envelope::Envelope;
13use arcp_core::error::ARCPError;
14use arcp_core::ids::{JobId, MessageId, SessionId};
15use arcp_core::messages::{
16 CostBudget, JobResultChunkPayload, LeaseRequest, MessageType, MetricPayload,
17 ResultChunkEncoding,
18};
19
20pub struct ToolContext {
22 pub cancel: CancellationToken,
24 pub(crate) job_id: JobId,
25 pub(crate) session_id: SessionId,
26 pub(crate) correlation_id: MessageId,
27 pub(crate) out: mpsc::Sender<Envelope>,
28 pub(crate) budget: BudgetTracker,
31 pub(crate) lease: Option<LeaseRequest>,
33}
34
35#[derive(Clone, Debug, Default)]
43pub struct BudgetTracker {
44 inner: Arc<BudgetTrackerInner>,
45}
46
47const BUDGET_SCALE: i128 = 1_000_000;
53
54#[derive(Debug, Default)]
55struct BudgetTrackerInner {
56 state: Mutex<HashMap<String, (i128, i128)>>,
60}
61
62fn to_micros(amount: f64) -> Option<i128> {
66 if !amount.is_finite() || amount < 0.0 {
67 return None;
68 }
69 #[allow(clippy::cast_precision_loss)]
70 let max_amount = (i128::MAX / BUDGET_SCALE) as f64;
71 if amount > max_amount {
72 return None;
73 }
74 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
75 let scaled = (amount * BUDGET_SCALE as f64).round() as i128;
76 Some(scaled)
77}
78
79#[allow(clippy::cast_precision_loss)]
80fn from_micros(micros: i128) -> f64 {
81 micros as f64 / BUDGET_SCALE as f64
82}
83
84impl BudgetTracker {
85 #[must_use]
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 #[must_use]
94 pub fn from_budget(budget: &CostBudget) -> Self {
95 let mut state = HashMap::new();
96 for a in &budget.amounts {
97 let max = to_micros(a.amount).unwrap_or(0);
98 state.insert(a.currency.clone(), (max, 0i128));
99 }
100 Self {
101 inner: Arc::new(BudgetTrackerInner {
102 state: Mutex::new(state),
103 }),
104 }
105 }
106
107 #[must_use]
109 pub fn is_disabled(&self) -> bool {
110 self.inner.state.lock().map_or(true, |s| s.is_empty())
111 }
112
113 #[must_use]
117 pub fn remaining(&self, currency: &str) -> Option<f64> {
118 let s = self.inner.state.lock().ok()?;
119 s.get(currency).map(|(max, cons)| from_micros(max - cons))
120 }
121
122 #[must_use]
124 pub fn snapshot_remaining(&self) -> HashMap<String, f64> {
125 self.inner
126 .state
127 .lock()
128 .map(|s| {
129 s.iter()
130 .map(|(k, (max, cons))| (k.clone(), from_micros(max - cons)))
131 .collect()
132 })
133 .unwrap_or_default()
134 }
135
136 pub fn charge(&self, currency: &str, amount: f64) -> Result<f64, ARCPError> {
169 let Some(amount_micros) = to_micros(amount) else {
170 return Err(ARCPError::InvalidArgument {
171 detail: format!("negative, non-finite, or out-of-range cost amount: {amount}"),
172 });
173 };
174 let Ok(mut s) = self.inner.state.lock() else {
175 return Err(ARCPError::Internal {
176 detail: "budget tracker mutex poisoned".into(),
177 });
178 };
179 let Some(entry) = s.get_mut(currency) else {
180 return Ok(f64::INFINITY);
183 };
184 let remaining = entry.0.saturating_sub(entry.1);
185 if amount_micros > remaining {
186 return Err(ARCPError::BudgetExhausted {
187 detail: format!(
188 "{currency} budget exhausted (remaining={}, attempted={amount})",
189 from_micros(remaining)
190 ),
191 });
192 }
193 entry.1 = entry.1.saturating_add(amount_micros);
194 Ok(from_micros(entry.0 - entry.1))
195 }
196}
197
198#[cfg(test)]
199#[allow(clippy::expect_used, clippy::unwrap_used)]
200mod budget_tracker_tests {
201 use super::*;
202 use arcp_core::messages::CostBudgetAmount;
203
204 fn budget(items: &[(&str, f64)]) -> CostBudget {
205 CostBudget {
206 amounts: items
207 .iter()
208 .map(|(c, a)| CostBudgetAmount {
209 currency: (*c).to_owned(),
210 amount: *a,
211 })
212 .collect(),
213 }
214 }
215
216 #[test]
217 fn fresh_tracker_reports_max_remaining() {
218 let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
219 assert_eq!(t.remaining("USD"), Some(5.0));
220 }
221
222 #[test]
223 fn charge_decrements_remaining() {
224 let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
225 let r = t.charge("USD", 1.5).expect("charge ok");
226 assert!((r - 3.5).abs() < f64::EPSILON);
227 assert!((t.remaining("USD").unwrap() - 3.5).abs() < f64::EPSILON);
228 }
229
230 #[test]
231 fn negative_charge_rejected() {
232 let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
233 assert!(matches!(
234 t.charge("USD", -0.5),
235 Err(ARCPError::InvalidArgument { .. })
236 ));
237 }
238
239 #[test]
240 fn oversized_single_charge_is_rejected_and_counter_unchanged() {
241 let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
244 let err = t.charge("USD", 1.5).unwrap_err();
245 assert!(matches!(err, ARCPError::BudgetExhausted { .. }));
246 let remaining = t.remaining("USD").expect("currency tracked");
247 assert!((remaining - 1.0).abs() < f64::EPSILON);
248 let after = t.charge("USD", 0.4).expect("in-budget charge ok");
250 assert!((after - 0.6).abs() < f64::EPSILON);
251 }
252
253 #[test]
254 fn exact_exhaustion_succeeds_and_next_charge_fails() {
255 let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
258 let after = t.charge("USD", 1.0).expect("exact-exhaustion ok");
259 assert!(after.abs() < f64::EPSILON);
260 let err = t.charge("USD", 0.000_001).unwrap_err();
261 assert!(matches!(err, ARCPError::BudgetExhausted { .. }));
262 }
263
264 #[test]
265 fn fractional_decimal_charges_sum_without_floating_point_drift() {
266 let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
270 t.charge("USD", 0.10).expect("first slice");
271 t.charge("USD", 0.20).expect("second slice");
272 let after = t.charge("USD", 0.70).expect("third slice ok");
273 assert!(after.abs() < f64::EPSILON);
274 }
275
276 #[test]
277 fn multi_currency_charges_are_tracked_independently() {
278 let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0), ("EUR", 2.0)]));
279 t.charge("USD", 3.0).expect("usd in budget");
280 t.charge("EUR", 1.5).expect("eur in budget");
281 let usd_err = t.charge("USD", 2.5).unwrap_err();
282 assert!(matches!(usd_err, ARCPError::BudgetExhausted { .. }));
283 assert!((t.remaining("USD").unwrap() - 2.0).abs() < f64::EPSILON);
284 assert!((t.remaining("EUR").unwrap() - 0.5).abs() < f64::EPSILON);
285 }
286
287 #[test]
288 fn unbudgeted_currency_returns_infinity() {
289 let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
290 let r = t.charge("EUR", 2.0).expect("charge ok");
291 assert!(r.is_infinite());
292 }
293}
294
295impl std::fmt::Debug for ToolContext {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("ToolContext")
298 .field("job_id", &self.job_id)
299 .field("session_id", &self.session_id)
300 .finish_non_exhaustive()
301 }
302}
303
304impl ToolContext {
305 #[must_use]
307 pub const fn correlation_id(&self) -> &MessageId {
308 &self.correlation_id
309 }
310
311 #[must_use]
313 pub const fn job_id(&self) -> &JobId {
314 &self.job_id
315 }
316
317 #[must_use]
324 pub const fn budget(&self) -> &BudgetTracker {
325 &self.budget
326 }
327
328 #[must_use]
330 pub const fn lease(&self) -> Option<&LeaseRequest> {
331 self.lease.as_ref()
332 }
333
334 pub fn enforce_model_use(&self, model: &str) -> Result<(), ARCPError> {
341 let Some(model_use) = self
342 .lease
343 .as_ref()
344 .and_then(|lease| lease.model_use.as_ref())
345 else {
346 return Ok(());
347 };
348 if model_use.matches(model) {
349 Ok(())
350 } else {
351 Err(ARCPError::PermissionDenied {
352 detail: format!("model {model} not permitted by lease model.use"),
353 })
354 }
355 }
356
357 #[must_use]
360 pub fn translate_upstream_budget_exhausted(&self, detail: impl Into<String>) -> ARCPError {
361 ARCPError::BudgetExhausted {
362 detail: detail.into(),
363 }
364 }
365
366 pub async fn charge(&self, name: &str, amount: f64, currency: &str) -> Result<(), ARCPError> {
381 let remaining = self.budget.charge(currency, amount)?;
382 let mut metric = Envelope::new(MessageType::Metric(MetricPayload {
385 name: name.to_owned(),
386 value: amount,
387 unit: currency.to_owned(),
388 dims: None,
389 }));
390 metric.session_id = Some(self.session_id.clone());
391 metric.job_id = Some(self.job_id.clone());
392 metric.correlation_id = Some(self.correlation_id.clone());
393 let _ = self.out.send(metric).await;
394 if remaining.is_finite() {
399 let mut rem = Envelope::new(MessageType::Metric(MetricPayload {
400 name: "cost.budget.remaining".into(),
401 value: remaining,
402 unit: currency.to_owned(),
403 dims: None,
404 }));
405 rem.session_id = Some(self.session_id.clone());
406 rem.job_id = Some(self.job_id.clone());
407 rem.correlation_id = Some(self.correlation_id.clone());
408 let _ = self.out.send(rem).await;
409 }
410 Ok(())
411 }
412
413 pub async fn emit_result_chunk(
425 &self,
426 result_id: impl Into<String>,
427 chunk_seq: u64,
428 data: impl Into<String>,
429 encoding: ResultChunkEncoding,
430 more: bool,
431 ) -> Result<(), ARCPError> {
432 let mut env = Envelope::new(MessageType::JobResultChunk(JobResultChunkPayload {
433 result_id: result_id.into(),
434 chunk_seq,
435 data: data.into(),
436 encoding,
437 more,
438 }));
439 env.session_id = Some(self.session_id.clone());
440 env.job_id = Some(self.job_id.clone());
441 env.correlation_id = Some(self.correlation_id.clone());
442 self.out
443 .send(env)
444 .await
445 .map_err(|_| ARCPError::Unavailable {
446 detail: "outbound channel closed".into(),
447 })
448 }
449}
450
451#[cfg(test)]
452#[allow(
453 clippy::expect_used,
454 clippy::unwrap_used,
455 clippy::panic,
456 clippy::missing_panics_doc
457)]
458mod tests {
459 use tokio::sync::mpsc;
460
461 use super::*;
462
463 fn build_ctx() -> (ToolContext, mpsc::Receiver<Envelope>) {
464 let (out_tx, out_rx) = mpsc::channel(8);
465 let ctx = ToolContext {
466 cancel: CancellationToken::new(),
467 job_id: JobId::new(),
468 session_id: SessionId::new(),
469 correlation_id: MessageId::new(),
470 out: out_tx,
471 budget: BudgetTracker::new(),
472 lease: None,
473 };
474 (ctx, out_rx)
475 }
476
477 #[tokio::test]
478 async fn accessors_return_internal_ids() {
479 let (ctx, _rx) = build_ctx();
480 assert!(ctx.correlation_id().as_str().starts_with("msg_"));
482 assert!(ctx.job_id().as_str().starts_with("job_"));
483 }
484}