Skip to main content

chainlink/
token_usage.rs

1use serde::Deserialize;
2
3/// Raw token usage as returned by the Claude API response metadata.
4#[derive(Debug, Clone, Deserialize)]
5pub struct RawTokenUsage {
6    pub input_tokens: i64,
7    pub output_tokens: i64,
8    #[serde(default)]
9    pub cache_read_input_tokens: Option<i64>,
10    #[serde(default)]
11    pub cache_creation_input_tokens: Option<i64>,
12}
13
14/// Parsed usage ready for database insertion.
15#[derive(Debug, Clone)]
16pub struct ParsedUsage {
17    pub agent_id: String,
18    pub session_id: Option<i64>,
19    pub input_tokens: i64,
20    pub output_tokens: i64,
21    pub cache_read_tokens: Option<i64>,
22    pub cache_creation_tokens: Option<i64>,
23    pub model: String,
24    pub cost_estimate: Option<f64>,
25}
26
27/// Per-million-token pricing for a model.
28struct ModelPricing {
29    input_per_mtok: f64,
30    output_per_mtok: f64,
31    cache_read_per_mtok: f64,
32    cache_creation_per_mtok: f64,
33}
34
35/// Estimate cost in USD for a given model and token counts.
36pub fn estimate_cost(
37    model: &str,
38    input_tokens: i64,
39    output_tokens: i64,
40    cache_read_tokens: Option<i64>,
41    cache_creation_tokens: Option<i64>,
42) -> Option<f64> {
43    let pricing = model_pricing(model)?;
44
45    let input_cost = input_tokens as f64 * pricing.input_per_mtok / 1_000_000.0;
46    let output_cost = output_tokens as f64 * pricing.output_per_mtok / 1_000_000.0;
47    let cache_read_cost =
48        cache_read_tokens.unwrap_or(0) as f64 * pricing.cache_read_per_mtok / 1_000_000.0;
49    let cache_create_cost =
50        cache_creation_tokens.unwrap_or(0) as f64 * pricing.cache_creation_per_mtok / 1_000_000.0;
51
52    Some(input_cost + output_cost + cache_read_cost + cache_create_cost)
53}
54
55fn model_pricing(model: &str) -> Option<ModelPricing> {
56    let m = model.to_lowercase();
57    if m.contains("opus") {
58        Some(ModelPricing {
59            input_per_mtok: 15.0,
60            output_per_mtok: 75.0,
61            cache_read_per_mtok: 1.5,
62            cache_creation_per_mtok: 18.75,
63        })
64    } else if m.contains("sonnet") {
65        Some(ModelPricing {
66            input_per_mtok: 3.0,
67            output_per_mtok: 15.0,
68            cache_read_per_mtok: 0.3,
69            cache_creation_per_mtok: 3.75,
70        })
71    } else if m.contains("haiku") {
72        Some(ModelPricing {
73            input_per_mtok: 0.80,
74            output_per_mtok: 4.0,
75            cache_read_per_mtok: 0.08,
76            cache_creation_per_mtok: 1.0,
77        })
78    } else {
79        None
80    }
81}
82
83/// Parse raw API usage into a fully resolved ParsedUsage with cost estimate.
84pub fn parse_api_usage(
85    raw: &RawTokenUsage,
86    model: &str,
87    agent_id: &str,
88    session_id: Option<i64>,
89) -> ParsedUsage {
90    let cost = estimate_cost(
91        model,
92        raw.input_tokens,
93        raw.output_tokens,
94        raw.cache_read_input_tokens,
95        raw.cache_creation_input_tokens,
96    );
97
98    ParsedUsage {
99        agent_id: agent_id.to_string(),
100        session_id,
101        input_tokens: raw.input_tokens,
102        output_tokens: raw.output_tokens,
103        cache_read_tokens: raw.cache_read_input_tokens,
104        cache_creation_tokens: raw.cache_creation_input_tokens,
105        model: model.to_string(),
106        cost_estimate: cost,
107    }
108}
109
110/// Aggregated usage summary grouped by agent and model.
111#[derive(Debug, Clone, serde::Serialize)]
112pub struct UsageSummaryRow {
113    pub agent_id: String,
114    pub model: String,
115    pub request_count: i64,
116    pub total_input_tokens: i64,
117    pub total_output_tokens: i64,
118    pub total_cache_read_tokens: i64,
119    pub total_cache_creation_tokens: i64,
120    pub total_cost: f64,
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_estimate_cost_opus() {
129        let cost = estimate_cost("claude-opus-4-6", 1_000_000, 1_000_000, None, None).unwrap();
130        // $15 input + $75 output = $90
131        assert!((cost - 90.0).abs() < 0.01);
132    }
133
134    #[test]
135    fn test_estimate_cost_sonnet() {
136        let cost = estimate_cost("claude-sonnet-4-6", 1_000_000, 1_000_000, None, None).unwrap();
137        // $3 input + $15 output = $18
138        assert!((cost - 18.0).abs() < 0.01);
139    }
140
141    #[test]
142    fn test_estimate_cost_haiku() {
143        let cost = estimate_cost("claude-haiku-4-5", 1_000_000, 1_000_000, None, None).unwrap();
144        // $0.80 input + $4 output = $4.80
145        assert!((cost - 4.80).abs() < 0.01);
146    }
147
148    #[test]
149    fn test_estimate_cost_with_cache() {
150        let cost = estimate_cost(
151            "claude-opus-4-6",
152            500_000,
153            200_000,
154            Some(300_000),
155            Some(100_000),
156        )
157        .unwrap();
158        let expected = 500_000.0 * 15.0 / 1_000_000.0
159            + 200_000.0 * 75.0 / 1_000_000.0
160            + 300_000.0 * 1.5 / 1_000_000.0
161            + 100_000.0 * 18.75 / 1_000_000.0;
162        assert!((cost - expected).abs() < 0.01);
163    }
164
165    #[test]
166    fn test_estimate_cost_unknown_model() {
167        assert!(estimate_cost("gpt-4", 1000, 1000, None, None).is_none());
168    }
169
170    #[test]
171    fn test_parse_api_usage() {
172        let raw = RawTokenUsage {
173            input_tokens: 1000,
174            output_tokens: 500,
175            cache_read_input_tokens: Some(200),
176            cache_creation_input_tokens: None,
177        };
178        let parsed = parse_api_usage(&raw, "claude-sonnet-4-6", "worker-1", Some(42));
179        assert_eq!(parsed.agent_id, "worker-1");
180        assert_eq!(parsed.session_id, Some(42));
181        assert_eq!(parsed.input_tokens, 1000);
182        assert_eq!(parsed.output_tokens, 500);
183        assert_eq!(parsed.cache_read_tokens, Some(200));
184        assert!(parsed.cost_estimate.is_some());
185    }
186
187    #[test]
188    fn test_raw_token_usage_deserialize() {
189        let json = r#"{"input_tokens": 100, "output_tokens": 50}"#;
190        let raw: RawTokenUsage = serde_json::from_str(json).unwrap();
191        assert_eq!(raw.input_tokens, 100);
192        assert_eq!(raw.output_tokens, 50);
193        assert!(raw.cache_read_input_tokens.is_none());
194    }
195
196    #[test]
197    fn test_raw_token_usage_with_cache_fields() {
198        let json = r#"{
199            "input_tokens": 100,
200            "output_tokens": 50,
201            "cache_read_input_tokens": 30,
202            "cache_creation_input_tokens": 10
203        }"#;
204        let raw: RawTokenUsage = serde_json::from_str(json).unwrap();
205        assert_eq!(raw.cache_read_input_tokens, Some(30));
206        assert_eq!(raw.cache_creation_input_tokens, Some(10));
207    }
208}