1use crate::config::schema::ModelPricing;
2use crate::cost::CostTracker;
3use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
4use std::sync::Arc;
5
6#[derive(Clone)]
11pub(crate) struct ToolLoopCostTrackingContext {
12 pub tracker: Arc<CostTracker>,
13 pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
14}
15
16impl ToolLoopCostTrackingContext {
17 pub(crate) fn new(
18 tracker: Arc<CostTracker>,
19 prices: Arc<std::collections::HashMap<String, ModelPricing>>,
20 ) -> Self {
21 Self { tracker, prices }
22 }
23}
24
25tokio::task_local! {
26 pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
27}
28
29pub(crate) fn record_tool_loop_cost_usage(
32 provider_name: &str,
33 model: &str,
34 usage: &crate::providers::traits::TokenUsage,
35) -> Option<(u64, f64)> {
36 let input_tokens = usage.input_tokens.unwrap_or(0);
37 let output_tokens = usage.output_tokens.unwrap_or(0);
38 let total_tokens = input_tokens.saturating_add(output_tokens);
39
40 let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
41 .try_with(Clone::clone)
42 .ok()
43 .flatten()?;
44
45 if total_tokens == 0 {
46 tracing::warn!(
47 provider = provider_name,
48 model,
49 "Cost tracking received zero-token usage; recording request with zero tokens (provider may not be reporting usage)"
50 );
51 }
52 let pricing = ctx
60 .prices
61 .get(model)
62 .or_else(|| ctx.prices.get(&format!("{provider_name}/{model}")))
63 .or_else(|| {
64 model
65 .rsplit_once('/')
66 .and_then(|(_, suffix)| ctx.prices.get(suffix))
67 })
68 .or_else(|| {
69 let base = model
73 .rsplit_once('-')
74 .filter(|(_, tail)| tail.chars().all(|c| c.is_ascii_digit()))
75 .map_or(model, |(prefix, _)| prefix);
76 ctx.prices.iter().find_map(|(key, entry)| {
77 let model_part = key.rsplit_once('/').map_or(key.as_str(), |(_, m)| m);
79 if model_part.starts_with(base) {
80 Some(entry)
81 } else {
82 None
83 }
84 })
85 });
86 let cost_usage = CostTokenUsage::new(
87 model,
88 input_tokens,
89 output_tokens,
90 pricing.map_or(0.0, |entry| entry.input),
91 pricing.map_or(0.0, |entry| entry.output),
92 );
93
94 if pricing.is_none() {
95 tracing::debug!(
96 provider = provider_name,
97 model,
98 "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
99 );
100 }
101
102 if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
103 tracing::warn!(
104 provider = provider_name,
105 model,
106 "Failed to record cost tracking usage: {error}"
107 );
108 }
109
110 Some((cost_usage.total_tokens, cost_usage.cost_usd))
111}
112
113pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
116 TOOL_LOOP_COST_TRACKING_CONTEXT
117 .try_with(Clone::clone)
118 .ok()
119 .flatten()
120 .map(|ctx| {
121 ctx.tracker
122 .check_budget(0.0)
123 .unwrap_or(BudgetCheck::Allowed)
124 })
125}