Skip to main content

lexicon_spec/
auth.rs

1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6/// Supported AI authentication providers.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, clap::ValueEnum)]
8#[serde(rename_all = "lowercase")]
9pub enum Provider {
10    Claude,
11    OpenAi,
12}
13
14impl Provider {
15    pub fn as_str(self) -> &'static str {
16        match self {
17            Self::Claude => "claude",
18            Self::OpenAi => "openai",
19        }
20    }
21
22    pub fn display_name(self) -> &'static str {
23        match self {
24            Self::Claude => "Anthropic / Claude",
25            Self::OpenAi => "OpenAI",
26        }
27    }
28
29    pub fn config(self) -> ProviderConfig {
30        match self {
31            Self::Claude => ProviderConfig {
32                client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e",
33                auth_url: "https://claude.ai/oauth/authorize",
34                token_url: "https://platform.claude.com/v1/oauth/token",
35                default_port: 54321,
36                scopes: "user:inference user:profile",
37                token_exchange_json: true,
38            },
39            Self::OpenAi => ProviderConfig {
40                client_id: "app_EMoamEEZ73f0CkXaXp7hrann",
41                auth_url: "https://auth.openai.com/oauth/authorize",
42                token_url: "https://auth.openai.com/oauth/token",
43                default_port: 1455,
44                scopes: "openid profile email offline_access",
45                token_exchange_json: false,
46            },
47        }
48    }
49
50    /// Environment variable name for the API key (e.g. `ANTHROPIC_API_KEY`).
51    pub fn env_var(self) -> &'static str {
52        match self {
53            Self::Claude => "ANTHROPIC_API_KEY",
54            Self::OpenAi => "OPENAI_API_KEY",
55        }
56    }
57
58    pub const ALL: [Provider; 2] = [Provider::Claude, Provider::OpenAi];
59}
60
61impl fmt::Display for Provider {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.write_str(self.as_str())
64    }
65}
66
67impl FromStr for Provider {
68    type Err = String;
69
70    fn from_str(s: &str) -> Result<Self, Self::Err> {
71        match s.to_lowercase().as_str() {
72            "claude" | "anthropic" => Ok(Self::Claude),
73            "openai" => Ok(Self::OpenAi),
74            other => Err(format!("unknown provider '{other}' — use 'claude' or 'openai'")),
75        }
76    }
77}
78
79/// OAuth configuration for a provider.
80#[derive(Debug, Clone)]
81pub struct ProviderConfig {
82    pub client_id: &'static str,
83    pub auth_url: &'static str,
84    pub token_url: &'static str,
85    pub default_port: u16,
86    pub scopes: &'static str,
87    /// If true, token exchange uses JSON body (Claude). Otherwise form-encoded (OpenAI).
88    pub token_exchange_json: bool,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct Credentials {
93    pub provider: Provider,
94    pub access_token: String,
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub refresh_token: Option<String>,
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub expires_at: Option<u64>,
99}
100
101impl Credentials {
102    /// Returns true if the token has expired or will expire within 60 seconds.
103    pub fn is_expired(&self) -> bool {
104        self.expires_at.is_some_and(|exp| {
105            let now = std::time::SystemTime::now()
106                .duration_since(std::time::UNIX_EPOCH)
107                .unwrap_or_default()
108                .as_secs();
109            exp <= now + 60
110        })
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn provider_round_trip() {
120        for p in Provider::ALL {
121            let s = p.as_str();
122            let parsed: Provider = s.parse().unwrap();
123            assert_eq!(parsed, p);
124        }
125    }
126
127    #[test]
128    fn provider_display() {
129        assert_eq!(Provider::Claude.to_string(), "claude");
130        assert_eq!(Provider::OpenAi.to_string(), "openai");
131    }
132
133    #[test]
134    fn credentials_not_expired_when_no_expiry() {
135        let creds = Credentials {
136            provider: Provider::Claude,
137            access_token: "test".into(),
138            refresh_token: None,
139            expires_at: None,
140        };
141        assert!(!creds.is_expired());
142    }
143
144    #[test]
145    fn credentials_expired_in_past() {
146        let creds = Credentials {
147            provider: Provider::Claude,
148            access_token: "test".into(),
149            refresh_token: None,
150            expires_at: Some(1000),
151        };
152        assert!(creds.is_expired());
153    }
154
155    #[test]
156    fn credentials_not_expired_far_future() {
157        let creds = Credentials {
158            provider: Provider::Claude,
159            access_token: "test".into(),
160            refresh_token: None,
161            expires_at: Some(u64::MAX / 2),
162        };
163        assert!(!creds.is_expired());
164    }
165
166    #[test]
167    fn provider_config_has_valid_urls() {
168        for p in Provider::ALL {
169            let cfg = p.config();
170            assert!(cfg.auth_url.starts_with("https://"));
171            assert!(cfg.token_url.starts_with("https://"));
172            assert!(!cfg.client_id.is_empty());
173            assert!(cfg.default_port > 0);
174        }
175    }
176}