use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use arcp_core::envelope::Envelope;
use arcp_core::error::ARCPError;
use arcp_core::ids::{JobId, MessageId, SessionId};
use arcp_core::messages::{
CostBudget, JobResultChunkPayload, LeaseRequest, MessageType, MetricPayload,
ResultChunkEncoding,
};
pub struct ToolContext {
pub cancel: CancellationToken,
pub(crate) job_id: JobId,
pub(crate) session_id: SessionId,
pub(crate) correlation_id: MessageId,
pub(crate) out: mpsc::Sender<Envelope>,
pub(crate) budget: BudgetTracker,
pub(crate) lease: Option<LeaseRequest>,
}
#[derive(Clone, Debug, Default)]
pub struct BudgetTracker {
inner: Arc<BudgetTrackerInner>,
}
const BUDGET_SCALE: i128 = 1_000_000;
#[derive(Debug, Default)]
struct BudgetTrackerInner {
state: Mutex<HashMap<String, (i128, i128)>>,
}
fn to_micros(amount: f64) -> Option<i128> {
if !amount.is_finite() || amount < 0.0 {
return None;
}
#[allow(clippy::cast_precision_loss)]
let max_amount = (i128::MAX / BUDGET_SCALE) as f64;
if amount > max_amount {
return None;
}
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
let scaled = (amount * BUDGET_SCALE as f64).round() as i128;
Some(scaled)
}
#[allow(clippy::cast_precision_loss)]
fn from_micros(micros: i128) -> f64 {
micros as f64 / BUDGET_SCALE as f64
}
impl BudgetTracker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_budget(budget: &CostBudget) -> Self {
let mut state = HashMap::new();
for a in &budget.amounts {
let max = to_micros(a.amount).unwrap_or(0);
state.insert(a.currency.clone(), (max, 0i128));
}
Self {
inner: Arc::new(BudgetTrackerInner {
state: Mutex::new(state),
}),
}
}
#[must_use]
pub fn is_disabled(&self) -> bool {
self.inner.state.lock().map_or(true, |s| s.is_empty())
}
#[must_use]
pub fn remaining(&self, currency: &str) -> Option<f64> {
let s = self.inner.state.lock().ok()?;
s.get(currency).map(|(max, cons)| from_micros(max - cons))
}
#[must_use]
pub fn snapshot_remaining(&self) -> HashMap<String, f64> {
self.inner
.state
.lock()
.map(|s| {
s.iter()
.map(|(k, (max, cons))| (k.clone(), from_micros(max - cons)))
.collect()
})
.unwrap_or_default()
}
pub fn charge(&self, currency: &str, amount: f64) -> Result<f64, ARCPError> {
let Some(amount_micros) = to_micros(amount) else {
return Err(ARCPError::InvalidArgument {
detail: format!("negative, non-finite, or out-of-range cost amount: {amount}"),
});
};
let Ok(mut s) = self.inner.state.lock() else {
return Err(ARCPError::Internal {
detail: "budget tracker mutex poisoned".into(),
});
};
let Some(entry) = s.get_mut(currency) else {
return Ok(f64::INFINITY);
};
let remaining = entry.0.saturating_sub(entry.1);
if amount_micros > remaining {
return Err(ARCPError::BudgetExhausted {
detail: format!(
"{currency} budget exhausted (remaining={}, attempted={amount})",
from_micros(remaining)
),
});
}
entry.1 = entry.1.saturating_add(amount_micros);
Ok(from_micros(entry.0 - entry.1))
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod budget_tracker_tests {
use super::*;
use arcp_core::messages::CostBudgetAmount;
fn budget(items: &[(&str, f64)]) -> CostBudget {
CostBudget {
amounts: items
.iter()
.map(|(c, a)| CostBudgetAmount {
currency: (*c).to_owned(),
amount: *a,
})
.collect(),
}
}
#[test]
fn fresh_tracker_reports_max_remaining() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
assert_eq!(t.remaining("USD"), Some(5.0));
}
#[test]
fn charge_decrements_remaining() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
let r = t.charge("USD", 1.5).expect("charge ok");
assert!((r - 3.5).abs() < f64::EPSILON);
assert!((t.remaining("USD").unwrap() - 3.5).abs() < f64::EPSILON);
}
#[test]
fn negative_charge_rejected() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
assert!(matches!(
t.charge("USD", -0.5),
Err(ARCPError::InvalidArgument { .. })
));
}
#[test]
fn oversized_single_charge_is_rejected_and_counter_unchanged() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
let err = t.charge("USD", 1.5).unwrap_err();
assert!(matches!(err, ARCPError::BudgetExhausted { .. }));
let remaining = t.remaining("USD").expect("currency tracked");
assert!((remaining - 1.0).abs() < f64::EPSILON);
let after = t.charge("USD", 0.4).expect("in-budget charge ok");
assert!((after - 0.6).abs() < f64::EPSILON);
}
#[test]
fn exact_exhaustion_succeeds_and_next_charge_fails() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
let after = t.charge("USD", 1.0).expect("exact-exhaustion ok");
assert!(after.abs() < f64::EPSILON);
let err = t.charge("USD", 0.000_001).unwrap_err();
assert!(matches!(err, ARCPError::BudgetExhausted { .. }));
}
#[test]
fn fractional_decimal_charges_sum_without_floating_point_drift() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 1.0)]));
t.charge("USD", 0.10).expect("first slice");
t.charge("USD", 0.20).expect("second slice");
let after = t.charge("USD", 0.70).expect("third slice ok");
assert!(after.abs() < f64::EPSILON);
}
#[test]
fn multi_currency_charges_are_tracked_independently() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0), ("EUR", 2.0)]));
t.charge("USD", 3.0).expect("usd in budget");
t.charge("EUR", 1.5).expect("eur in budget");
let usd_err = t.charge("USD", 2.5).unwrap_err();
assert!(matches!(usd_err, ARCPError::BudgetExhausted { .. }));
assert!((t.remaining("USD").unwrap() - 2.0).abs() < f64::EPSILON);
assert!((t.remaining("EUR").unwrap() - 0.5).abs() < f64::EPSILON);
}
#[test]
fn unbudgeted_currency_returns_infinity() {
let t = BudgetTracker::from_budget(&budget(&[("USD", 5.0)]));
let r = t.charge("EUR", 2.0).expect("charge ok");
assert!(r.is_infinite());
}
}
impl std::fmt::Debug for ToolContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolContext")
.field("job_id", &self.job_id)
.field("session_id", &self.session_id)
.finish_non_exhaustive()
}
}
impl ToolContext {
#[must_use]
pub const fn correlation_id(&self) -> &MessageId {
&self.correlation_id
}
#[must_use]
pub const fn job_id(&self) -> &JobId {
&self.job_id
}
#[must_use]
pub const fn budget(&self) -> &BudgetTracker {
&self.budget
}
#[must_use]
pub const fn lease(&self) -> Option<&LeaseRequest> {
self.lease.as_ref()
}
pub fn enforce_model_use(&self, model: &str) -> Result<(), ARCPError> {
let Some(model_use) = self
.lease
.as_ref()
.and_then(|lease| lease.model_use.as_ref())
else {
return Ok(());
};
if model_use.matches(model) {
Ok(())
} else {
Err(ARCPError::PermissionDenied {
detail: format!("model {model} not permitted by lease model.use"),
})
}
}
#[must_use]
pub fn translate_upstream_budget_exhausted(&self, detail: impl Into<String>) -> ARCPError {
ARCPError::BudgetExhausted {
detail: detail.into(),
}
}
pub async fn charge(&self, name: &str, amount: f64, currency: &str) -> Result<(), ARCPError> {
let remaining = self.budget.charge(currency, amount)?;
let mut metric = Envelope::new(MessageType::Metric(MetricPayload {
name: name.to_owned(),
value: amount,
unit: currency.to_owned(),
dims: None,
}));
metric.session_id = Some(self.session_id.clone());
metric.job_id = Some(self.job_id.clone());
metric.correlation_id = Some(self.correlation_id.clone());
let _ = self.out.send(metric).await;
if remaining.is_finite() {
let mut rem = Envelope::new(MessageType::Metric(MetricPayload {
name: "cost.budget.remaining".into(),
value: remaining,
unit: currency.to_owned(),
dims: None,
}));
rem.session_id = Some(self.session_id.clone());
rem.job_id = Some(self.job_id.clone());
rem.correlation_id = Some(self.correlation_id.clone());
let _ = self.out.send(rem).await;
}
Ok(())
}
pub async fn emit_result_chunk(
&self,
result_id: impl Into<String>,
chunk_seq: u64,
data: impl Into<String>,
encoding: ResultChunkEncoding,
more: bool,
) -> Result<(), ARCPError> {
let mut env = Envelope::new(MessageType::JobResultChunk(JobResultChunkPayload {
result_id: result_id.into(),
chunk_seq,
data: data.into(),
encoding,
more,
}));
env.session_id = Some(self.session_id.clone());
env.job_id = Some(self.job_id.clone());
env.correlation_id = Some(self.correlation_id.clone());
self.out
.send(env)
.await
.map_err(|_| ARCPError::Unavailable {
detail: "outbound channel closed".into(),
})
}
}
#[cfg(test)]
#[allow(
clippy::expect_used,
clippy::unwrap_used,
clippy::panic,
clippy::missing_panics_doc
)]
mod tests {
use tokio::sync::mpsc;
use super::*;
fn build_ctx() -> (ToolContext, mpsc::Receiver<Envelope>) {
let (out_tx, out_rx) = mpsc::channel(8);
let ctx = ToolContext {
cancel: CancellationToken::new(),
job_id: JobId::new(),
session_id: SessionId::new(),
correlation_id: MessageId::new(),
out: out_tx,
budget: BudgetTracker::new(),
lease: None,
};
(ctx, out_rx)
}
#[tokio::test]
async fn accessors_return_internal_ids() {
let (ctx, _rx) = build_ctx();
assert!(ctx.correlation_id().as_str().starts_with("msg_"));
assert!(ctx.job_id().as_str().starts_with("job_"));
}
}