use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use crate::config::{ProxyConfig, proxy_home_dir};
use crate::lb::LbState;
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
enum ProviderKind {
BudgetHttpJson,
YescodeProfile,
}
#[derive(Debug, Deserialize, Serialize)]
struct UsageProviderConfig {
id: String,
kind: ProviderKind,
domains: Vec<String>,
endpoint: String,
#[serde(default)]
token_env: Option<String>,
#[serde(default)]
poll_interval_secs: Option<u64>,
}
#[derive(Debug, Deserialize, Serialize, Default)]
struct UsageProvidersFile {
#[serde(default)]
providers: Vec<UsageProviderConfig>,
}
#[derive(Debug, Clone)]
struct UpstreamRef {
config_name: String,
index: usize,
}
static LAST_USAGE_POLL: OnceLock<Mutex<HashMap<String, Instant>>> = OnceLock::new();
const MIN_POLL_INTERVAL_SECS: u64 = 20;
fn usage_providers_path() -> std::path::PathBuf {
proxy_home_dir().join("usage_providers.json")
}
fn default_providers() -> UsageProvidersFile {
UsageProvidersFile {
providers: vec![
UsageProviderConfig {
id: "packycode".to_string(),
kind: ProviderKind::BudgetHttpJson,
domains: vec!["packycode.com".to_string()],
endpoint: "https://www.packycode.com/api/backend/users/info".to_string(),
token_env: None,
poll_interval_secs: Some(60),
},
UsageProviderConfig {
id: "yescode".to_string(),
kind: ProviderKind::YescodeProfile,
domains: vec!["yes.vg".to_string()],
endpoint: "https://co.yes.vg/api/v1/auth/profile".to_string(),
token_env: None,
poll_interval_secs: Some(60),
},
],
}
}
fn load_providers() -> UsageProvidersFile {
let path = usage_providers_path();
if let Ok(text) = std::fs::read_to_string(&path)
&& let Ok(file) = serde_json::from_str::<UsageProvidersFile>(&text)
{
return file;
}
let default = default_providers();
if let Ok(text) = serde_json::to_string_pretty(&default) {
if let Some(parent) = path.parent() {
let _ = std::fs::create_dir_all(parent);
}
let _ = std::fs::write(&path, text);
}
default
}
fn domain_matches(base_url: &str, domains: &[String]) -> bool {
let url = match reqwest::Url::parse(base_url) {
Ok(u) => u,
Err(_) => return false,
};
let host = match url.host_str() {
Some(h) => h,
None => return false,
};
for d in domains {
if host == d || host.ends_with(&format!(".{}", d)) {
return true;
}
}
false
}
fn resolve_token(
provider: &UsageProviderConfig,
upstreams: &[UpstreamRef],
cfg: &ProxyConfig,
) -> Option<String> {
if let Some(env_name) = &provider.token_env
&& let Ok(v) = std::env::var(env_name)
&& !v.trim().is_empty()
{
return Some(v);
}
for uref in upstreams {
if let Some(service) = cfg.codex.configs.get(&uref.config_name)
&& let Some(up) = service.upstreams.get(uref.index)
{
if let Some(token) = up.auth.resolve_auth_token() {
return Some(token);
}
if let Some(token) = up.auth.resolve_api_key() {
return Some(token);
}
}
}
None
}
async fn poll_budget_http_json(
client: &Client,
endpoint: &str,
token: &str,
) -> Result<(bool, f64, f64)> {
let resp = client
.get(endpoint)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json")
.send()
.await?;
if !resp.status().is_success() {
anyhow::bail!("usage provider HTTP {}", resp.status());
}
let value: serde_json::Value = resp.json().await?;
let monthly_budget = value
.get("monthly_budget_usd")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let monthly_spent = value
.get("monthly_spent_usd")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let exhausted = monthly_budget > 0.0 && monthly_spent >= monthly_budget;
Ok((exhausted, monthly_budget, monthly_spent))
}
async fn poll_yescode_profile(
client: &Client,
endpoint: &str,
token: &str,
) -> Result<(bool, f64, f64, f64)> {
let resp = client
.get(endpoint)
.header("X-API-Key", token)
.header("Accept", "application/json")
.send()
.await?;
if !resp.status().is_success() {
anyhow::bail!("yescode profile HTTP {}", resp.status());
}
let value: serde_json::Value = resp.json().await?;
let subscription_balance = value
.get("subscription_balance")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let paygo_balance = value
.get("pay_as_you_go_balance")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let total_balance = subscription_balance + paygo_balance;
let exhausted = total_balance <= 0.0;
Ok((
exhausted,
total_balance,
subscription_balance,
paygo_balance,
))
}
fn update_usage_exhausted(
lb_states: &Arc<Mutex<HashMap<String, LbState>>>,
cfg: &ProxyConfig,
upstreams: &[UpstreamRef],
exhausted: bool,
) {
let mut map = match lb_states.lock() {
Ok(m) => m,
Err(_) => return,
};
for uref in upstreams {
let service = match cfg.codex.configs.get(&uref.config_name) {
Some(s) => s,
None => continue,
};
let len = service.upstreams.len();
let entry = map
.entry(uref.config_name.clone())
.or_insert_with(LbState::default);
if entry.failure_counts.len() != len {
entry.failure_counts.resize(len, 0);
entry.cooldown_until.resize(len, None);
entry.usage_exhausted.resize(len, false);
}
if uref.index < entry.usage_exhausted.len() {
entry.usage_exhausted[uref.index] = exhausted;
}
}
}
pub async fn poll_for_codex_upstream(
cfg: Arc<ProxyConfig>,
lb_states: Arc<Mutex<HashMap<String, LbState>>>,
config_name: &str,
upstream_index: usize,
) {
if cfg!(test) {
return;
}
let providers_file = load_providers();
if providers_file.providers.is_empty() {
return;
}
let current_service = match cfg.codex.configs.get(config_name) {
Some(s) => s,
None => return,
};
let current_upstream = match current_service.upstreams.get(upstream_index) {
Some(u) => u,
None => return,
};
let current_base_url = current_upstream.base_url.clone();
let now = Instant::now();
let poll_map = LAST_USAGE_POLL.get_or_init(|| Mutex::new(HashMap::new()));
let mut client: Option<Client> = None;
for provider in providers_file.providers {
if !domain_matches(¤t_base_url, &provider.domains) {
continue;
}
let mut interval_secs = provider
.poll_interval_secs
.unwrap_or(MIN_POLL_INTERVAL_SECS);
if interval_secs < MIN_POLL_INTERVAL_SECS {
interval_secs = MIN_POLL_INTERVAL_SECS;
}
if interval_secs > 0 {
let mut map = match poll_map.lock() {
Ok(m) => m,
Err(_) => continue,
};
if let Some(last) = map.get(&provider.id)
&& now.duration_since(*last) < Duration::from_secs(interval_secs)
{
continue;
}
map.insert(provider.id.clone(), now);
}
let mut hosts: Vec<String> = Vec::new();
for service in cfg.codex.configs.values() {
for upstream in &service.upstreams {
if domain_matches(&upstream.base_url, &provider.domains)
&& let Ok(url) = reqwest::Url::parse(&upstream.base_url)
&& let Some(host) = url.host_str()
{
hosts.push(host.to_string());
}
}
}
hosts.sort();
hosts.dedup();
if hosts.len() > 1 {
warn!(
"usage provider '{}' is associated with multiple hosts: {:?}; \
将按统一额度处理这些 upstream,如需区分配额请拆分为多个 provider 配置",
provider.id, hosts
);
}
let current_ref = UpstreamRef {
config_name: config_name.to_string(),
index: upstream_index,
};
let upstreams = vec![current_ref];
let c = client.get_or_insert_with(Client::new);
if let Some(token) = resolve_token(&provider, &upstreams, &cfg) {
match provider.kind {
ProviderKind::BudgetHttpJson => {
match poll_budget_http_json(c, &provider.endpoint, &token).await {
Ok((exhausted, monthly_budget, monthly_spent)) => {
update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
info!(
"usage provider '{}' exhausted = {} (monthly: {:.2}/{:.2} USD)",
provider.id, exhausted, monthly_spent, monthly_budget
);
}
Err(err) => {
warn!("usage provider '{}' poll failed: {}", provider.id, err);
}
}
}
ProviderKind::YescodeProfile => {
match poll_yescode_profile(c, &provider.endpoint, &token).await {
Ok((exhausted, total_balance, sub_balance, paygo_balance)) => {
update_usage_exhausted(&lb_states, &cfg, &upstreams, exhausted);
info!(
"usage provider '{}' exhausted = {} (yescode balance: total={:.2}, subscription={:.2}, paygo={:.2})",
provider.id, exhausted, total_balance, sub_balance, paygo_balance
);
}
Err(err) => {
warn!("usage provider '{}' poll failed: {}", provider.id, err);
}
}
}
}
} else {
warn!(
"usage provider '{}' has no usable token (checked token_env and associated upstream auth_token); \
跳过本次用量查询,请检查 usage_providers.json 和 ~/.codex-helper/config.json",
provider.id
);
}
}
}