Skip to main content

codex_helper_core/
usage_providers.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3use std::time::{Duration, Instant};
4
5use anyhow::Result;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use tracing::{info, warn};
9
10use crate::config::{ProxyConfig, proxy_home_dir};
11use crate::lb::LbState;
12
13#[derive(Debug, Deserialize, Serialize)]
14#[serde(rename_all = "snake_case")]
15enum ProviderKind {
16    /// 简单预算接口,返回 total/used,判断是否用尽
17    BudgetHttpJson,
18    /// YesCode 账户用量,基于 /api/v1/auth/profile 返回的余额信息
19    YescodeProfile,
20}
21
22#[derive(Debug, Deserialize, Serialize)]
23struct UsageProviderConfig {
24    id: String,
25    kind: ProviderKind,
26    domains: Vec<String>,
27    endpoint: String,
28    #[serde(default)]
29    token_env: Option<String>,
30    #[serde(default)]
31    poll_interval_secs: Option<u64>,
32}
33
34#[derive(Debug, Deserialize, Serialize, Default)]
35struct UsageProvidersFile {
36    #[serde(default)]
37    providers: Vec<UsageProviderConfig>,
38}
39
40#[derive(Debug, Clone)]
41struct UpstreamRef {
42    config_name: String,
43    index: usize,
44}
45
46// 全局节流状态:按 provider.id 记录最近一次查询时间,避免高频请求。
47static LAST_USAGE_POLL: OnceLock<Mutex<HashMap<String, Instant>>> = OnceLock::new();
48
49// Minimal poll interval per provider to avoid hammering usage APIs.
50const MIN_POLL_INTERVAL_SECS: u64 = 20;
51
52fn usage_providers_path() -> std::path::PathBuf {
53    proxy_home_dir().join("usage_providers.json")
54}
55
56fn default_providers() -> UsageProvidersFile {
57    UsageProvidersFile {
58        providers: vec![
59            UsageProviderConfig {
60                id: "packycode".to_string(),
61                kind: ProviderKind::BudgetHttpJson,
62                domains: vec!["packycode.com".to_string()],
63                endpoint: "https://www.packycode.com/api/backend/users/info".to_string(),
64                token_env: None,
65                poll_interval_secs: Some(60),
66            },
67            UsageProviderConfig {
68                id: "yescode".to_string(),
69                kind: ProviderKind::YescodeProfile,
70                // yes.vg 匹配 co.yes.vg / cotest.yes.vg 等子域名
71                domains: vec!["yes.vg".to_string()],
72                endpoint: "https://co.yes.vg/api/v1/auth/profile".to_string(),
73                token_env: None,
74                poll_interval_secs: Some(60),
75            },
76        ],
77    }
78}
79
80fn load_providers() -> UsageProvidersFile {
81    let path = usage_providers_path();
82    if let Ok(text) = std::fs::read_to_string(&path)
83        && let Ok(file) = serde_json::from_str::<UsageProvidersFile>(&text)
84    {
85        return file;
86    }
87
88    // 写入默认配置(当前仅包含 packycode),方便用户查看/修改
89    let default = default_providers();
90    if let Ok(text) = serde_json::to_string_pretty(&default) {
91        if let Some(parent) = path.parent() {
92            let _ = std::fs::create_dir_all(parent);
93        }
94        let _ = std::fs::write(&path, text);
95    }
96    default
97}
98
99fn domain_matches(base_url: &str, domains: &[String]) -> bool {
100    let url = match reqwest::Url::parse(base_url) {
101        Ok(u) => u,
102        Err(_) => return false,
103    };
104    let host = match url.host_str() {
105        Some(h) => h,
106        None => return false,
107    };
108    for d in domains {
109        if host == d || host.ends_with(&format!(".{}", d)) {
110            return true;
111        }
112    }
113    false
114}
115
116fn resolve_token(
117    provider: &UsageProviderConfig,
118    upstreams: &[UpstreamRef],
119    cfg: &ProxyConfig,
120) -> Option<String> {
121    // 优先: token_env 环境变量
122    if let Some(env_name) = &provider.token_env
123        && let Ok(v) = std::env::var(env_name)
124        && !v.trim().is_empty()
125    {
126        return Some(v);
127    }
128
129    // 否则: 使用绑定 upstream 的 auth_token(当前 Codex 正在使用的 token)
130    for uref in upstreams {
131        if let Some(service) = cfg.codex.configs.get(&uref.config_name)
132            && let Some(up) = service.upstreams.get(uref.index)
133        {
134            if let Some(token) = up.auth.resolve_auth_token() {
135                return Some(token);
136            }
137            if let Some(token) = up.auth.resolve_api_key() {
138                return Some(token);
139            }
140        }
141    }
142    None
143}
144
145async fn poll_budget_http_json(
146    client: &Client,
147    endpoint: &str,
148    token: &str,
149) -> Result<(bool, f64, f64)> {
150    let resp = client
151        .get(endpoint)
152        .header("Authorization", format!("Bearer {}", token))
153        .header("Content-Type", "application/json")
154        .send()
155        .await?;
156
157    if !resp.status().is_success() {
158        anyhow::bail!("usage provider HTTP {}", resp.status());
159    }
160    let value: serde_json::Value = resp.json().await?;
161
162    let monthly_budget = value
163        .get("monthly_budget_usd")
164        .and_then(|v| v.as_f64())
165        .unwrap_or(0.0);
166    let monthly_spent = value
167        .get("monthly_spent_usd")
168        .and_then(|v| v.as_f64())
169        .unwrap_or(0.0);
170
171    let exhausted = monthly_budget > 0.0 && monthly_spent >= monthly_budget;
172    Ok((exhausted, monthly_budget, monthly_spent))
173}
174
175async fn poll_yescode_profile(
176    client: &Client,
177    endpoint: &str,
178    token: &str,
179) -> Result<(bool, f64, f64, f64)> {
180    let resp = client
181        .get(endpoint)
182        .header("X-API-Key", token)
183        .header("Accept", "application/json")
184        .send()
185        .await?;
186
187    if !resp.status().is_success() {
188        anyhow::bail!("yescode profile HTTP {}", resp.status());
189    }
190    let value: serde_json::Value = resp.json().await?;
191
192    let subscription_balance = value
193        .get("subscription_balance")
194        .and_then(|v| v.as_f64())
195        .unwrap_or(0.0);
196    let paygo_balance = value
197        .get("pay_as_you_go_balance")
198        .and_then(|v| v.as_f64())
199        .unwrap_or(0.0);
200    let total_balance = subscription_balance + paygo_balance;
201
202    // 简单策略:总余额 <= 0 视为额度用尽。
203    let exhausted = total_balance <= 0.0;
204    Ok((
205        exhausted,
206        total_balance,
207        subscription_balance,
208        paygo_balance,
209    ))
210}
211
212fn update_usage_exhausted(
213    lb_states: &Arc<Mutex<HashMap<String, LbState>>>,
214    cfg: &ProxyConfig,
215    upstreams: &[UpstreamRef],
216    exhausted: bool,
217) {
218    let mut map = match lb_states.lock() {
219        Ok(m) => m,
220        Err(_) => return,
221    };
222
223    for uref in upstreams {
224        let service = match cfg.codex.configs.get(&uref.config_name) {
225            Some(s) => s,
226            None => continue,
227        };
228
229        let len = service.upstreams.len();
230        let entry = map
231            .entry(uref.config_name.clone())
232            .or_insert_with(LbState::default);
233        if entry.failure_counts.len() != len {
234            entry.failure_counts.resize(len, 0);
235            entry.cooldown_until.resize(len, None);
236            entry.usage_exhausted.resize(len, false);
237        }
238        if uref.index < entry.usage_exhausted.len() {
239            entry.usage_exhausted[uref.index] = exhausted;
240        }
241    }
242}
243
244/// 在特定 Codex upstream 请求结束后,按需查询一次用量并更新 LB 状态。
245/// 设计为轻量的“按需刷新”,而非后台定时轮询。
246pub async fn poll_for_codex_upstream(
247    cfg: Arc<ProxyConfig>,
248    lb_states: Arc<Mutex<HashMap<String, LbState>>>,
249    config_name: &str,
250    upstream_index: usize,
251) {
252    // Tests should be hermetic and should not depend on any real user `usage_providers.json` on
253    // the machine running the suite. Disable provider polling during tests to avoid flakiness.
254    if cfg!(test) {
255        return;
256    }
257
258    let providers_file = load_providers();
259    if providers_file.providers.is_empty() {
260        return;
261    }
262
263    // Locate the current upstream once; if it no longer exists, bail out quietly.
264    let current_service = match cfg.codex.configs.get(config_name) {
265        Some(s) => s,
266        None => return,
267    };
268    let current_upstream = match current_service.upstreams.get(upstream_index) {
269        Some(u) => u,
270        None => return,
271    };
272    let current_base_url = current_upstream.base_url.clone();
273
274    let now = Instant::now();
275    let poll_map = LAST_USAGE_POLL.get_or_init(|| Mutex::new(HashMap::new()));
276
277    let mut client: Option<Client> = None;
278
279    for provider in providers_file.providers {
280        // Only providers whose domains match the current upstream are considered.
281        if !domain_matches(&current_base_url, &provider.domains) {
282            continue;
283        }
284
285        // Compute effective poll interval with a global minimum to avoid hammering.
286        let mut interval_secs = provider
287            .poll_interval_secs
288            .unwrap_or(MIN_POLL_INTERVAL_SECS);
289        if interval_secs < MIN_POLL_INTERVAL_SECS {
290            interval_secs = MIN_POLL_INTERVAL_SECS;
291        }
292
293        if interval_secs > 0 {
294            let mut map = match poll_map.lock() {
295                Ok(m) => m,
296                Err(_) => continue,
297            };
298            if let Some(last) = map.get(&provider.id)
299                && now.duration_since(*last) < Duration::from_secs(interval_secs)
300            {
301                continue;
302            }
303            map.insert(provider.id.clone(), now);
304        }
305
306        // For diagnostics, still check whether this provider is associated with
307        // multiple hosts across configs, but only once per poll.
308        let mut hosts: Vec<String> = Vec::new();
309        for service in cfg.codex.configs.values() {
310            for upstream in &service.upstreams {
311                if domain_matches(&upstream.base_url, &provider.domains)
312                    && let Ok(url) = reqwest::Url::parse(&upstream.base_url)
313                    && let Some(host) = url.host_str()
314                {
315                    hosts.push(host.to_string());
316                }
317            }
318        }
319        hosts.sort();
320        hosts.dedup();
321        if hosts.len() > 1 {
322            warn!(
323                "usage provider '{}' is associated with multiple hosts: {:?}; \
324将按统一额度处理这些 upstream,如需区分配额请拆分为多个 provider 配置",
325                provider.id, hosts
326            );
327        }
328
329        // Only the current upstream participates in token resolution and usage update.
330        let current_ref = UpstreamRef {
331            config_name: config_name.to_string(),
332            index: upstream_index,
333        };
334        let upstreams = vec![current_ref];
335
336        let c = client.get_or_insert_with(Client::new);
337
338        if let Some(token) = resolve_token(&provider, &upstreams, &cfg) {
339            match provider.kind {
340                ProviderKind::BudgetHttpJson => {
341                    match poll_budget_http_json(c, &provider.endpoint, &token).await {
342                        Ok((exhausted, monthly_budget, monthly_spent)) => {
343                            update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
344                            info!(
345                                "usage provider '{}' exhausted = {} (monthly: {:.2}/{:.2} USD)",
346                                provider.id, exhausted, monthly_spent, monthly_budget
347                            );
348                        }
349                        Err(err) => {
350                            warn!("usage provider '{}' poll failed: {}", provider.id, err);
351                        }
352                    }
353                }
354                ProviderKind::YescodeProfile => {
355                    match poll_yescode_profile(c, &provider.endpoint, &token).await {
356                        Ok((exhausted, total_balance, sub_balance, paygo_balance)) => {
357                            update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
358                            info!(
359                                "usage provider '{}' exhausted = {} (yescode balance: total={:.2}, subscription={:.2}, paygo={:.2})",
360                                provider.id, exhausted, total_balance, sub_balance, paygo_balance
361                            );
362                        }
363                        Err(err) => {
364                            warn!("usage provider '{}' poll failed: {}", provider.id, err);
365                        }
366                    }
367                }
368            }
369        } else {
370            warn!(
371                "usage provider '{}' has no usable token (checked token_env and associated upstream auth_token); \
372跳过本次用量查询,请检查 usage_providers.json 和 ~/.codex-helper/config.json",
373                provider.id
374            );
375        }
376    }
377}