Skip to main content

codex_helper_core/
usage.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
5pub struct UsageMetrics {
6    #[serde(default)]
7    pub input_tokens: i64,
8    #[serde(default)]
9    pub output_tokens: i64,
10    #[serde(default)]
11    pub reasoning_tokens: i64,
12    #[serde(default)]
13    pub total_tokens: i64,
14}
15
16impl UsageMetrics {
17    pub fn add_assign(&mut self, other: &UsageMetrics) {
18        self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
19        self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
20        self.reasoning_tokens = self.reasoning_tokens.saturating_add(other.reasoning_tokens);
21        self.total_tokens = self.total_tokens.saturating_add(other.total_tokens);
22    }
23}
24
25fn to_i64(v: &Value) -> i64 {
26    match v {
27        Value::Number(n) => n.as_i64().unwrap_or(0),
28        Value::String(s) => s.parse::<f64>().ok().map(|f| f as i64).unwrap_or(0),
29        _ => 0,
30    }
31}
32
33fn extract_usage_obj(payload: &Value) -> Option<&Value> {
34    if let Some(u) = payload.get("usage") {
35        return Some(u);
36    }
37    if let Some(resp) = payload.get("response")
38        && let Some(u) = resp.get("usage")
39    {
40        return Some(u);
41    }
42    None
43}
44
45fn usage_from_value(usage_obj: &Value) -> Option<UsageMetrics> {
46    let mut m = UsageMetrics::default();
47    let mut recognized = false;
48
49    if let Some(v) = usage_obj.get("input_tokens") {
50        m.input_tokens = to_i64(v);
51        recognized = true;
52    }
53    if let Some(v) = usage_obj.get("output_tokens") {
54        m.output_tokens = to_i64(v);
55        recognized = true;
56    }
57    if let Some(v) = usage_obj.get("total_tokens") {
58        m.total_tokens = to_i64(v);
59        recognized = true;
60    }
61
62    // OpenAI Chat Completions compatibility (`prompt_tokens` / `completion_tokens`).
63    if let Some(v) = usage_obj.get("prompt_tokens") {
64        m.input_tokens = to_i64(v);
65        recognized = true;
66    }
67    if let Some(v) = usage_obj.get("completion_tokens") {
68        m.output_tokens = to_i64(v);
69        recognized = true;
70    }
71
72    // Some providers may expose reasoning tokens directly.
73    if let Some(v) = usage_obj.get("reasoning_tokens") {
74        m.reasoning_tokens = to_i64(v);
75        recognized = true;
76    }
77
78    if let Some(details) = usage_obj
79        .get("output_tokens_details")
80        .and_then(|v| v.as_object())
81        && let Some(v) = details.get("reasoning_tokens")
82    {
83        m.reasoning_tokens = to_i64(v);
84        recognized = true;
85    }
86    if let Some(details) = usage_obj
87        .get("completion_tokens_details")
88        .and_then(|v| v.as_object())
89        && let Some(v) = details.get("reasoning_tokens")
90    {
91        m.reasoning_tokens = to_i64(v);
92        recognized = true;
93    }
94
95    // If total isn't provided, derive it from input/output when possible.
96    if usage_obj.get("total_tokens").is_none() {
97        m.total_tokens = m.input_tokens.saturating_add(m.output_tokens);
98    }
99
100    if !recognized {
101        return None;
102    }
103    Some(m)
104}
105
106pub fn extract_usage_from_bytes(data: &[u8]) -> Option<UsageMetrics> {
107    let text = std::str::from_utf8(data).ok()?.trim();
108    if text.is_empty() {
109        return None;
110    }
111    let json: Value = serde_json::from_str(text).ok()?;
112    let usage_obj = extract_usage_obj(&json)?;
113    usage_from_value(usage_obj)
114}
115
116#[allow(dead_code)]
117pub fn extract_usage_from_sse_bytes(data: &[u8]) -> Option<UsageMetrics> {
118    let text = std::str::from_utf8(data).ok()?;
119    let mut last: Option<UsageMetrics> = None;
120
121    for chunk in text.split("\n\n") {
122        let lines: Vec<&str> = chunk
123            .lines()
124            .map(|l| l.trim())
125            .filter(|l| !l.is_empty())
126            .collect();
127        for line in lines {
128            if let Some(rest) = line.strip_prefix("data:") {
129                let payload_str = rest.trim();
130                if payload_str.is_empty() {
131                    continue;
132                }
133                if let Ok(json) = serde_json::from_str::<Value>(payload_str)
134                    && let Some(usage_obj) = extract_usage_obj(&json)
135                    && let Some(u) = usage_from_value(usage_obj)
136                {
137                    last = Some(u);
138                }
139            }
140        }
141    }
142
143    last
144}
145
146/// Incrementally scan SSE bytes for `data: {json}` lines that contain usage information.
147///
148/// This is designed for streaming scenarios where the response arrives in many chunks:
149/// it avoids repeatedly re-parsing the entire buffer (which can become O(n^2)).
150///
151/// - `scan_pos` is an in/out cursor into `data` (byte index).
152/// - `last` stores the latest usage parsed so far (updated in-place).
153pub fn scan_usage_from_sse_bytes_incremental(
154    data: &[u8],
155    scan_pos: &mut usize,
156    last: &mut Option<UsageMetrics>,
157) {
158    let mut i = (*scan_pos).min(data.len());
159
160    while i < data.len() {
161        let Some(rel_end) = data[i..].iter().position(|b| *b == b'\n') else {
162            break;
163        };
164        let end = i + rel_end;
165        let mut line = &data[i..end];
166        i = end.saturating_add(1);
167
168        if line.ends_with(b"\r") {
169            line = &line[..line.len().saturating_sub(1)];
170        }
171        if line.is_empty() {
172            continue;
173        }
174
175        const DATA_PREFIX: &[u8] = b"data:";
176        if !line.starts_with(DATA_PREFIX) {
177            continue;
178        }
179        let mut payload = &line[DATA_PREFIX.len()..];
180        while !payload.is_empty() && payload[0].is_ascii_whitespace() {
181            payload = &payload[1..];
182        }
183        if payload.is_empty() || payload == b"[DONE]" {
184            continue;
185        }
186
187        if let Ok(json) = serde_json::from_slice::<Value>(payload)
188            && let Some(usage_obj) = extract_usage_obj(&json)
189            && let Some(u) = usage_from_value(usage_obj)
190        {
191            *last = Some(u);
192        }
193    }
194
195    *scan_pos = i;
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    use pretty_assertions::assert_eq;
203
204    #[test]
205    fn incremental_sse_scan_matches_full_parse() {
206        let sse = concat!(
207            "event: response.output_text.delta\n",
208            "data: {\"type\":\"response.output_text.delta\",\"delta\":\"hi\"}\n",
209            "\n",
210            "event: response.completed\n",
211            "data: {\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n",
212            "\n"
213        );
214
215        let full = extract_usage_from_sse_bytes(sse.as_bytes());
216        let mut pos = 0usize;
217        let mut last = None;
218        scan_usage_from_sse_bytes_incremental(sse.as_bytes(), &mut pos, &mut last);
219        assert_eq!(last, full);
220    }
221
222    #[test]
223    fn incremental_sse_scan_handles_split_lines() {
224        let part1 = b"data: {\"response\":{\"usage\":{\"input_tokens\":1";
225        let part2 = b",\"output_tokens\":2,\"total_tokens\":3}}}\n\n";
226        let mut buf = Vec::new();
227        let mut pos = 0usize;
228        let mut last = None;
229
230        buf.extend_from_slice(part1);
231        scan_usage_from_sse_bytes_incremental(&buf, &mut pos, &mut last);
232        assert_eq!(last, None);
233
234        buf.extend_from_slice(part2);
235        scan_usage_from_sse_bytes_incremental(&buf, &mut pos, &mut last);
236        assert_eq!(
237            last,
238            Some(UsageMetrics {
239                input_tokens: 1,
240                output_tokens: 2,
241                reasoning_tokens: 0,
242                total_tokens: 3,
243            })
244        );
245    }
246
247    #[test]
248    fn parses_chat_completions_usage_fields() {
249        let json = r#"{
250          "id":"chatcmpl_x",
251          "object":"chat.completion",
252          "usage":{
253            "prompt_tokens":9,
254            "completion_tokens":12,
255            "total_tokens":21,
256            "completion_tokens_details":{"reasoning_tokens":5}
257          }
258        }"#;
259        assert_eq!(
260            extract_usage_from_bytes(json.as_bytes()),
261            Some(UsageMetrics {
262                input_tokens: 9,
263                output_tokens: 12,
264                reasoning_tokens: 5,
265                total_tokens: 21,
266            })
267        );
268    }
269
270    #[test]
271    fn unknown_usage_schema_returns_none() {
272        let json = r#"{"usage":{"foo":123}}"#;
273        assert_eq!(extract_usage_from_bytes(json.as_bytes()), None);
274    }
275}