use crate::config::schema::ModelPricing;
use crate::cost::CostTracker;
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
use std::sync::Arc;
#[derive(Clone)]
pub(crate) struct ToolLoopCostTrackingContext {
pub tracker: Arc<CostTracker>,
pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
}
impl ToolLoopCostTrackingContext {
pub(crate) fn new(
tracker: Arc<CostTracker>,
prices: Arc<std::collections::HashMap<String, ModelPricing>>,
) -> Self {
Self { tracker, prices }
}
}
tokio::task_local! {
pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
}
pub(crate) fn record_tool_loop_cost_usage(
provider_name: &str,
model: &str,
usage: &crate::providers::traits::TokenUsage,
) -> Option<(u64, f64)> {
let input_tokens = usage.input_tokens.unwrap_or(0);
let output_tokens = usage.output_tokens.unwrap_or(0);
let total_tokens = input_tokens.saturating_add(output_tokens);
if total_tokens == 0 {
return None;
}
let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten()?;
let pricing = ctx
.prices
.get(model)
.or_else(|| ctx.prices.get(&format!("{provider_name}/{model}")))
.or_else(|| {
model
.rsplit_once('/')
.and_then(|(_, suffix)| ctx.prices.get(suffix))
});
let cost_usage = CostTokenUsage::new(
model,
input_tokens,
output_tokens,
pricing.map_or(0.0, |entry| entry.input),
pricing.map_or(0.0, |entry| entry.output),
);
if pricing.is_none() {
tracing::debug!(
provider = provider_name,
model,
"Cost tracking recorded token usage with zero pricing (no pricing entry found)"
);
}
if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
tracing::warn!(
provider = provider_name,
model,
"Failed to record cost tracking usage: {error}"
);
}
Some((cost_usage.total_tokens, cost_usage.cost_usd))
}
pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
TOOL_LOOP_COST_TRACKING_CONTEXT
.try_with(Clone::clone)
.ok()
.flatten()
.map(|ctx| {
ctx.tracker
.check_budget(0.0)
.unwrap_or(BudgetCheck::Allowed)
})
}