Skip to main content

ai_usagebar/deepseek/
fetch.rs

1//! Fetch DeepSeek usage from `/user/balance`.
2
3use std::time::Duration;
4
5use crate::cache::{Cache, acquire_lock};
6use crate::error::{AppError, Result};
7use crate::usage::DeepseekSnapshot;
8
9use super::types::BalanceResponse;
10
11pub const BASE_URL: &str = "https://api.deepseek.com";
12const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
13const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
14
15#[derive(Debug, Clone)]
16pub struct Endpoints {
17    pub balance: String,
18}
19
20impl Default for Endpoints {
21    fn default() -> Self {
22        Self {
23            balance: format!("{BASE_URL}/user/balance"),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct FetchOutcome {
30    pub snapshot: DeepseekSnapshot,
31    pub stale: bool,
32    pub last_error: Option<(u16, String)>,
33    pub cache_age: Option<Duration>,
34}
35
36pub async fn fetch_snapshot(
37    client: &reqwest::Client,
38    api_key: &str,
39    cache: &Cache,
40    endpoints: &Endpoints,
41    cache_ttl: Duration,
42) -> Result<FetchOutcome> {
43    cache.ensure_dir()?;
44    let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
45
46    if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
47        return Ok(reuse_cache(bytes, cache, false));
48    }
49
50    match fetch_live(client, &endpoints.balance, api_key).await {
51        Ok(snap) => {
52            let bytes = serde_json::to_vec(&snap_to_json(&snap)).unwrap_or_default();
53            cache.write_payload(&bytes)?;
54            Ok(FetchOutcome {
55                snapshot: snap,
56                stale: false,
57                last_error: None,
58                cache_age: Some(Duration::ZERO),
59            })
60        }
61        Err(e) if e.is_transient() => fallback_silent(cache),
62        Err(AppError::Http { status, body }) => {
63            cache.mark_stale();
64            cache.write_last_error(status, &body);
65            fallback_with_error(cache, Some((status, body)))
66        }
67        Err(e) => {
68            cache.mark_stale();
69            cache.write_last_error(0, &e.to_string());
70            fallback_with_error(cache, Some((0, e.to_string())))
71        }
72    }
73}
74
75fn fallback_silent(cache: &Cache) -> Result<FetchOutcome> {
76    let Some(bytes) = cache.maybe_payload()? else {
77        return Err(AppError::Transport(
78            "deepseek: no cache and network unreachable".into(),
79        ));
80    };
81    Ok(reuse_cache(bytes, cache, true))
82}
83
84fn fallback_with_error(cache: &Cache, last_error: Option<(u16, String)>) -> Result<FetchOutcome> {
85    let Some(bytes) = cache.maybe_payload()? else {
86        return Err(AppError::Other("deepseek: no usable cache".into()));
87    };
88    let mut outcome = reuse_cache(bytes, cache, true);
89    outcome.last_error = last_error;
90    Ok(outcome)
91}
92
93fn reuse_cache(bytes: Vec<u8>, cache: &Cache, stale: bool) -> FetchOutcome {
94    let snap = parse_cache(&bytes).unwrap_or_default();
95    FetchOutcome {
96        snapshot: snap,
97        stale,
98        last_error: cache.read_last_error(),
99        cache_age: cache.payload_age(),
100    }
101}
102
103fn parse_cache(bytes: &[u8]) -> Result<DeepseekSnapshot> {
104    let v: serde_json::Value = serde_json::from_slice(bytes)?;
105    Ok(DeepseekSnapshot {
106        is_available: v["is_available"].as_bool().unwrap_or(false),
107        balance: v["balance"].as_f64().unwrap_or(0.0),
108        granted: v["granted"].as_f64().unwrap_or(0.0),
109        topped_up: v["topped_up"].as_f64().unwrap_or(0.0),
110        currency: v["currency"].as_str().unwrap_or("").to_string(),
111    })
112}
113
114fn snap_to_json(snap: &DeepseekSnapshot) -> serde_json::Value {
115    serde_json::json!({
116        "is_available": snap.is_available,
117        "balance": snap.balance,
118        "granted": snap.granted,
119        "topped_up": snap.topped_up,
120        "currency": snap.currency,
121    })
122}
123
124async fn fetch_live(
125    client: &reqwest::Client,
126    url: &str,
127    api_key: &str,
128) -> Result<DeepseekSnapshot> {
129    let resp = tokio::time::timeout(
130        HTTP_TIMEOUT,
131        client
132            .get(url)
133            .header("Authorization", format!("Bearer {api_key}"))
134            .header("Accept", "application/json")
135            .send(),
136    )
137    .await
138    .map_err(|_| AppError::Transport(format!("deepseek timeout: {url}")))??;
139
140    let status = resp.status();
141    let bytes = resp.bytes().await?;
142
143    if !status.is_success() {
144        let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
145        return Err(AppError::Http {
146            status: status.as_u16(),
147            body,
148        });
149    }
150
151    let r: BalanceResponse = serde_json::from_slice(&bytes)
152        .map_err(|e| AppError::Schema(format!("deepseek balance response: {e}")))?;
153    Ok(r.into_snapshot())
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use tempfile::TempDir;
160
161    fn cache_fixture() -> (TempDir, Cache) {
162        let td = TempDir::new().unwrap();
163        let cache = Cache::at(td.path().join("deepseek"));
164        cache.ensure_dir().unwrap();
165        (td, cache)
166    }
167
168    #[tokio::test]
169    async fn live_200_returns_snapshot() {
170        let mut server = mockito::Server::new_async().await;
171        server
172            .mock("GET", "/user/balance")
173            .with_status(200)
174            .with_body(r#"{
175                "is_available": true,
176                "balance_infos": [
177                    {"currency": "USD", "total_balance": "5.00", "granted_balance": "5.00", "topped_up_balance": "0.00"}
178                ]
179            }"#)
180            .create_async()
181            .await;
182
183        let (_td, cache) = cache_fixture();
184        let client = reqwest::Client::new();
185        let endpoints = Endpoints {
186            balance: format!("{}/user/balance", server.url()),
187        };
188        let out = fetch_snapshot(
189            &client,
190            "sk-test",
191            &cache,
192            &endpoints,
193            Duration::from_secs(0),
194        )
195        .await
196        .unwrap();
197        assert!(out.snapshot.is_available);
198        assert!((out.snapshot.balance - 5.0).abs() < 1e-9);
199        assert_eq!(out.snapshot.currency, "USD");
200        assert!(!out.stale);
201    }
202
203    #[tokio::test]
204    async fn http_401_falls_back_to_cache() {
205        let mut server = mockito::Server::new_async().await;
206        server
207            .mock("GET", "/user/balance")
208            .with_status(401)
209            .with_body(r#"{"error": "invalid api key"}"#)
210            .create_async()
211            .await;
212
213        let (_td, cache) = cache_fixture();
214        let seed = serde_json::json!({
215            "is_available": true,
216            "balance": 3.0,
217            "granted": 3.0,
218            "topped_up": 0.0,
219            "currency": "USD"
220        });
221        cache.write_payload(seed.to_string().as_bytes()).unwrap();
222
223        let client = reqwest::Client::new();
224        let endpoints = Endpoints {
225            balance: format!("{}/user/balance", server.url()),
226        };
227        let out = fetch_snapshot(
228            &client,
229            "bad-key",
230            &cache,
231            &endpoints,
232            Duration::from_secs(0),
233        )
234        .await
235        .unwrap();
236        assert!(out.stale);
237        assert!((out.snapshot.balance - 3.0).abs() < 1e-9);
238        assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
239    }
240}