1use std::time::Duration;
6
7use crate::cache::{Cache, acquire_lock};
8use crate::error::{AppError, Result};
9use crate::usage::ZaiSnapshot;
10
11use super::types::Envelope;
12
13pub const QUOTA_URL: &str = "https://api.z.ai/api/monitor/usage/quota/limit";
14const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
15const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
16
17#[derive(Debug, Clone)]
18pub struct Endpoints {
19 pub quota: String,
20}
21
22impl Default for Endpoints {
23 fn default() -> Self {
24 Self {
25 quota: QUOTA_URL.into(),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
31pub struct FetchOutcome {
32 pub snapshot: ZaiSnapshot,
33 pub stale: bool,
34 pub last_error: Option<(u16, String)>,
35 pub cache_age: Option<Duration>,
36}
37
38pub async fn fetch_snapshot(
39 client: &reqwest::Client,
40 api_key: &str,
41 cache: &Cache,
42 endpoints: &Endpoints,
43 cache_ttl: Duration,
44 config_plan_tier: Option<&str>,
45) -> Result<FetchOutcome> {
46 cache.ensure_dir()?;
47 let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
48
49 if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
50 return Ok(reuse(bytes, cache, false, config_plan_tier));
51 }
52
53 match fetch_live(client, &endpoints.quota, api_key).await {
54 Ok(bytes) => {
55 cache.write_payload(&bytes)?;
56 let env: Envelope = serde_json::from_slice(&bytes)?;
57 Ok(FetchOutcome {
58 snapshot: env.into_snapshot(config_plan_tier),
59 stale: false,
60 last_error: None,
61 cache_age: Some(Duration::ZERO),
62 })
63 }
64 Err(e) if e.is_transient() => fallback_silent(cache, config_plan_tier),
65 Err(AppError::Http { status, body }) => {
66 cache.mark_stale();
67 cache.write_last_error(status, &body);
68 fallback_with_error(cache, Some((status, body)), config_plan_tier)
69 }
70 Err(e) => {
71 cache.mark_stale();
72 cache.write_last_error(0, &e.to_string());
73 fallback_with_error(cache, Some((0, e.to_string())), config_plan_tier)
74 }
75 }
76}
77
78fn reuse(bytes: Vec<u8>, cache: &Cache, stale: bool, tier: Option<&str>) -> FetchOutcome {
79 let snapshot = serde_json::from_slice::<Envelope>(&bytes)
80 .map(|e| e.into_snapshot(tier))
81 .unwrap_or_else(|_| ZaiSnapshot {
82 plan: "GLM Coding Unknown".into(),
83 session: None,
84 weekly: None,
85 mcp: None,
86 });
87 FetchOutcome {
88 snapshot,
89 stale,
90 last_error: cache.read_last_error(),
91 cache_age: cache.payload_age(),
92 }
93}
94
95fn fallback_silent(cache: &Cache, tier: Option<&str>) -> Result<FetchOutcome> {
96 let Some(bytes) = cache.maybe_payload()? else {
97 return Err(AppError::Transport(
98 "zai: no cache and network unreachable".into(),
99 ));
100 };
101 Ok(reuse(bytes, cache, true, tier))
102}
103
104fn fallback_with_error(
105 cache: &Cache,
106 last_error: Option<(u16, String)>,
107 tier: Option<&str>,
108) -> Result<FetchOutcome> {
109 let Some(bytes) = cache.maybe_payload()? else {
110 return Err(AppError::Other("zai: no usable cache".into()));
111 };
112 let mut out = reuse(bytes, cache, true, tier);
113 out.last_error = last_error;
114 Ok(out)
115}
116
117async fn fetch_live(client: &reqwest::Client, url: &str, api_key: &str) -> Result<Vec<u8>> {
118 let resp = tokio::time::timeout(
119 HTTP_TIMEOUT,
120 client
121 .get(url)
122 .header("Authorization", api_key) .header("Accept-Language", "en-US,en")
124 .header("Content-Type", "application/json")
125 .send(),
126 )
127 .await
128 .map_err(|_| AppError::Transport(format!("zai timeout: {url}")))??;
129
130 let status = resp.status();
131 let bytes = resp.bytes().await?.to_vec();
132
133 if !status.is_success() {
134 let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
135 return Err(AppError::Http {
136 status: status.as_u16(),
137 body,
138 });
139 }
140
141 let _: Envelope = serde_json::from_slice(&bytes)
143 .map_err(|e| AppError::Schema(format!("zai quota response: {e}")))?;
144 Ok(bytes)
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use tempfile::TempDir;
151
152 fn cache_fixture() -> (TempDir, Cache) {
153 let td = TempDir::new().unwrap();
154 let cache = Cache::at(td.path().join("zai"));
155 cache.ensure_dir().unwrap();
156 (td, cache)
157 }
158
159 #[tokio::test]
160 async fn live_200_parses_real_shape() {
161 let mut server = mockito::Server::new_async().await;
162 server
163 .mock("GET", "/api/monitor/usage/quota/limit")
164 .with_status(200)
165 .with_body(
166 r#"{"code":200,"msg":"Operation successful","data":{
167 "limits":[
168 {"type":"TOKENS_LIMIT","unit":3,"number":5,"percentage":42},
169 {"type":"TOKENS_LIMIT","unit":6,"number":1,"percentage":15,"nextResetTime":1779792169974}
170 ],"level":"pro"
171 },"success":true}"#,
172 )
173 .create_async()
174 .await;
175
176 let (_td, cache) = cache_fixture();
177 let client = reqwest::Client::new();
178 let endpoints = Endpoints {
179 quota: format!("{}/api/monitor/usage/quota/limit", server.url()),
180 };
181 let out = fetch_snapshot(
182 &client,
183 "fake-key",
184 &cache,
185 &endpoints,
186 Duration::from_secs(0),
187 None,
188 )
189 .await
190 .unwrap();
191 assert_eq!(out.snapshot.plan, "GLM Coding Pro");
192 assert_eq!(out.snapshot.session.as_ref().unwrap().utilization_pct, 42);
193 assert_eq!(out.snapshot.weekly.as_ref().unwrap().utilization_pct, 15);
194 }
195
196 #[tokio::test]
197 async fn http_401_falls_back_to_cache_when_present() {
198 let mut server = mockito::Server::new_async().await;
199 server
200 .mock("GET", "/api/monitor/usage/quota/limit")
201 .with_status(401)
202 .with_body(r#"{"code":401,"msg":"Unauthorized"}"#)
203 .create_async()
204 .await;
205
206 let (_td, cache) = cache_fixture();
207 let seed = r#"{"code":200,"data":{"limits":[
208 {"type":"TOKENS_LIMIT","percentage":10}
209 ],"level":"lite"},"success":true}"#;
210 cache.write_payload(seed.as_bytes()).unwrap();
211
212 let client = reqwest::Client::new();
213 let endpoints = Endpoints {
214 quota: format!("{}/api/monitor/usage/quota/limit", server.url()),
215 };
216 let out = fetch_snapshot(
217 &client,
218 "k",
219 &cache,
220 &endpoints,
221 Duration::from_secs(0),
222 None,
223 )
224 .await
225 .unwrap();
226 assert!(out.stale);
227 assert_eq!(out.snapshot.session.as_ref().unwrap().utilization_pct, 10);
228 assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
229 }
230}