1use std::time::Duration;
5
6use crate::cache::{Cache, acquire_lock};
7use crate::error::{AppError, Result};
8use crate::usage::OpenRouterSnapshot;
9
10use super::types::{CreditsData, KeyData, OrEnvelope, combine};
11
12pub const BASE_URL: &str = "https://openrouter.ai/api/v1";
13const HTTP_TIMEOUT: Duration = Duration::from_secs(10);
14const LOCK_TIMEOUT: Duration = Duration::from_secs(15);
15
16#[derive(Debug, Clone)]
17pub struct Endpoints {
18 pub credits: String,
19 pub key: String,
20}
21
22impl Default for Endpoints {
23 fn default() -> Self {
24 Self {
25 credits: format!("{BASE_URL}/credits"),
26 key: format!("{BASE_URL}/key"),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
32pub struct FetchOutcome {
33 pub snapshot: OpenRouterSnapshot,
34 pub stale: bool,
35 pub last_error: Option<(u16, String)>,
36 pub cache_age: Option<Duration>,
37}
38
39pub async fn fetch_snapshot(
42 client: &reqwest::Client,
43 api_key: &str,
44 cache: &Cache,
45 endpoints: &Endpoints,
46 cache_ttl: Duration,
47) -> Result<FetchOutcome> {
48 cache.ensure_dir()?;
49 let _lock = acquire_lock(&cache.lock_path(), LOCK_TIMEOUT)?;
50
51 if let Some(bytes) = cache.fresh_payload(cache_ttl)? {
52 return Ok(reuse_cache(bytes, cache, false));
53 }
54
55 match fetch_live(client, endpoints, api_key).await {
56 Ok((credits, key)) => {
57 let snap = combine(credits, key);
58 let cache_repr = serde_json::json!({
60 "snapshot": serde_repr(&snap),
61 });
62 let bytes = serde_json::to_vec(&cache_repr).unwrap_or_default();
63 cache.write_payload(&bytes)?;
64 Ok(FetchOutcome {
65 snapshot: snap,
66 stale: false,
67 last_error: None,
68 cache_age: Some(Duration::ZERO),
69 })
70 }
71 Err(e) if e.is_transient() => fallback_silent(cache),
72 Err(AppError::Http { status, body }) => {
73 cache.mark_stale();
74 cache.write_last_error(status, &body);
75 fallback_with_error(cache, Some((status, body)))
76 }
77 Err(e) => {
78 cache.mark_stale();
79 cache.write_last_error(0, &e.to_string());
80 fallback_with_error(cache, Some((0, e.to_string())))
81 }
82 }
83}
84
85fn fallback_silent(cache: &Cache) -> Result<FetchOutcome> {
86 let Some(bytes) = cache.maybe_payload()? else {
87 return Err(AppError::Transport(
88 "openrouter: no cache and network unreachable".into(),
89 ));
90 };
91 Ok(reuse_cache(bytes, cache, true))
92}
93
94fn fallback_with_error(cache: &Cache, last_error: Option<(u16, String)>) -> Result<FetchOutcome> {
95 let Some(bytes) = cache.maybe_payload()? else {
96 return Err(AppError::Other("openrouter: no usable cache".into()));
97 };
98 let mut outcome = reuse_cache(bytes, cache, true);
99 outcome.last_error = last_error;
100 Ok(outcome)
101}
102
103fn reuse_cache(bytes: Vec<u8>, cache: &Cache, stale: bool) -> FetchOutcome {
104 let snap = parse_cache(&bytes).unwrap_or_else(|_| OpenRouterSnapshot {
105 label: "OpenRouter".into(),
106 total_credits: 0.0,
107 total_usage: 0.0,
108 usage_daily: 0.0,
109 usage_weekly: 0.0,
110 usage_monthly: 0.0,
111 is_free_tier: true,
112 limit: None,
113 limit_remaining: None,
114 });
115 FetchOutcome {
116 snapshot: snap,
117 stale,
118 last_error: cache.read_last_error(),
119 cache_age: cache.payload_age(),
120 }
121}
122
123fn parse_cache(bytes: &[u8]) -> Result<OpenRouterSnapshot> {
124 let v: serde_json::Value = serde_json::from_slice(bytes)?;
125 let s = v
126 .get("snapshot")
127 .ok_or_else(|| AppError::Schema("openrouter cache missing 'snapshot' field".into()))?;
128 Ok(OpenRouterSnapshot {
129 label: s["label"].as_str().unwrap_or("OpenRouter").to_string(),
130 total_credits: s["total_credits"].as_f64().unwrap_or(0.0),
131 total_usage: s["total_usage"].as_f64().unwrap_or(0.0),
132 usage_daily: s["usage_daily"].as_f64().unwrap_or(0.0),
133 usage_weekly: s["usage_weekly"].as_f64().unwrap_or(0.0),
134 usage_monthly: s["usage_monthly"].as_f64().unwrap_or(0.0),
135 is_free_tier: s["is_free_tier"].as_bool().unwrap_or(false),
136 limit: s["limit"].as_f64(),
137 limit_remaining: s["limit_remaining"].as_f64(),
138 })
139}
140
141fn serde_repr(snap: &OpenRouterSnapshot) -> serde_json::Value {
142 serde_json::json!({
143 "label": snap.label,
144 "total_credits": snap.total_credits,
145 "total_usage": snap.total_usage,
146 "usage_daily": snap.usage_daily,
147 "usage_weekly": snap.usage_weekly,
148 "usage_monthly": snap.usage_monthly,
149 "is_free_tier": snap.is_free_tier,
150 "limit": snap.limit,
151 "limit_remaining": snap.limit_remaining,
152 })
153}
154
155async fn fetch_live(
156 client: &reqwest::Client,
157 endpoints: &Endpoints,
158 api_key: &str,
159) -> Result<(CreditsData, KeyData)> {
160 let credits_fut = fetch_one::<CreditsData>(client, &endpoints.credits, api_key);
162 let key_fut = fetch_one::<KeyData>(client, &endpoints.key, api_key);
163 let (credits, key) = tokio::join!(credits_fut, key_fut);
164 Ok((credits?, key?))
165}
166
167async fn fetch_one<T: for<'de> serde::Deserialize<'de>>(
168 client: &reqwest::Client,
169 url: &str,
170 api_key: &str,
171) -> Result<T> {
172 let resp = tokio::time::timeout(
173 HTTP_TIMEOUT,
174 client
175 .get(url)
176 .header("Authorization", format!("Bearer {api_key}"))
177 .send(),
178 )
179 .await
180 .map_err(|_| AppError::Transport(format!("openrouter timeout: {url}")))??;
181
182 let status = resp.status();
183 let bytes = resp.bytes().await?;
184
185 if !status.is_success() {
186 let body = String::from_utf8_lossy(&bytes).chars().take(200).collect();
187 return Err(AppError::Http {
188 status: status.as_u16(),
189 body,
190 });
191 }
192 let env: OrEnvelope<T> = serde_json::from_slice(&bytes)
193 .map_err(|e| AppError::Schema(format!("openrouter {url}: {e}")))?;
194 Ok(env.data)
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use tempfile::TempDir;
201
202 fn cache_fixture() -> (TempDir, Cache) {
203 let td = TempDir::new().unwrap();
204 let cache = Cache::at(td.path().join("openrouter"));
205 cache.ensure_dir().unwrap();
206 (td, cache)
207 }
208
209 #[tokio::test]
210 async fn live_fetch_combines_both_endpoints() {
211 let mut server = mockito::Server::new_async().await;
212 server
213 .mock("GET", "/api/v1/credits")
214 .with_status(200)
215 .with_body(r#"{"data":{"total_credits":100.0,"total_usage":25.5}}"#)
216 .create_async()
217 .await;
218 server
219 .mock("GET", "/api/v1/key")
220 .with_status(200)
221 .with_body(
222 r#"{"data":{"label":"prod","limit":50.0,"limit_remaining":24.5,
223 "usage":25.5,"usage_daily":1.0,"usage_weekly":7.0,"usage_monthly":25.5,
224 "is_free_tier":false}}"#,
225 )
226 .create_async()
227 .await;
228
229 let (_td, cache) = cache_fixture();
230 let client = reqwest::Client::new();
231 let endpoints = Endpoints {
232 credits: format!("{}/api/v1/credits", server.url()),
233 key: format!("{}/api/v1/key", server.url()),
234 };
235 let out = fetch_snapshot(
236 &client,
237 "sk-or-test",
238 &cache,
239 &endpoints,
240 Duration::from_secs(0),
241 )
242 .await
243 .unwrap();
244 assert_eq!(out.snapshot.total_credits, 100.0);
245 assert_eq!(out.snapshot.total_usage, 25.5);
246 assert!((out.snapshot.balance() - 74.5).abs() < 1e-9);
247 assert_eq!(out.snapshot.label, "OpenRouter — prod");
248 assert!(!out.stale);
249 }
250
251 #[tokio::test]
252 async fn http_error_falls_back_to_cache_when_present() {
253 let mut server = mockito::Server::new_async().await;
254 server
255 .mock("GET", "/api/v1/credits")
256 .with_status(401)
257 .with_body(r#"{"error":"unauthorized"}"#)
258 .create_async()
259 .await;
260 server
261 .mock("GET", "/api/v1/key")
262 .with_status(401)
263 .with_body(r#"{"error":"unauthorized"}"#)
264 .create_async()
265 .await;
266
267 let (_td, cache) = cache_fixture();
268 let seed = serde_json::json!({
270 "snapshot": {
271 "label":"OpenRouter — seed","total_credits": 50.0,
272 "total_usage": 10.0,"usage_daily":1.0,"usage_weekly":3.0,
273 "usage_monthly":10.0,"is_free_tier":false,
274 "limit":null,"limit_remaining":null
275 }
276 });
277 cache.write_payload(seed.to_string().as_bytes()).unwrap();
278
279 let client = reqwest::Client::new();
280 let endpoints = Endpoints {
281 credits: format!("{}/api/v1/credits", server.url()),
282 key: format!("{}/api/v1/key", server.url()),
283 };
284 let out = fetch_snapshot(&client, "k", &cache, &endpoints, Duration::from_secs(0))
285 .await
286 .unwrap();
287 assert!(out.stale);
288 assert_eq!(out.snapshot.label, "OpenRouter — seed");
289 assert_eq!(out.last_error.as_ref().map(|(c, _)| *c), Some(401));
290 }
291}