Skip to main content

ai_usagebar/openai/
fetch.rs

1//! Orchestrate: read ~/.codex/auth.json → maybe refresh → fetch usage → cache.
2//!
3//! Mirrors `anthropic::fetch::fetch_snapshot` but for the Codex OAuth flow.
4
5use std::path::Path;
6use std::time::Duration;
7
8use chrono::Utc;
9
10use crate::cache::{Cache, acquire_lock};
11use crate::error::{AppError, Result};
12use crate::usage::OpenAiSnapshot;
13
14use super::creds::{self, Tokens};
15use super::oauth;
16use super::types::UsageResponse;
17
18pub const USAGE_URL: &str = "https://chatgpt.com/backend-api/wham/usage";
19const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
20const REFRESH_TIMEOUT: Duration = Duration::from_secs(25);
21const LOCK_TIMEOUT: Duration = Duration::from_secs(45);
22
23#[derive(Debug, Clone)]
24pub struct Endpoints {
25    pub usage: String,
26    pub token: String,
27}
28
29impl Default for Endpoints {
30    fn default() -> Self {
31        Self {
32            usage: USAGE_URL.into(),
33            token: oauth::TOKEN_URL.into(),
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct FetchOutcome {
40    pub snapshot: OpenAiSnapshot,
41    pub stale: bool,
42    pub last_error: Option<(u16, String)>,
43    pub cache_age: Option<Duration>,
44}
45
46pub async fn fetch_snapshot(
47    client: &reqwest::Client,
48    creds_path: &Path,
49    cache: &Cache,
50    endpoints: &Endpoints,
51    cache_ttl: Duration,
52) -> Result<FetchOutcome> {
53    cache.ensure_dir()?;
54    let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
55
56    let mut auth = creds::read_from(creds_path)?;
57    let plan_hint = auth.tokens.plan_type_from_id_token();
58
59    if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
60        return Ok(reuse(bytes, cache, false, plan_hint.as_deref()));
61    }
62
63    // Maybe refresh — Codex CLI doesn't always populate expires_at, so we use
64    // the id_token's exp claim.
65    let now = Utc::now().timestamp();
66    if oauth::needs_refresh(auth.tokens.expires_at_secs(), now) {
67        match tokio::time::timeout(
68            REFRESH_TIMEOUT,
69            oauth::refresh(client, &endpoints.token, &auth.tokens.refresh_token),
70        )
71        .await
72        {
73            Ok(Ok(rr)) => {
74                auth.tokens.access_token = rr.access_token;
75                if let Some(rt) = rr.refresh_token {
76                    auth.tokens.refresh_token = rt;
77                }
78                if let Some(id) = rr.id_token {
79                    auth.tokens.id_token = id;
80                }
81                let _ = creds::write_back(creds_path, &auth);
82            }
83            Ok(Err(AppError::Http { status, body })) => {
84                cache.write_last_error(status, &body);
85                return handle_auth_failure(cache, plan_hint.as_deref(), false);
86            }
87            Ok(Err(e)) if e.is_transient() => {
88                return handle_auth_failure(cache, plan_hint.as_deref(), true);
89            }
90            Ok(Err(e)) => {
91                cache.write_last_error(0, &e.to_string());
92                return handle_auth_failure(cache, plan_hint.as_deref(), false);
93            }
94            Err(_) => return handle_auth_failure(cache, plan_hint.as_deref(), true),
95        }
96    }
97
98    match tokio::time::timeout(
99        HTTP_TIMEOUT,
100        fetch_usage(client, &endpoints.usage, &auth.tokens),
101    )
102    .await
103    {
104        Ok(Ok(bytes)) => {
105            cache.write_payload(&bytes)?;
106            let snap = parse_payload(&bytes, plan_hint.as_deref())?;
107            Ok(FetchOutcome {
108                snapshot: snap,
109                stale: false,
110                last_error: None,
111                cache_age: Some(Duration::ZERO),
112            })
113        }
114        Ok(Err(AppError::Http { status, body })) => {
115            cache.mark_stale();
116            cache.write_last_error(status, &body);
117            fallback(cache, plan_hint.as_deref(), Some((status, body)))
118        }
119        Ok(Err(e)) if e.is_transient() => fallback_silent(cache, plan_hint.as_deref()),
120        Ok(Err(e)) => {
121            cache.mark_stale();
122            cache.write_last_error(0, &e.to_string());
123            fallback(cache, plan_hint.as_deref(), Some((0, e.to_string())))
124        }
125        Err(_) => fallback_silent(cache, plan_hint.as_deref()),
126    }
127}
128
129fn reuse(bytes: Vec<u8>, cache: &Cache, stale: bool, plan_hint: Option<&str>) -> FetchOutcome {
130    let snap = parse_payload(&bytes, plan_hint).unwrap_or_else(|_| empty(plan_hint));
131    FetchOutcome {
132        snapshot: snap,
133        stale,
134        last_error: cache.read_last_error(),
135        cache_age: cache.payload_age(),
136    }
137}
138
139fn fallback(
140    cache: &Cache,
141    plan_hint: Option<&str>,
142    last_error: Option<(u16, String)>,
143) -> Result<FetchOutcome> {
144    let Some(bytes) = cache.maybe_payload()? else {
145        return Err(AppError::Other("openai: no usable cache".into()));
146    };
147    let mut out = reuse(bytes, cache, true, plan_hint);
148    out.last_error = last_error;
149    Ok(out)
150}
151
152fn fallback_silent(cache: &Cache, plan_hint: Option<&str>) -> Result<FetchOutcome> {
153    let Some(bytes) = cache.maybe_payload()? else {
154        return Err(AppError::Transport(
155            "openai: no cache and network unreachable".into(),
156        ));
157    };
158    Ok(reuse(bytes, cache, true, plan_hint))
159}
160
161fn handle_auth_failure(
162    cache: &Cache,
163    plan_hint: Option<&str>,
164    transient: bool,
165) -> Result<FetchOutcome> {
166    let Some(bytes) = cache.maybe_payload()? else {
167        return if transient {
168            Err(AppError::Transport(
169                "openai: no cache and refresh failed transiently".into(),
170            ))
171        } else {
172            Err(AppError::Credentials(
173                "openai: token refresh failed; run `codex login` to re-auth".into(),
174            ))
175        };
176    };
177    Ok(reuse(bytes, cache, true, plan_hint))
178}
179
180fn parse_payload(bytes: &[u8], plan_hint: Option<&str>) -> Result<OpenAiSnapshot> {
181    let r: UsageResponse = serde_json::from_slice(bytes)?;
182    Ok(r.into_snapshot(plan_hint))
183}
184
185fn empty(plan_hint: Option<&str>) -> OpenAiSnapshot {
186    UsageResponse::default().into_snapshot(plan_hint)
187}
188
189async fn fetch_usage(client: &reqwest::Client, url: &str, t: &Tokens) -> Result<Vec<u8>> {
190    let mut req = client
191        .get(url)
192        .header("Authorization", format!("Bearer {}", t.access_token))
193        .header("User-Agent", "codex-cli");
194    if let Some(aid) = t.account_id.as_deref() {
195        req = req.header("ChatGPT-Account-Id", aid);
196    }
197    let resp = req.send().await?;
198    let status = resp.status();
199    let bytes = resp.bytes().await?.to_vec();
200
201    if !status.is_success() {
202        let body: String = String::from_utf8_lossy(&bytes).chars().take(200).collect();
203        return Err(AppError::Http {
204            status: status.as_u16(),
205            body,
206        });
207    }
208    let _: UsageResponse = serde_json::from_slice(&bytes)
209        .map_err(|e| AppError::Schema(format!("openai usage response: {e}")))?;
210    Ok(bytes)
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use base64::Engine;
217    use std::io::Write;
218    use tempfile::{NamedTempFile, TempDir};
219
220    fn fake_jwt(claims: serde_json::Value) -> String {
221        let h = base64::engine::general_purpose::URL_SAFE_NO_PAD
222            .encode(br#"{"alg":"none","typ":"JWT"}"#);
223        let p =
224            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(claims.to_string().as_bytes());
225        format!("{h}.{p}.sig")
226    }
227
228    fn future_creds() -> NamedTempFile {
229        // exp 1h in the future.
230        let exp = Utc::now().timestamp() + 3600;
231        let jwt = fake_jwt(serde_json::json!({
232            "exp": exp,
233            "https://api.openai.com/auth": {"chatgpt_plan_type": "plus"}
234        }));
235        let body = format!(
236            r#"{{"tokens":{{"access_token":"AT","refresh_token":"RT","id_token":"{jwt}",
237                "account_id":"acc"}}}}"#
238        );
239        let mut f = NamedTempFile::new().unwrap();
240        f.write_all(body.as_bytes()).unwrap();
241        f.flush().unwrap();
242        f
243    }
244
245    fn cache_fixture() -> (TempDir, Cache) {
246        let td = TempDir::new().unwrap();
247        let c = Cache::at(td.path().join("openai"));
248        c.ensure_dir().unwrap();
249        (td, c)
250    }
251
252    #[tokio::test]
253    async fn live_200_returns_snapshot_with_plan_from_id_token() {
254        let mut server = mockito::Server::new_async().await;
255        server
256            .mock("GET", "/backend-api/wham/usage")
257            .with_status(200)
258            .with_body(
259                r#"{"plan_type":"plus","rate_limit":{
260                "primary_window":{"used_percent":1,"limit_window_seconds":18000,"reset_at":1779597324},
261                "secondary_window":{"used_percent":0,"limit_window_seconds":604800,"reset_at":1780184124}
262            }}"#,
263            )
264            .create_async()
265            .await;
266        let (_td, cache) = cache_fixture();
267        let creds = future_creds();
268        let client = reqwest::Client::new();
269        let endpoints = Endpoints {
270            usage: format!("{}/backend-api/wham/usage", server.url()),
271            token: format!("{}/oauth/token", server.url()),
272        };
273        let out = fetch_snapshot(
274            &client,
275            creds.path(),
276            &cache,
277            &endpoints,
278            Duration::from_secs(0),
279        )
280        .await
281        .unwrap();
282        assert_eq!(out.snapshot.plan, "ChatGPT Plus");
283        assert_eq!(out.snapshot.session.utilization_pct, 1);
284        assert!(!out.stale);
285    }
286
287    #[tokio::test]
288    async fn http_500_falls_back_to_cache_when_present() {
289        let mut server = mockito::Server::new_async().await;
290        server
291            .mock("GET", "/backend-api/wham/usage")
292            .with_status(500)
293            .with_body(r#"{"error":{"message":"upstream"}}"#)
294            .create_async()
295            .await;
296        let (_td, cache) = cache_fixture();
297        cache
298            .write_payload(
299                br#"{"plan_type":"pro","rate_limit":{"primary_window":{"used_percent":50,"limit_window_seconds":18000}}}"#,
300            )
301            .unwrap();
302        let creds = future_creds();
303        let client = reqwest::Client::new();
304        let endpoints = Endpoints {
305            usage: format!("{}/backend-api/wham/usage", server.url()),
306            token: format!("{}/oauth/token", server.url()),
307        };
308        let out = fetch_snapshot(
309            &client,
310            creds.path(),
311            &cache,
312            &endpoints,
313            Duration::from_secs(0),
314        )
315        .await
316        .unwrap();
317        assert!(out.stale);
318        assert_eq!(out.snapshot.session.utilization_pct, 50);
319        assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(500));
320    }
321}