Skip to main content

ai_usagebar/zai/
fetch.rs

1//! Z.AI fetch. Note the auth-header quirk — the API key is passed as
2//! `Authorization: <KEY>` WITHOUT the `Bearer` prefix. Sending `Bearer …`
3//! returns 401.
4
5use std::time::Duration;
6
7use crate::cache::{Cache, acquire_lock};
8use crate::error::{AppError, Result};
9use crate::usage::ZaiSnapshot;
10
11use super::types::Envelope;
12
13pub const QUOTA_URL: &str = "https://api.z.ai/api/monitor/usage/quota/limit";
14const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
15const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
16
17#[derive(Debug, Clone)]
18pub struct Endpoints {
19    pub quota: String,
20}
21
22impl Default for Endpoints {
23    fn default() -> Self {
24        Self {
25            quota: QUOTA_URL.into(),
26        }
27    }
28}
29
30#[derive(Debug, Clone)]
31pub struct FetchOutcome {
32    pub snapshot: ZaiSnapshot,
33    pub stale: bool,
34    pub last_error: Option<(u16, String)>,
35    pub cache_age: Option<Duration>,
36}
37
38pub async fn fetch_snapshot(
39    client: &reqwest::Client,
40    api_key: &str,
41    cache: &Cache,
42    endpoints: &Endpoints,
43    cache_ttl: Duration,
44    config_plan_tier: Option<&str>,
45) -> Result<FetchOutcome> {
46    cache.ensure_dir()?;
47    let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
48
49    if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
50        return Ok(reuse(bytes, cache, false, config_plan_tier));
51    }
52
53    match fetch_live(client, &endpoints.quota, api_key).await {
54        Ok(bytes) => {
55            cache.write_payload(&bytes)?;
56            let env: Envelope = serde_json::from_slice(&bytes)?;
57            Ok(FetchOutcome {
58                snapshot: env.into_snapshot(config_plan_tier),
59                stale: false,
60                last_error: None,
61                cache_age: Some(Duration::ZERO),
62            })
63        }
64        Err(e) if e.is_transient() => fallback_silent(cache, config_plan_tier),
65        Err(AppError::Http { status, body }) => {
66            cache.mark_stale();
67            cache.write_last_error(status, &body);
68            fallback_with_error(cache, Some((status, body)), config_plan_tier)
69        }
70        Err(e) => {
71            cache.mark_stale();
72            cache.write_last_error(0, &e.to_string());
73            fallback_with_error(cache, Some((0, e.to_string())), config_plan_tier)
74        }
75    }
76}
77
78fn reuse(bytes: Vec<u8>, cache: &Cache, stale: bool, tier: Option<&str>) -> FetchOutcome {
79    let snapshot = serde_json::from_slice::<Envelope>(&bytes)
80        .map(|e| e.into_snapshot(tier))
81        .unwrap_or_else(|_| ZaiSnapshot {
82            plan: "GLM Coding Unknown".into(),
83            session: None,
84            weekly: None,
85            mcp: None,
86        });
87    FetchOutcome {
88        snapshot,
89        stale,
90        last_error: cache.read_last_error(),
91        cache_age: cache.payload_age(),
92    }
93}
94
95fn fallback_silent(cache: &Cache, tier: Option<&str>) -> Result<FetchOutcome> {
96    let Some(bytes) = cache.maybe_payload()? else {
97        return Err(AppError::Transport(
98            "zai: no cache and network unreachable".into(),
99        ));
100    };
101    Ok(reuse(bytes, cache, true, tier))
102}
103
104fn fallback_with_error(
105    cache: &Cache,
106    last_error: Option<(u16, String)>,
107    tier: Option<&str>,
108) -> Result<FetchOutcome> {
109    let Some(bytes) = cache.maybe_payload()? else {
110        return Err(AppError::Other("zai: no usable cache".into()));
111    };
112    let mut out = reuse(bytes, cache, true, tier);
113    out.last_error = last_error;
114    Ok(out)
115}
116
117async fn fetch_live(client: &reqwest::Client, url: &str, api_key: &str) -> Result<Vec<u8>> {
118    let resp = tokio::time::timeout(
119        HTTP_TIMEOUT,
120        client
121            .get(url)
122            .header("Authorization", api_key) // NO `Bearer ` prefix.
123            .header("Accept-Language", "en-US,en")
124            .header("Content-Type", "application/json")
125            .send(),
126    )
127    .await
128    .map_err(|_| AppError::Transport(format!("zai timeout: {url}")))??;
129
130    let status = resp.status();
131    let bytes = resp.bytes().await?.to_vec();
132
133    if !status.is_success() {
134        let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
135        return Err(AppError::Http {
136            status: status.as_u16(),
137            body,
138        });
139    }
140
141    // Sanity check we got a valid envelope. Schema drift surfaces here.
142    let _: Envelope = serde_json::from_slice(&bytes)
143        .map_err(|e| AppError::Schema(format!("zai quota response: {e}")))?;
144    Ok(bytes)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use tempfile::TempDir;
151
152    fn cache_fixture() -> (TempDir, Cache) {
153        let td = TempDir::new().unwrap();
154        let cache = Cache::at(td.path().join("zai"));
155        cache.ensure_dir().unwrap();
156        (td, cache)
157    }
158
159    #[tokio::test]
160    async fn live_200_parses_real_shape() {
161        let mut server = mockito::Server::new_async().await;
162        server
163            .mock("GET", "/api/monitor/usage/quota/limit")
164            .with_status(200)
165            .with_body(
166                r#"{"code":200,"msg":"Operation successful","data":{
167                    "limits":[
168                        {"type":"TOKENS_LIMIT","unit":3,"number":5,"percentage":42},
169                        {"type":"TOKENS_LIMIT","unit":6,"number":1,"percentage":15,"nextResetTime":1779792169974}
170                    ],"level":"pro"
171                },"success":true}"#,
172            )
173            .create_async()
174            .await;
175
176        let (_td, cache) = cache_fixture();
177        let client = reqwest::Client::new();
178        let endpoints = Endpoints {
179            quota: format!("{}/api/monitor/usage/quota/limit", server.url()),
180        };
181        let out = fetch_snapshot(
182            &client,
183            "fake-key",
184            &cache,
185            &endpoints,
186            Duration::from_secs(0),
187            None,
188        )
189        .await
190        .unwrap();
191        assert_eq!(out.snapshot.plan, "GLM Coding Pro");
192        assert_eq!(out.snapshot.session.as_ref().unwrap().utilization_pct, 42);
193        assert_eq!(out.snapshot.weekly.as_ref().unwrap().utilization_pct, 15);
194    }
195
196    #[tokio::test]
197    async fn http_401_falls_back_to_cache_when_present() {
198        let mut server = mockito::Server::new_async().await;
199        server
200            .mock("GET", "/api/monitor/usage/quota/limit")
201            .with_status(401)
202            .with_body(r#"{"code":401,"msg":"Unauthorized"}"#)
203            .create_async()
204            .await;
205
206        let (_td, cache) = cache_fixture();
207        let seed = r#"{"code":200,"data":{"limits":[
208            {"type":"TOKENS_LIMIT","percentage":10}
209        ],"level":"lite"},"success":true}"#;
210        cache.write_payload(seed.as_bytes()).unwrap();
211
212        let client = reqwest::Client::new();
213        let endpoints = Endpoints {
214            quota: format!("{}/api/monitor/usage/quota/limit", server.url()),
215        };
216        let out = fetch_snapshot(
217            &client,
218            "k",
219            &cache,
220            &endpoints,
221            Duration::from_secs(0),
222            None,
223        )
224        .await
225        .unwrap();
226        assert!(out.stale);
227        assert_eq!(out.snapshot.session.as_ref().unwrap().utilization_pct, 10);
228        assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
229    }
230}