1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, clap::ValueEnum)]
8#[serde(rename_all = "lowercase")]
9pub enum Provider {
10 Claude,
11 OpenAi,
12}
13
14impl Provider {
15 pub fn as_str(self) -> &'static str {
16 match self {
17 Self::Claude => "claude",
18 Self::OpenAi => "openai",
19 }
20 }
21
22 pub fn display_name(self) -> &'static str {
23 match self {
24 Self::Claude => "Anthropic / Claude",
25 Self::OpenAi => "OpenAI",
26 }
27 }
28
29 pub fn config(self) -> ProviderConfig {
30 match self {
31 Self::Claude => ProviderConfig {
32 client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e",
33 auth_url: "https://claude.ai/oauth/authorize",
34 token_url: "https://platform.claude.com/v1/oauth/token",
35 default_port: 54321,
36 scopes: "user:inference user:profile",
37 token_exchange_json: true,
38 },
39 Self::OpenAi => ProviderConfig {
40 client_id: "app_EMoamEEZ73f0CkXaXp7hrann",
41 auth_url: "https://auth.openai.com/oauth/authorize",
42 token_url: "https://auth.openai.com/oauth/token",
43 default_port: 1455,
44 scopes: "openid profile email offline_access",
45 token_exchange_json: false,
46 },
47 }
48 }
49
50 pub fn env_var(self) -> &'static str {
52 match self {
53 Self::Claude => "ANTHROPIC_API_KEY",
54 Self::OpenAi => "OPENAI_API_KEY",
55 }
56 }
57
58 pub const ALL: [Provider; 2] = [Provider::Claude, Provider::OpenAi];
59}
60
61impl fmt::Display for Provider {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.write_str(self.as_str())
64 }
65}
66
67impl FromStr for Provider {
68 type Err = String;
69
70 fn from_str(s: &str) -> Result<Self, Self::Err> {
71 match s.to_lowercase().as_str() {
72 "claude" | "anthropic" => Ok(Self::Claude),
73 "openai" => Ok(Self::OpenAi),
74 other => Err(format!("unknown provider '{other}' — use 'claude' or 'openai'")),
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct ProviderConfig {
82 pub client_id: &'static str,
83 pub auth_url: &'static str,
84 pub token_url: &'static str,
85 pub default_port: u16,
86 pub scopes: &'static str,
87 pub token_exchange_json: bool,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Credentials {
93 pub provider: Provider,
94 pub access_token: String,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub refresh_token: Option<String>,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub expires_at: Option<u64>,
99}
100
101impl Credentials {
102 pub fn is_expired(&self) -> bool {
104 self.expires_at.is_some_and(|exp| {
105 let now = std::time::SystemTime::now()
106 .duration_since(std::time::UNIX_EPOCH)
107 .unwrap_or_default()
108 .as_secs();
109 exp <= now + 60
110 })
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn provider_round_trip() {
120 for p in Provider::ALL {
121 let s = p.as_str();
122 let parsed: Provider = s.parse().unwrap();
123 assert_eq!(parsed, p);
124 }
125 }
126
127 #[test]
128 fn provider_display() {
129 assert_eq!(Provider::Claude.to_string(), "claude");
130 assert_eq!(Provider::OpenAi.to_string(), "openai");
131 }
132
133 #[test]
134 fn credentials_not_expired_when_no_expiry() {
135 let creds = Credentials {
136 provider: Provider::Claude,
137 access_token: "test".into(),
138 refresh_token: None,
139 expires_at: None,
140 };
141 assert!(!creds.is_expired());
142 }
143
144 #[test]
145 fn credentials_expired_in_past() {
146 let creds = Credentials {
147 provider: Provider::Claude,
148 access_token: "test".into(),
149 refresh_token: None,
150 expires_at: Some(1000),
151 };
152 assert!(creds.is_expired());
153 }
154
155 #[test]
156 fn credentials_not_expired_far_future() {
157 let creds = Credentials {
158 provider: Provider::Claude,
159 access_token: "test".into(),
160 refresh_token: None,
161 expires_at: Some(u64::MAX / 2),
162 };
163 assert!(!creds.is_expired());
164 }
165
166 #[test]
167 fn provider_config_has_valid_urls() {
168 for p in Provider::ALL {
169 let cfg = p.config();
170 assert!(cfg.auth_url.starts_with("https://"));
171 assert!(cfg.token_url.starts_with("https://"));
172 assert!(!cfg.client_id.is_empty());
173 assert!(cfg.default_port > 0);
174 }
175 }
176}