Skip to main content

construct/agent/
cost.rs

1use crate::config::schema::ModelPricing;
2use crate::cost::CostTracker;
3use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
4use std::sync::Arc;
5
6// ── Cost tracking via task-local ──
7
8/// Context for cost tracking within the tool call loop.
9/// Scoped via `tokio::task_local!` at call sites (channels, gateway).
10#[derive(Clone)]
11pub(crate) struct ToolLoopCostTrackingContext {
12    pub tracker: Arc<CostTracker>,
13    pub prices: Arc<std::collections::HashMap<String, ModelPricing>>,
14}
15
16impl ToolLoopCostTrackingContext {
17    pub(crate) fn new(
18        tracker: Arc<CostTracker>,
19        prices: Arc<std::collections::HashMap<String, ModelPricing>>,
20    ) -> Self {
21        Self { tracker, prices }
22    }
23}
24
25tokio::task_local! {
26    pub(crate) static TOOL_LOOP_COST_TRACKING_CONTEXT: Option<ToolLoopCostTrackingContext>;
27}
28
29/// Record token usage from an LLM response via the task-local cost tracker.
30/// Returns `(total_tokens, cost_usd)` on success, `None` when not scoped or no usage.
31pub(crate) fn record_tool_loop_cost_usage(
32    provider_name: &str,
33    model: &str,
34    usage: &crate::providers::traits::TokenUsage,
35) -> Option<(u64, f64)> {
36    let input_tokens = usage.input_tokens.unwrap_or(0);
37    let output_tokens = usage.output_tokens.unwrap_or(0);
38    let total_tokens = input_tokens.saturating_add(output_tokens);
39
40    let ctx = TOOL_LOOP_COST_TRACKING_CONTEXT
41        .try_with(Clone::clone)
42        .ok()
43        .flatten()?;
44
45    if total_tokens == 0 {
46        tracing::warn!(
47            provider = provider_name,
48            model,
49            "Cost tracking received zero-token usage; recording request with zero tokens (provider may not be reporting usage)"
50        );
51    }
52    // Multi-tier model pricing lookup:
53    //   1. Direct name          → "claude-sonnet-4-6"
54    //   2. Provider/model       → "anthropic/claude-sonnet-4-6"
55    //   3. Suffix after `/`     → strip provider prefix from model string
56    //   4. Fuzzy: find a pricing key whose model portion starts with our model's
57    //      base name (e.g. "claude-sonnet-4" matches "anthropic/claude-sonnet-4-20250514").
58    //      This handles short aliases (claude-sonnet-4-6) vs full versioned names.
59    let pricing = ctx
60        .prices
61        .get(model)
62        .or_else(|| ctx.prices.get(&format!("{provider_name}/{model}")))
63        .or_else(|| {
64            model
65                .rsplit_once('/')
66                .and_then(|(_, suffix)| ctx.prices.get(suffix))
67        })
68        .or_else(|| {
69            // Derive a base name by stripping the last `-<digits>` segment for fuzzy matching.
70            // "claude-sonnet-4-6" → base "claude-sonnet-4"
71            // "claude-opus-4-20250514" → base "claude-opus-4"
72            let base = model
73                .rsplit_once('-')
74                .filter(|(_, tail)| tail.chars().all(|c| c.is_ascii_digit()))
75                .map_or(model, |(prefix, _)| prefix);
76            ctx.prices.iter().find_map(|(key, entry)| {
77                // Extract model portion after provider prefix: "anthropic/claude-sonnet-4-..." → "claude-sonnet-4-..."
78                let model_part = key.rsplit_once('/').map_or(key.as_str(), |(_, m)| m);
79                if model_part.starts_with(base) {
80                    Some(entry)
81                } else {
82                    None
83                }
84            })
85        });
86    let cost_usage = CostTokenUsage::new(
87        model,
88        input_tokens,
89        output_tokens,
90        pricing.map_or(0.0, |entry| entry.input),
91        pricing.map_or(0.0, |entry| entry.output),
92    );
93
94    if pricing.is_none() {
95        tracing::debug!(
96            provider = provider_name,
97            model,
98            "Cost tracking recorded token usage with zero pricing (no pricing entry found)"
99        );
100    }
101
102    if let Err(error) = ctx.tracker.record_usage(cost_usage.clone()) {
103        tracing::warn!(
104            provider = provider_name,
105            model,
106            "Failed to record cost tracking usage: {error}"
107        );
108    }
109
110    Some((cost_usage.total_tokens, cost_usage.cost_usd))
111}
112
113/// Check budget before an LLM call. Returns `None` when no cost tracking
114/// context is scoped (tests, delegate, CLI without cost config).
115pub(crate) fn check_tool_loop_budget() -> Option<BudgetCheck> {
116    TOOL_LOOP_COST_TRACKING_CONTEXT
117        .try_with(Clone::clone)
118        .ok()
119        .flatten()
120        .map(|ctx| {
121            ctx.tracker
122                .check_budget(0.0)
123                .unwrap_or(BudgetCheck::Allowed)
124        })
125}