Skip to main content

codex_helper_core/
usage.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4fn i64_is_zero(value: &i64) -> bool {
5    *value == 0
6}
7
8#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
9#[serde(rename_all = "snake_case")]
10pub enum CacheInputAccounting {
11    #[default]
12    DirectReadSeparate,
13    DirectReadIncludedInInput,
14}
15
16impl CacheInputAccounting {
17    pub fn for_service(service: &str) -> Self {
18        match service.trim().to_ascii_lowercase().as_str() {
19            "codex" | "gemini" => Self::DirectReadIncludedInInput,
20            _ => Self::DirectReadSeparate,
21        }
22    }
23
24    fn includes_direct_read_in_input(self) -> bool {
25        matches!(self, Self::DirectReadIncludedInInput)
26    }
27}
28
29#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
30pub struct CacheUsageBreakdown {
31    pub input_tokens: i64,
32    pub cached_input_tokens: i64,
33    pub direct_cache_read_input_tokens: i64,
34    pub cache_read_input_tokens: i64,
35    pub cache_creation_input_tokens: i64,
36    pub effective_input_tokens: i64,
37    pub denominator_tokens: i64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
41pub struct UsageMetrics {
42    #[serde(default)]
43    pub input_tokens: i64,
44    #[serde(default)]
45    pub output_tokens: i64,
46    #[serde(default)]
47    pub reasoning_tokens: i64,
48    #[serde(default, skip_serializing_if = "i64_is_zero")]
49    pub reasoning_output_tokens: i64,
50    #[serde(default)]
51    pub total_tokens: i64,
52    #[serde(default, skip_serializing_if = "i64_is_zero")]
53    pub cached_input_tokens: i64,
54    #[serde(
55        default,
56        alias = "cache_read_tokens",
57        skip_serializing_if = "i64_is_zero"
58    )]
59    pub cache_read_input_tokens: i64,
60    #[serde(
61        default,
62        alias = "cache_creation_tokens",
63        skip_serializing_if = "i64_is_zero"
64    )]
65    pub cache_creation_input_tokens: i64,
66    #[serde(default, skip_serializing_if = "i64_is_zero")]
67    pub cache_creation_5m_input_tokens: i64,
68    #[serde(default, skip_serializing_if = "i64_is_zero")]
69    pub cache_creation_1h_input_tokens: i64,
70}
71
72impl UsageMetrics {
73    pub fn add_assign(&mut self, other: &UsageMetrics) {
74        self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
75        self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
76        self.reasoning_tokens = self.reasoning_tokens.saturating_add(other.reasoning_tokens);
77        self.reasoning_output_tokens = self
78            .reasoning_output_tokens
79            .saturating_add(other.reasoning_output_tokens_total());
80        self.total_tokens = self.total_tokens.saturating_add(other.total_tokens);
81        self.cached_input_tokens = self
82            .cached_input_tokens
83            .saturating_add(other.cached_input_tokens);
84        self.cache_read_input_tokens = self
85            .cache_read_input_tokens
86            .saturating_add(other.cache_read_input_tokens);
87        self.cache_creation_input_tokens = self
88            .cache_creation_input_tokens
89            .saturating_add(other.cache_creation_input_tokens);
90        self.cache_creation_5m_input_tokens = self
91            .cache_creation_5m_input_tokens
92            .saturating_add(other.cache_creation_5m_input_tokens);
93        self.cache_creation_1h_input_tokens = self
94            .cache_creation_1h_input_tokens
95            .saturating_add(other.cache_creation_1h_input_tokens);
96    }
97
98    pub fn reasoning_output_tokens_total(&self) -> i64 {
99        self.reasoning_output_tokens.max(self.reasoning_tokens)
100    }
101
102    pub fn cache_creation_tokens_total(&self) -> i64 {
103        let by_ttl = self
104            .cache_creation_5m_input_tokens
105            .saturating_add(self.cache_creation_1h_input_tokens);
106        self.cache_creation_input_tokens.max(by_ttl)
107    }
108
109    pub fn has_cache_tokens(&self) -> bool {
110        self.cached_input_tokens > 0
111            || self.cache_read_input_tokens > 0
112            || self.cache_creation_tokens_total() > 0
113    }
114
115    pub fn cache_read_tokens_total(&self) -> i64 {
116        self.cached_input_tokens
117            .max(0)
118            .saturating_add(self.cache_read_input_tokens.max(0))
119    }
120
121    pub fn cache_usage_breakdown(&self, accounting: CacheInputAccounting) -> CacheUsageBreakdown {
122        let input = self.input_tokens.max(0);
123        let cached = self.cached_input_tokens.max(0);
124        let direct_read = self.cache_read_input_tokens.max(0);
125        let read = cached.saturating_add(direct_read);
126        let create = self.cache_creation_tokens_total().max(0);
127        let included_read = if accounting.includes_direct_read_in_input() {
128            read
129        } else {
130            cached
131        };
132        let effective_input = input.saturating_sub(included_read);
133        let denom = effective_input.saturating_add(create).saturating_add(read);
134
135        CacheUsageBreakdown {
136            input_tokens: input,
137            cached_input_tokens: cached,
138            direct_cache_read_input_tokens: direct_read,
139            cache_read_input_tokens: read,
140            cache_creation_input_tokens: create,
141            effective_input_tokens: effective_input,
142            denominator_tokens: denom,
143        }
144    }
145
146    pub fn cache_hit_rate_with_accounting(&self, accounting: CacheInputAccounting) -> Option<f64> {
147        let breakdown = self.cache_usage_breakdown(accounting);
148        if breakdown.denominator_tokens <= 0 {
149            return None;
150        }
151        Some(breakdown.cache_read_input_tokens as f64 / breakdown.denominator_tokens as f64)
152    }
153
154    pub fn cache_hit_rate_for_service(&self, service: &str) -> Option<f64> {
155        self.cache_hit_rate_with_accounting(CacheInputAccounting::for_service(service))
156    }
157
158    pub fn cache_hit_rate(&self) -> Option<f64> {
159        self.cache_hit_rate_with_accounting(CacheInputAccounting::default())
160    }
161
162    pub fn effective_input_tokens_with_accounting(&self, accounting: CacheInputAccounting) -> i64 {
163        self.cache_usage_breakdown(accounting)
164            .effective_input_tokens
165    }
166
167    pub fn cache_denominator_tokens_with_accounting(
168        &self,
169        accounting: CacheInputAccounting,
170    ) -> Option<i64> {
171        let denom = self.cache_usage_breakdown(accounting).denominator_tokens;
172        if denom <= 0 {
173            return None;
174        }
175        Some(denom)
176    }
177
178    fn derived_total_tokens(&self) -> i64 {
179        self.input_tokens
180            .saturating_add(self.output_tokens)
181            .saturating_add(self.cache_read_input_tokens)
182            .saturating_add(self.cache_creation_tokens_total())
183    }
184}
185
186fn to_i64(v: &Value) -> i64 {
187    match v {
188        Value::Number(n) => n.as_i64().unwrap_or(0),
189        Value::String(s) => s.parse::<f64>().ok().map(|f| f as i64).unwrap_or(0),
190        _ => 0,
191    }
192}
193
194fn extract_usage_obj(payload: &Value) -> Option<&Value> {
195    if let Some(u) = payload.get("usage") {
196        return Some(u);
197    }
198    if let Some(resp) = payload.get("response")
199        && let Some(u) = resp.get("usage")
200    {
201        return Some(u);
202    }
203    None
204}
205
206fn usage_from_value(usage_obj: &Value) -> Option<UsageMetrics> {
207    let mut m = UsageMetrics::default();
208    let mut recognized = false;
209    let mut total_provided = false;
210
211    if let Some(v) = usage_obj.get("input_tokens") {
212        m.input_tokens = to_i64(v);
213        recognized = true;
214    }
215    if let Some(v) = usage_obj.get("output_tokens") {
216        m.output_tokens = to_i64(v);
217        recognized = true;
218    }
219    if let Some(v) = usage_obj.get("total_tokens") {
220        m.total_tokens = to_i64(v);
221        recognized = true;
222        total_provided = true;
223    }
224
225    // OpenAI Chat Completions compatibility (`prompt_tokens` / `completion_tokens`).
226    if let Some(v) = usage_obj.get("prompt_tokens") {
227        m.input_tokens = to_i64(v);
228        recognized = true;
229    }
230    if let Some(v) = usage_obj.get("completion_tokens") {
231        m.output_tokens = to_i64(v);
232        recognized = true;
233    }
234
235    // Some providers may expose reasoning tokens directly.
236    if let Some(v) = usage_obj.get("reasoning_tokens") {
237        let value = to_i64(v);
238        m.reasoning_tokens = value;
239        m.reasoning_output_tokens = value;
240        recognized = true;
241    }
242    if let Some(v) = usage_obj.get("reasoning_output_tokens") {
243        let value = to_i64(v);
244        m.reasoning_output_tokens = value;
245        m.reasoning_tokens = m.reasoning_tokens.max(value);
246        recognized = true;
247    }
248
249    if let Some(details) = usage_obj
250        .get("output_tokens_details")
251        .and_then(|v| v.as_object())
252        && let Some(v) = details.get("reasoning_tokens")
253    {
254        let value = to_i64(v);
255        m.reasoning_tokens = value;
256        m.reasoning_output_tokens = value;
257        recognized = true;
258    }
259    if let Some(details) = usage_obj
260        .get("completion_tokens_details")
261        .and_then(|v| v.as_object())
262        && let Some(v) = details.get("reasoning_tokens")
263    {
264        let value = to_i64(v);
265        m.reasoning_tokens = value;
266        m.reasoning_output_tokens = value;
267        recognized = true;
268    }
269
270    if let Some(details) = usage_obj
271        .get("input_tokens_details")
272        .or_else(|| usage_obj.get("input_token_details"))
273        .and_then(|v| v.as_object())
274        && let Some(v) = details.get("cached_tokens")
275    {
276        m.cached_input_tokens = to_i64(v);
277        recognized = true;
278    } else if let Some(details) = usage_obj
279        .get("prompt_tokens_details")
280        .or_else(|| usage_obj.get("prompt_token_details"))
281        .and_then(|v| v.as_object())
282        && let Some(v) = details.get("cached_tokens")
283    {
284        m.cached_input_tokens = to_i64(v);
285        recognized = true;
286    } else if let Some(v) = usage_obj.get("cached_input_tokens") {
287        m.cached_input_tokens = to_i64(v);
288        recognized = true;
289    } else if let Some(v) = usage_obj
290        .get("cache_read_input_tokens")
291        .or_else(|| usage_obj.get("cache_read_tokens"))
292    {
293        m.cache_read_input_tokens = to_i64(v);
294        recognized = true;
295    }
296    if let Some(v) = usage_obj
297        .get("cache_creation_input_tokens")
298        .or_else(|| usage_obj.get("cache_creation_tokens"))
299    {
300        m.cache_creation_input_tokens = to_i64(v);
301        recognized = true;
302    }
303    if let Some(v) = usage_obj.get("cache_creation_5m_input_tokens") {
304        m.cache_creation_5m_input_tokens = to_i64(v);
305        recognized = true;
306    }
307    if let Some(v) = usage_obj.get("cache_creation_1h_input_tokens") {
308        m.cache_creation_1h_input_tokens = to_i64(v);
309        recognized = true;
310    }
311    if m.cache_creation_input_tokens == 0 {
312        m.cache_creation_input_tokens = m
313            .cache_creation_5m_input_tokens
314            .saturating_add(m.cache_creation_1h_input_tokens);
315    }
316
317    // If total isn't provided, derive it from input/output when possible.
318    if !total_provided {
319        m.total_tokens = m.derived_total_tokens();
320    }
321
322    if !recognized {
323        return None;
324    }
325    Some(m)
326}
327
328pub fn extract_usage_from_bytes(data: &[u8]) -> Option<UsageMetrics> {
329    let text = std::str::from_utf8(data).ok()?.trim();
330    if text.is_empty() {
331        return None;
332    }
333    let json: Value = serde_json::from_str(text).ok()?;
334    let usage_obj = extract_usage_obj(&json)?;
335    usage_from_value(usage_obj)
336}
337
338#[allow(dead_code)]
339pub fn extract_usage_from_sse_bytes(data: &[u8]) -> Option<UsageMetrics> {
340    let text = std::str::from_utf8(data).ok()?;
341    let mut last: Option<UsageMetrics> = None;
342
343    for chunk in text.split("\n\n") {
344        let lines: Vec<&str> = chunk
345            .lines()
346            .map(|l| l.trim())
347            .filter(|l| !l.is_empty())
348            .collect();
349        for line in lines {
350            if let Some(rest) = line.strip_prefix("data:") {
351                let payload_str = rest.trim();
352                if payload_str.is_empty() {
353                    continue;
354                }
355                if let Ok(json) = serde_json::from_str::<Value>(payload_str)
356                    && let Some(usage_obj) = extract_usage_obj(&json)
357                    && let Some(u) = usage_from_value(usage_obj)
358                {
359                    last = Some(u);
360                }
361            }
362        }
363    }
364
365    last
366}
367
368/// Incrementally scan SSE bytes for `data: {json}` lines that contain usage information.
369///
370/// This is designed for streaming scenarios where the response arrives in many chunks:
371/// it avoids repeatedly re-parsing the entire buffer (which can become O(n^2)).
372///
373/// - `scan_pos` is an in/out cursor into `data` (byte index).
374/// - `last` stores the latest usage parsed so far (updated in-place).
375pub fn scan_usage_from_sse_bytes_incremental(
376    data: &[u8],
377    scan_pos: &mut usize,
378    last: &mut Option<UsageMetrics>,
379) {
380    let mut i = (*scan_pos).min(data.len());
381
382    while i < data.len() {
383        let Some(rel_end) = data[i..].iter().position(|b| *b == b'\n') else {
384            break;
385        };
386        let end = i + rel_end;
387        let mut line = &data[i..end];
388        i = end.saturating_add(1);
389
390        if line.ends_with(b"\r") {
391            line = &line[..line.len().saturating_sub(1)];
392        }
393        if line.is_empty() {
394            continue;
395        }
396
397        const DATA_PREFIX: &[u8] = b"data:";
398        if !line.starts_with(DATA_PREFIX) {
399            continue;
400        }
401        let mut payload = &line[DATA_PREFIX.len()..];
402        while !payload.is_empty() && payload[0].is_ascii_whitespace() {
403            payload = &payload[1..];
404        }
405        if payload.is_empty() || payload == b"[DONE]" {
406            continue;
407        }
408
409        if let Ok(json) = serde_json::from_slice::<Value>(payload)
410            && let Some(usage_obj) = extract_usage_obj(&json)
411            && let Some(u) = usage_from_value(usage_obj)
412        {
413            *last = Some(u);
414        }
415    }
416
417    *scan_pos = i;
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    use pretty_assertions::assert_eq;
425
426    #[test]
427    fn incremental_sse_scan_matches_full_parse() {
428        let sse = concat!(
429            "event: response.output_text.delta\n",
430            "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hi\"}\n",
431            "\n",
432            "event: response.completed\n",
433            "data: {\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n",
434            "\n"
435        );
436
437        let full = extract_usage_from_sse_bytes(sse.as_bytes());
438        let mut pos = 0usize;
439        let mut last = None;
440        scan_usage_from_sse_bytes_incremental(sse.as_bytes(), &mut pos, &mut last);
441        assert_eq!(last, full);
442    }
443
444    #[test]
445    fn incremental_sse_scan_handles_split_lines() {
446        let part1 = b"data: {\"response\":{\"usage\":{\"input_tokens\":1";
447        let part2 = b",\"output_tokens\":2,\"total_tokens\":3}}}\n\n";
448        let mut buf = Vec::new();
449        let mut pos = 0usize;
450        let mut last = None;
451
452        buf.extend_from_slice(part1);
453        scan_usage_from_sse_bytes_incremental(&buf, &mut pos, &mut last);
454        assert_eq!(last, None);
455
456        buf.extend_from_slice(part2);
457        scan_usage_from_sse_bytes_incremental(&buf, &mut pos, &mut last);
458        assert_eq!(
459            last,
460            Some(UsageMetrics {
461                input_tokens: 1,
462                output_tokens: 2,
463                total_tokens: 3,
464                ..UsageMetrics::default()
465            })
466        );
467    }
468
469    #[test]
470    fn parses_chat_completions_usage_fields() {
471        let json = r#"{
472          "id":"chatcmpl_x",
473          "object":"chat.completion",
474          "usage":{
475            "prompt_tokens":9,
476            "completion_tokens":12,
477            "total_tokens":21,
478            "completion_tokens_details":{"reasoning_tokens":5}
479          }
480        }"#;
481        assert_eq!(
482            extract_usage_from_bytes(json.as_bytes()),
483            Some(UsageMetrics {
484                input_tokens: 9,
485                output_tokens: 12,
486                reasoning_tokens: 5,
487                reasoning_output_tokens: 5,
488                total_tokens: 21,
489                ..UsageMetrics::default()
490            })
491        );
492    }
493
494    #[test]
495    fn parses_responses_usage_cache_and_reasoning_details() {
496        let json = r#"{
497          "response":{
498            "usage":{
499              "input_tokens":100,
500              "output_tokens":20,
501              "input_tokens_details":{"cached_tokens":40},
502              "output_tokens_details":{"reasoning_tokens":7}
503            }
504          }
505        }"#;
506        assert_eq!(
507            extract_usage_from_bytes(json.as_bytes()),
508            Some(UsageMetrics {
509                input_tokens: 100,
510                output_tokens: 20,
511                reasoning_tokens: 7,
512                reasoning_output_tokens: 7,
513                cached_input_tokens: 40,
514                total_tokens: 120,
515                ..UsageMetrics::default()
516            })
517        );
518    }
519
520    #[test]
521    fn parses_openai_cached_tokens_before_direct_cache_read_tokens() {
522        let json = r#"{
523          "usage":{
524            "input_tokens":100,
525            "output_tokens":20,
526            "input_tokens_details":{"cached_tokens":40},
527            "cache_read_input_tokens":30
528          }
529        }"#;
530        assert_eq!(
531            extract_usage_from_bytes(json.as_bytes()),
532            Some(UsageMetrics {
533                input_tokens: 100,
534                output_tokens: 20,
535                cached_input_tokens: 40,
536                total_tokens: 120,
537                ..UsageMetrics::default()
538            })
539        );
540    }
541
542    #[test]
543    fn parses_anthropic_cache_usage_fields() {
544        let json = r#"{
545          "usage":{
546            "input_tokens":10,
547            "output_tokens":5,
548            "cache_read_input_tokens":30,
549            "cache_creation_5m_input_tokens":20,
550            "cache_creation_1h_input_tokens":40
551          }
552        }"#;
553        assert_eq!(
554            extract_usage_from_bytes(json.as_bytes()),
555            Some(UsageMetrics {
556                input_tokens: 10,
557                output_tokens: 5,
558                total_tokens: 105,
559                cache_read_input_tokens: 30,
560                cache_creation_input_tokens: 60,
561                cache_creation_5m_input_tokens: 20,
562                cache_creation_1h_input_tokens: 40,
563                ..UsageMetrics::default()
564            })
565        );
566    }
567
568    #[test]
569    fn computes_cache_hit_rate_from_read_and_creation_tokens() {
570        let usage = UsageMetrics {
571            input_tokens: 100,
572            cache_read_input_tokens: 30,
573            cache_creation_input_tokens: 20,
574            ..UsageMetrics::default()
575        };
576
577        let rate = usage.cache_hit_rate().expect("cache hit rate");
578
579        assert_eq!(rate, 0.2);
580    }
581
582    #[test]
583    fn computes_cache_hit_rate_when_direct_cache_read_is_included_in_input() {
584        let usage = UsageMetrics {
585            input_tokens: 100,
586            cache_read_input_tokens: 30,
587            cache_creation_input_tokens: 20,
588            ..UsageMetrics::default()
589        };
590
591        let rate = usage
592            .cache_hit_rate_with_accounting(CacheInputAccounting::DirectReadIncludedInInput)
593            .expect("cache hit rate");
594
595        assert_eq!(rate, 0.25);
596    }
597
598    #[test]
599    fn computes_cache_hit_rate_from_cached_input_tokens() {
600        let usage = UsageMetrics {
601            input_tokens: 100,
602            cached_input_tokens: 40,
603            ..UsageMetrics::default()
604        };
605
606        let rate = usage.cache_hit_rate().expect("cache hit rate");
607
608        assert_eq!(rate, 0.4);
609    }
610
611    #[test]
612    fn computes_cache_hit_rate_from_mixed_usage_cache_fields() {
613        let usage = UsageMetrics {
614            input_tokens: 1_500,
615            cached_input_tokens: 50,
616            cache_read_input_tokens: 250,
617            ..UsageMetrics::default()
618        };
619
620        let rate = usage.cache_hit_rate().expect("cache hit rate");
621
622        assert!((rate - (300.0 / 1_750.0)).abs() < f64::EPSILON);
623    }
624
625    #[test]
626    fn computes_service_specific_cache_hit_rate_from_mixed_usage_cache_fields() {
627        let usage = UsageMetrics {
628            input_tokens: 1_500,
629            cached_input_tokens: 50,
630            cache_read_input_tokens: 250,
631            ..UsageMetrics::default()
632        };
633
634        let rate = usage
635            .cache_hit_rate_for_service("codex")
636            .expect("cache hit rate");
637
638        assert!((rate - (300.0 / 1_500.0)).abs() < f64::EPSILON);
639    }
640
641    #[test]
642    fn unknown_usage_schema_returns_none() {
643        let json = r#"{"usage":{"foo":123}}"#;
644        assert_eq!(extract_usage_from_bytes(json.as_bytes()), None);
645    }
646}