1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use tracing::{info, warn};
9
10use crate::config::{ProxyConfig, proxy_home_dir};
11use crate::lb::LbState;
12
13#[derive(Debug, Deserialize, Serialize)]
14#[serde(rename_all = "snake_case")]
15enum ProviderKind {
16 BudgetHttpJson,
18 YescodeProfile,
20}
21
22#[derive(Debug, Deserialize, Serialize)]
23struct UsageProviderConfig {
24 id: String,
25 kind: ProviderKind,
26 domains: Vec<String>,
27 endpoint: String,
28 #[serde(default)]
29 token_env: Option<String>,
30 #[serde(default)]
31 poll_interval_secs: Option<u64>,
32}
33
34#[derive(Debug, Deserialize, Serialize, Default)]
35struct UsageProvidersFile {
36 #[serde(default)]
37 providers: Vec<UsageProviderConfig>,
38}
39
40#[derive(Debug, Clone)]
41struct UpstreamRef {
42 config_name: String,
43 index: usize,
44}
45
46static LAST_USAGE_POLL: OnceLock<Mutex<HashMap<String, Instant>>> = OnceLock::new();
48
49const MIN_POLL_INTERVAL_SECS: u64 = 20;
51
52fn usage_providers_path() -> std::path::PathBuf {
53 proxy_home_dir().join("usage_providers.json")
54}
55
56fn default_providers() -> UsageProvidersFile {
57 UsageProvidersFile {
58 providers: vec![
59 UsageProviderConfig {
60 id: "packycode".to_string(),
61 kind: ProviderKind::BudgetHttpJson,
62 domains: vec!["packycode.com".to_string()],
63 endpoint: "https://www.packycode.com/api/backend/users/info".to_string(),
64 token_env: None,
65 poll_interval_secs: Some(60),
66 },
67 UsageProviderConfig {
68 id: "yescode".to_string(),
69 kind: ProviderKind::YescodeProfile,
70 domains: vec!["yes.vg".to_string()],
72 endpoint: "https://co.yes.vg/api/v1/auth/profile".to_string(),
73 token_env: None,
74 poll_interval_secs: Some(60),
75 },
76 ],
77 }
78}
79
80fn load_providers() -> UsageProvidersFile {
81 let path = usage_providers_path();
82 if let Ok(text) = std::fs::read_to_string(&path)
83 && let Ok(file) = serde_json::from_str::<UsageProvidersFile>(&text)
84 {
85 return file;
86 }
87
88 let default = default_providers();
90 if let Ok(text) = serde_json::to_string_pretty(&default) {
91 if let Some(parent) = path.parent() {
92 let _ = std::fs::create_dir_all(parent);
93 }
94 let _ = std::fs::write(&path, text);
95 }
96 default
97}
98
99fn domain_matches(base_url: &str, domains: &[String]) -> bool {
100 let url = match reqwest::Url::parse(base_url) {
101 Ok(u) => u,
102 Err(_) => return false,
103 };
104 let host = match url.host_str() {
105 Some(h) => h,
106 None => return false,
107 };
108 for d in domains {
109 if host == d || host.ends_with(&format!(".{}", d)) {
110 return true;
111 }
112 }
113 false
114}
115
116fn resolve_token(
117 provider: &UsageProviderConfig,
118 upstreams: &[UpstreamRef],
119 cfg: &ProxyConfig,
120) -> Option<String> {
121 if let Some(env_name) = &provider.token_env
123 && let Ok(v) = std::env::var(env_name)
124 && !v.trim().is_empty()
125 {
126 return Some(v);
127 }
128
129 for uref in upstreams {
131 if let Some(service) = cfg.codex.configs.get(&uref.config_name)
132 && let Some(up) = service.upstreams.get(uref.index)
133 {
134 if let Some(token) = up.auth.resolve_auth_token() {
135 return Some(token);
136 }
137 if let Some(token) = up.auth.resolve_api_key() {
138 return Some(token);
139 }
140 }
141 }
142 None
143}
144
145async fn poll_budget_http_json(
146 client: &Client,
147 endpoint: &str,
148 token: &str,
149) -> Result<(bool, f64, f64)> {
150 let resp = client
151 .get(endpoint)
152 .header("Authorization", format!("Bearer {}", token))
153 .header("Content-Type", "application/json")
154 .send()
155 .await?;
156
157 if !resp.status().is_success() {
158 anyhow::bail!("usage provider HTTP {}", resp.status());
159 }
160 let value: serde_json::Value = resp.json().await?;
161
162 let monthly_budget = value
163 .get("monthly_budget_usd")
164 .and_then(|v| v.as_f64())
165 .unwrap_or(0.0);
166 let monthly_spent = value
167 .get("monthly_spent_usd")
168 .and_then(|v| v.as_f64())
169 .unwrap_or(0.0);
170
171 let exhausted = monthly_budget > 0.0 && monthly_spent >= monthly_budget;
172 Ok((exhausted, monthly_budget, monthly_spent))
173}
174
175async fn poll_yescode_profile(
176 client: &Client,
177 endpoint: &str,
178 token: &str,
179) -> Result<(bool, f64, f64, f64)> {
180 let resp = client
181 .get(endpoint)
182 .header("X-API-Key", token)
183 .header("Accept", "application/json")
184 .send()
185 .await?;
186
187 if !resp.status().is_success() {
188 anyhow::bail!("yescode profile HTTP {}", resp.status());
189 }
190 let value: serde_json::Value = resp.json().await?;
191
192 let subscription_balance = value
193 .get("subscription_balance")
194 .and_then(|v| v.as_f64())
195 .unwrap_or(0.0);
196 let paygo_balance = value
197 .get("pay_as_you_go_balance")
198 .and_then(|v| v.as_f64())
199 .unwrap_or(0.0);
200 let total_balance = subscription_balance + paygo_balance;
201
202 let exhausted = total_balance <= 0.0;
204 Ok((
205 exhausted,
206 total_balance,
207 subscription_balance,
208 paygo_balance,
209 ))
210}
211
212fn update_usage_exhausted(
213 lb_states: &Arc<Mutex<HashMap<String, LbState>>>,
214 cfg: &ProxyConfig,
215 upstreams: &[UpstreamRef],
216 exhausted: bool,
217) {
218 let mut map = match lb_states.lock() {
219 Ok(m) => m,
220 Err(_) => return,
221 };
222
223 for uref in upstreams {
224 let service = match cfg.codex.configs.get(&uref.config_name) {
225 Some(s) => s,
226 None => continue,
227 };
228
229 let len = service.upstreams.len();
230 let entry = map
231 .entry(uref.config_name.clone())
232 .or_insert_with(LbState::default);
233 if entry.failure_counts.len() != len {
234 entry.failure_counts.resize(len, 0);
235 entry.cooldown_until.resize(len, None);
236 entry.usage_exhausted.resize(len, false);
237 }
238 if uref.index < entry.usage_exhausted.len() {
239 entry.usage_exhausted[uref.index] = exhausted;
240 }
241 }
242}
243
244pub async fn poll_for_codex_upstream(
247 cfg: Arc<ProxyConfig>,
248 lb_states: Arc<Mutex<HashMap<String, LbState>>>,
249 config_name: &str,
250 upstream_index: usize,
251) {
252 if cfg!(test) {
255 return;
256 }
257
258 let providers_file = load_providers();
259 if providers_file.providers.is_empty() {
260 return;
261 }
262
263 let current_service = match cfg.codex.configs.get(config_name) {
265 Some(s) => s,
266 None => return,
267 };
268 let current_upstream = match current_service.upstreams.get(upstream_index) {
269 Some(u) => u,
270 None => return,
271 };
272 let current_base_url = current_upstream.base_url.clone();
273
274 let now = Instant::now();
275 let poll_map = LAST_USAGE_POLL.get_or_init(|| Mutex::new(HashMap::new()));
276
277 let mut client: Option<Client> = None;
278
279 for provider in providers_file.providers {
280 if !domain_matches(¤t_base_url, &provider.domains) {
282 continue;
283 }
284
285 let mut interval_secs = provider
287 .poll_interval_secs
288 .unwrap_or(MIN_POLL_INTERVAL_SECS);
289 if interval_secs < MIN_POLL_INTERVAL_SECS {
290 interval_secs = MIN_POLL_INTERVAL_SECS;
291 }
292
293 if interval_secs > 0 {
294 let mut map = match poll_map.lock() {
295 Ok(m) => m,
296 Err(_) => continue,
297 };
298 if let Some(last) = map.get(&provider.id)
299 && now.duration_since(*last) < Duration::from_secs(interval_secs)
300 {
301 continue;
302 }
303 map.insert(provider.id.clone(), now);
304 }
305
306 let mut hosts: Vec<String> = Vec::new();
309 for service in cfg.codex.configs.values() {
310 for upstream in &service.upstreams {
311 if domain_matches(&upstream.base_url, &provider.domains)
312 && let Ok(url) = reqwest::Url::parse(&upstream.base_url)
313 && let Some(host) = url.host_str()
314 {
315 hosts.push(host.to_string());
316 }
317 }
318 }
319 hosts.sort();
320 hosts.dedup();
321 if hosts.len() > 1 {
322 warn!(
323 "usage provider '{}' is associated with multiple hosts: {:?}; \
324将按统一额度处理这些 upstream,如需区分配额请拆分为多个 provider 配置",
325 provider.id, hosts
326 );
327 }
328
329 let current_ref = UpstreamRef {
331 config_name: config_name.to_string(),
332 index: upstream_index,
333 };
334 let upstreams = vec![current_ref];
335
336 let c = client.get_or_insert_with(Client::new);
337
338 if let Some(token) = resolve_token(&provider, &upstreams, &cfg) {
339 match provider.kind {
340 ProviderKind::BudgetHttpJson => {
341 match poll_budget_http_json(c, &provider.endpoint, &token).await {
342 Ok((exhausted, monthly_budget, monthly_spent)) => {
343 update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
344 info!(
345 "usage provider '{}' exhausted = {} (monthly: {:.2}/{:.2} USD)",
346 provider.id, exhausted, monthly_spent, monthly_budget
347 );
348 }
349 Err(err) => {
350 warn!("usage provider '{}' poll failed: {}", provider.id, err);
351 }
352 }
353 }
354 ProviderKind::YescodeProfile => {
355 match poll_yescode_profile(c, &provider.endpoint, &token).await {
356 Ok((exhausted, total_balance, sub_balance, paygo_balance)) => {
357 update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
358 info!(
359 "usage provider '{}' exhausted = {} (yescode balance: total={:.2}, subscription={:.2}, paygo={:.2})",
360 provider.id, exhausted, total_balance, sub_balance, paygo_balance
361 );
362 }
363 Err(err) => {
364 warn!("usage provider '{}' poll failed: {}", provider.id, err);
365 }
366 }
367 }
368 }
369 } else {
370 warn!(
371 "usage provider '{}' has no usable token (checked token_env and associated upstream auth_token); \
372跳过本次用量查询,请检查 usage_providers.json 和 ~/.codex-helper/config.json",
373 provider.id
374 );
375 }
376 }
377}