Skip to main content

ai_usagebar/openrouter/
fetch.rs

1//! OpenRouter fetch — combines `/api/v1/credits` and `/api/v1/key` under
2//! the shared cache + flock primitives.
3
4use std::time::Duration;
5
6use crate::cache::{Cache, acquire_lock};
7use crate::error::{AppError, Result};
8use crate::usage::OpenRouterSnapshot;
9
10use super::types::{CreditsData, KeyData, OrEnvelope, combine};
11
12pub const BASE_URL: &str = "https://openrouter.ai/api/v1";
13const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
14const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
15
16#[derive(Debug, Clone)]
17pub struct Endpoints {
18    pub credits: String,
19    pub key: String,
20}
21
22impl Default for Endpoints {
23    fn default() -> Self {
24        Self {
25            credits: format!("{BASE_URL}/credits"),
26            key: format!("{BASE_URL}/key"),
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct FetchOutcome {
33    pub snapshot: OpenRouterSnapshot,
34    pub stale: bool,
35    pub last_error: Option<(u16, String)>,
36    pub cache_age: Option<Duration>,
37}
38
39/// Cache-aware fetch. Mirrors `anthropic::fetch::fetch_snapshot` semantics:
40/// fresh cache short-circuits; on failure, fall back to cache + mark stale.
41pub async fn fetch_snapshot(
42    client: &reqwest::Client,
43    api_key: &str,
44    cache: &Cache,
45    endpoints: &Endpoints,
46    cache_ttl: Duration,
47) -> Result<FetchOutcome> {
48    cache.ensure_dir()?;
49    let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
50
51    if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
52        return Ok(reuse_cache(bytes, cache, false));
53    }
54
55    match fetch_live(client, endpoints, api_key).await {
56        Ok((credits, key)) => {
57            let snap = combine(credits, key);
58            // Serialize back to JSON for the cache.
59            let cache_repr = serde_json::json!({
60                "snapshot": serde_repr(&snap),
61            });
62            let bytes = serde_json::to_vec(&cache_repr).unwrap_or_default();
63            cache.write_payload(&bytes)?;
64            Ok(FetchOutcome {
65                snapshot: snap,
66                stale: false,
67                last_error: None,
68                cache_age: Some(Duration::ZERO),
69            })
70        }
71        Err(e) if e.is_transient() => fallback_silent(cache),
72        Err(AppError::Http { status, body }) => {
73            cache.mark_stale();
74            cache.write_last_error(status, &body);
75            fallback_with_error(cache, Some((status, body)))
76        }
77        Err(e) => {
78            cache.mark_stale();
79            cache.write_last_error(0, &e.to_string());
80            fallback_with_error(cache, Some((0, e.to_string())))
81        }
82    }
83}
84
85fn fallback_silent(cache: &Cache) -> Result<FetchOutcome> {
86    let Some(bytes) = cache.maybe_payload()? else {
87        return Err(AppError::Transport(
88            "openrouter: no cache and network unreachable".into(),
89        ));
90    };
91    Ok(reuse_cache(bytes, cache, true))
92}
93
94fn fallback_with_error(cache: &Cache, last_error: Option<(u16, String)>) -> Result<FetchOutcome> {
95    let Some(bytes) = cache.maybe_payload()? else {
96        return Err(AppError::Other("openrouter: no usable cache".into()));
97    };
98    let mut outcome = reuse_cache(bytes, cache, true);
99    outcome.last_error = last_error;
100    Ok(outcome)
101}
102
103fn reuse_cache(bytes: Vec<u8>, cache: &Cache, stale: bool) -> FetchOutcome {
104    let snap = parse_cache(&bytes).unwrap_or_else(|_| OpenRouterSnapshot {
105        label: "OpenRouter".into(),
106        total_credits: 0.0,
107        total_usage: 0.0,
108        usage_daily: 0.0,
109        usage_weekly: 0.0,
110        usage_monthly: 0.0,
111        is_free_tier: true,
112        limit: None,
113        limit_remaining: None,
114    });
115    FetchOutcome {
116        snapshot: snap,
117        stale,
118        last_error: cache.read_last_error(),
119        cache_age: cache.payload_age(),
120    }
121}
122
123fn parse_cache(bytes: &[u8]) -> Result<OpenRouterSnapshot> {
124    let v: serde_json::Value = serde_json::from_slice(bytes)?;
125    let s = v
126        .get("snapshot")
127        .ok_or_else(|| AppError::Schema("openrouter cache missing 'snapshot' field".into()))?;
128    Ok(OpenRouterSnapshot {
129        label: s["label"].as_str().unwrap_or("OpenRouter").to_string(),
130        total_credits: s["total_credits"].as_f64().unwrap_or(0.0),
131        total_usage: s["total_usage"].as_f64().unwrap_or(0.0),
132        usage_daily: s["usage_daily"].as_f64().unwrap_or(0.0),
133        usage_weekly: s["usage_weekly"].as_f64().unwrap_or(0.0),
134        usage_monthly: s["usage_monthly"].as_f64().unwrap_or(0.0),
135        is_free_tier: s["is_free_tier"].as_bool().unwrap_or(false),
136        limit: s["limit"].as_f64(),
137        limit_remaining: s["limit_remaining"].as_f64(),
138    })
139}
140
141fn serde_repr(snap: &OpenRouterSnapshot) -> serde_json::Value {
142    serde_json::json!({
143        "label": snap.label,
144        "total_credits": snap.total_credits,
145        "total_usage": snap.total_usage,
146        "usage_daily": snap.usage_daily,
147        "usage_weekly": snap.usage_weekly,
148        "usage_monthly": snap.usage_monthly,
149        "is_free_tier": snap.is_free_tier,
150        "limit": snap.limit,
151        "limit_remaining": snap.limit_remaining,
152    })
153}
154
155async fn fetch_live(
156    client: &reqwest::Client,
157    endpoints: &Endpoints,
158    api_key: &str,
159) -> Result<(CreditsData, KeyData)> {
160    // Fetch in parallel.
161    let credits_fut = fetch_one::<CreditsData>(client, &endpoints.credits, api_key);
162    let key_fut = fetch_one::<KeyData>(client, &endpoints.key, api_key);
163    let (credits, key) = tokio::join!(credits_fut, key_fut);
164    Ok((credits?, key?))
165}
166
167async fn fetch_one<T: for<'de> serde::Deserialize<'de>>(
168    client: &reqwest::Client,
169    url: &str,
170    api_key: &str,
171) -> Result<T> {
172    let resp = tokio::time::timeout(
173        HTTP_TIMEOUT,
174        client
175            .get(url)
176            .header("Authorization", format!("Bearer {api_key}"))
177            .send(),
178    )
179    .await
180    .map_err(|_| AppError::Transport(format!("openrouter timeout: {url}")))??;
181
182    let status = resp.status();
183    let bytes = resp.bytes().await?;
184
185    if !status.is_success() {
186        let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
187        return Err(AppError::Http {
188            status: status.as_u16(),
189            body,
190        });
191    }
192    let env: OrEnvelope<T> = serde_json::from_slice(&bytes)
193        .map_err(|e| AppError::Schema(format!("openrouter {url}: {e}")))?;
194    Ok(env.data)
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use tempfile::TempDir;
201
202    fn cache_fixture() -> (TempDir, Cache) {
203        let td = TempDir::new().unwrap();
204        let cache = Cache::at(td.path().join("openrouter"));
205        cache.ensure_dir().unwrap();
206        (td, cache)
207    }
208
209    #[tokio::test]
210    async fn live_fetch_combines_both_endpoints() {
211        let mut server = mockito::Server::new_async().await;
212        server
213            .mock("GET", "/api/v1/credits")
214            .with_status(200)
215            .with_body(r#"{"data":{"total_credits":100.0,"total_usage":25.5}}"#)
216            .create_async()
217            .await;
218        server
219            .mock("GET", "/api/v1/key")
220            .with_status(200)
221            .with_body(
222                r#"{"data":{"label":"prod","limit":50.0,"limit_remaining":24.5,
223                "usage":25.5,"usage_daily":1.0,"usage_weekly":7.0,"usage_monthly":25.5,
224                "is_free_tier":false}}"#,
225            )
226            .create_async()
227            .await;
228
229        let (_td, cache) = cache_fixture();
230        let client = reqwest::Client::new();
231        let endpoints = Endpoints {
232            credits: format!("{}/api/v1/credits", server.url()),
233            key: format!("{}/api/v1/key", server.url()),
234        };
235        let out = fetch_snapshot(
236            &client,
237            "sk-or-test",
238            &cache,
239            &endpoints,
240            Duration::from_secs(0),
241        )
242        .await
243        .unwrap();
244        assert_eq!(out.snapshot.total_credits, 100.0);
245        assert_eq!(out.snapshot.total_usage, 25.5);
246        assert!((out.snapshot.balance() - 74.5).abs() < 1e-9);
247        assert_eq!(out.snapshot.label, "OpenRouter — prod");
248        assert!(!out.stale);
249    }
250
251    #[tokio::test]
252    async fn http_error_falls_back_to_cache_when_present() {
253        let mut server = mockito::Server::new_async().await;
254        server
255            .mock("GET", "/api/v1/credits")
256            .with_status(401)
257            .with_body(r#"{"error":"unauthorized"}"#)
258            .create_async()
259            .await;
260        server
261            .mock("GET", "/api/v1/key")
262            .with_status(401)
263            .with_body(r#"{"error":"unauthorized"}"#)
264            .create_async()
265            .await;
266
267        let (_td, cache) = cache_fixture();
268        // Seed cache with a "snapshot" repr.
269        let seed = serde_json::json!({
270            "snapshot": {
271                "label":"OpenRouter — seed","total_credits": 50.0,
272                "total_usage": 10.0,"usage_daily":1.0,"usage_weekly":3.0,
273                "usage_monthly":10.0,"is_free_tier":false,
274                "limit":null,"limit_remaining":null
275            }
276        });
277        cache.write_payload(seed.to_string().as_bytes()).unwrap();
278
279        let client = reqwest::Client::new();
280        let endpoints = Endpoints {
281            credits: format!("{}/api/v1/credits", server.url()),
282            key: format!("{}/api/v1/key", server.url()),
283        };
284        let out = fetch_snapshot(&client, "k", &cache, &endpoints, Duration::from_secs(0))
285            .await
286            .unwrap();
287        assert!(out.stale);
288        assert_eq!(out.snapshot.label, "OpenRouter — seed");
289        assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
290    }
291}