1use std::time::Duration;
4
5use crate::cache::{Cache, acquire_lock};
6use crate::error::{AppError, Result};
7use crate::usage::DeepseekSnapshot;
8
9use super::types::BalanceResponse;
10
11pub const BASE_URL: &str = "https://api.deepseek.com";
12const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
13const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
14
15#[derive(Debug, Clone)]
16pub struct Endpoints {
17 pub balance: String,
18}
19
20impl Default for Endpoints {
21 fn default() -> Self {
22 Self {
23 balance: format!("{BASE_URL}/user/balance"),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct FetchOutcome {
30 pub snapshot: DeepseekSnapshot,
31 pub stale: bool,
32 pub last_error: Option<(u16, String)>,
33 pub cache_age: Option<Duration>,
34}
35
36pub async fn fetch_snapshot(
37 client: &reqwest::Client,
38 api_key: &str,
39 cache: &Cache,
40 endpoints: &Endpoints,
41 cache_ttl: Duration,
42) -> Result<FetchOutcome> {
43 cache.ensure_dir()?;
44 let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
45
46 if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
47 return Ok(reuse_cache(bytes, cache, false));
48 }
49
50 match fetch_live(client, &endpoints.balance, api_key).await {
51 Ok(snap) => {
52 let bytes = serde_json::to_vec(&snap_to_json(&snap)).unwrap_or_default();
53 cache.write_payload(&bytes)?;
54 Ok(FetchOutcome {
55 snapshot: snap,
56 stale: false,
57 last_error: None,
58 cache_age: Some(Duration::ZERO),
59 })
60 }
61 Err(e) if e.is_transient() => fallback_silent(cache),
62 Err(AppError::Http { status, body }) => {
63 cache.mark_stale();
64 cache.write_last_error(status, &body);
65 fallback_with_error(cache, Some((status, body)))
66 }
67 Err(e) => {
68 cache.mark_stale();
69 cache.write_last_error(0, &e.to_string());
70 fallback_with_error(cache, Some((0, e.to_string())))
71 }
72 }
73}
74
75fn fallback_silent(cache: &Cache) -> Result<FetchOutcome> {
76 let Some(bytes) = cache.maybe_payload()? else {
77 return Err(AppError::Transport(
78 "deepseek: no cache and network unreachable".into(),
79 ));
80 };
81 Ok(reuse_cache(bytes, cache, true))
82}
83
84fn fallback_with_error(cache: &Cache, last_error: Option<(u16, String)>) -> Result<FetchOutcome> {
85 let Some(bytes) = cache.maybe_payload()? else {
86 return Err(AppError::Other("deepseek: no usable cache".into()));
87 };
88 let mut outcome = reuse_cache(bytes, cache, true);
89 outcome.last_error = last_error;
90 Ok(outcome)
91}
92
93fn reuse_cache(bytes: Vec<u8>, cache: &Cache, stale: bool) -> FetchOutcome {
94 let snap = parse_cache(&bytes).unwrap_or_default();
95 FetchOutcome {
96 snapshot: snap,
97 stale,
98 last_error: cache.read_last_error(),
99 cache_age: cache.payload_age(),
100 }
101}
102
103fn parse_cache(bytes: &[u8]) -> Result<DeepseekSnapshot> {
104 let v: serde_json::Value = serde_json::from_slice(bytes)?;
105 Ok(DeepseekSnapshot {
106 is_available: v["is_available"].as_bool().unwrap_or(false),
107 balance: v["balance"].as_f64().unwrap_or(0.0),
108 granted: v["granted"].as_f64().unwrap_or(0.0),
109 topped_up: v["topped_up"].as_f64().unwrap_or(0.0),
110 currency: v["currency"].as_str().unwrap_or("").to_string(),
111 })
112}
113
114fn snap_to_json(snap: &DeepseekSnapshot) -> serde_json::Value {
115 serde_json::json!({
116 "is_available": snap.is_available,
117 "balance": snap.balance,
118 "granted": snap.granted,
119 "topped_up": snap.topped_up,
120 "currency": snap.currency,
121 })
122}
123
124async fn fetch_live(
125 client: &reqwest::Client,
126 url: &str,
127 api_key: &str,
128) -> Result<DeepseekSnapshot> {
129 let resp = tokio::time::timeout(
130 HTTP_TIMEOUT,
131 client
132 .get(url)
133 .header("Authorization", format!("Bearer {api_key}"))
134 .header("Accept", "application/json")
135 .send(),
136 )
137 .await
138 .map_err(|_| AppError::Transport(format!("deepseek timeout: {url}")))??;
139
140 let status = resp.status();
141 let bytes = resp.bytes().await?;
142
143 if !status.is_success() {
144 let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
145 return Err(AppError::Http {
146 status: status.as_u16(),
147 body,
148 });
149 }
150
151 let r: BalanceResponse = serde_json::from_slice(&bytes)
152 .map_err(|e| AppError::Schema(format!("deepseek balance response: {e}")))?;
153 Ok(r.into_snapshot())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use tempfile::TempDir;
160
161 fn cache_fixture() -> (TempDir, Cache) {
162 let td = TempDir::new().unwrap();
163 let cache = Cache::at(td.path().join("deepseek"));
164 cache.ensure_dir().unwrap();
165 (td, cache)
166 }
167
168 #[tokio::test]
169 async fn live_200_returns_snapshot() {
170 let mut server = mockito::Server::new_async().await;
171 server
172 .mock("GET", "/user/balance")
173 .with_status(200)
174 .with_body(r#"{
175 "is_available": true,
176 "balance_infos": [
177 {"currency": "USD", "total_balance": "5.00", "granted_balance": "5.00", "topped_up_balance": "0.00"}
178 ]
179 }"#)
180 .create_async()
181 .await;
182
183 let (_td, cache) = cache_fixture();
184 let client = reqwest::Client::new();
185 let endpoints = Endpoints {
186 balance: format!("{}/user/balance", server.url()),
187 };
188 let out = fetch_snapshot(
189 &client,
190 "sk-test",
191 &cache,
192 &endpoints,
193 Duration::from_secs(0),
194 )
195 .await
196 .unwrap();
197 assert!(out.snapshot.is_available);
198 assert!((out.snapshot.balance - 5.0).abs() < 1e-9);
199 assert_eq!(out.snapshot.currency, "USD");
200 assert!(!out.stale);
201 }
202
203 #[tokio::test]
204 async fn http_401_falls_back_to_cache() {
205 let mut server = mockito::Server::new_async().await;
206 server
207 .mock("GET", "/user/balance")
208 .with_status(401)
209 .with_body(r#"{"error": "invalid api key"}"#)
210 .create_async()
211 .await;
212
213 let (_td, cache) = cache_fixture();
214 let seed = serde_json::json!({
215 "is_available": true,
216 "balance": 3.0,
217 "granted": 3.0,
218 "topped_up": 0.0,
219 "currency": "USD"
220 });
221 cache.write_payload(seed.to_string().as_bytes()).unwrap();
222
223 let client = reqwest::Client::new();
224 let endpoints = Endpoints {
225 balance: format!("{}/user/balance", server.url()),
226 };
227 let out = fetch_snapshot(
228 &client,
229 "bad-key",
230 &cache,
231 &endpoints,
232 Duration::from_secs(0),
233 )
234 .await
235 .unwrap();
236 assert!(out.stale);
237 assert!((out.snapshot.balance - 3.0).abs() < 1e-9);
238 assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
239 }
240}