1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{AuthError, Result};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ProviderType {
12 Google,
14 Microsoft,
16 Okta,
18 Generic,
20}
21
22impl std::fmt::Display for ProviderType {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 ProviderType::Google => write!(f, "google"),
26 ProviderType::Microsoft => write!(f, "microsoft"),
27 ProviderType::Okta => write!(f, "okta"),
28 ProviderType::Generic => write!(f, "generic"),
29 }
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ProviderConfig {
36 pub provider_type: ProviderType,
38 pub client_id: String,
40 #[serde(skip_serializing)]
42 pub client_secret: String,
43 pub issuer_url: String,
45 pub authorization_endpoint: Option<String>,
47 pub token_endpoint: Option<String>,
49 pub userinfo_endpoint: Option<String>,
51 pub device_authorization_endpoint: Option<String>,
53 pub jwks_uri: Option<String>,
55 pub scopes: Vec<String>,
57 pub allowed_domains: Vec<String>,
59 pub required_groups: Vec<String>,
61 pub group_claim: String,
63 pub email_claim: String,
65 pub name_claim: String,
67 pub additional_params: HashMap<String, String>,
69}
70
71impl ProviderConfig {
72 pub fn google(client_id: &str, client_secret: &str, allowed_domain: Option<&str>) -> Self {
74 let mut config = Self {
75 provider_type: ProviderType::Google,
76 client_id: client_id.to_string(),
77 client_secret: client_secret.to_string(),
78 issuer_url: "https://accounts.google.com".to_string(),
79 authorization_endpoint: Some("https://accounts.google.com/o/oauth2/v2/auth".to_string()),
80 token_endpoint: Some("https://oauth2.googleapis.com/token".to_string()),
81 userinfo_endpoint: Some("https://openidconnect.googleapis.com/v1/userinfo".to_string()),
82 device_authorization_endpoint: Some("https://oauth2.googleapis.com/device/code".to_string()),
83 jwks_uri: Some("https://www.googleapis.com/oauth2/v3/certs".to_string()),
84 scopes: vec![
85 "openid".to_string(),
86 "email".to_string(),
87 "profile".to_string(),
88 ],
89 allowed_domains: vec![],
90 required_groups: vec![],
91 group_claim: "groups".to_string(),
92 email_claim: "email".to_string(),
93 name_claim: "name".to_string(),
94 additional_params: HashMap::new(),
95 };
96
97 if let Some(domain) = allowed_domain {
98 config.allowed_domains.push(domain.to_string());
99 config.additional_params.insert("hd".to_string(), domain.to_string());
101 }
102
103 config
104 }
105
106 pub fn microsoft(client_id: &str, client_secret: &str, tenant_id: &str) -> Self {
108 let base_url = format!("https://login.microsoftonline.com/{}", tenant_id);
109
110 Self {
111 provider_type: ProviderType::Microsoft,
112 client_id: client_id.to_string(),
113 client_secret: client_secret.to_string(),
114 issuer_url: format!("{}/v2.0", base_url),
115 authorization_endpoint: Some(format!("{}/oauth2/v2.0/authorize", base_url)),
116 token_endpoint: Some(format!("{}/oauth2/v2.0/token", base_url)),
117 userinfo_endpoint: Some("https://graph.microsoft.com/oidc/userinfo".to_string()),
118 device_authorization_endpoint: Some(format!("{}/oauth2/v2.0/devicecode", base_url)),
119 jwks_uri: Some(format!("{}/discovery/v2.0/keys", base_url)),
120 scopes: vec![
121 "openid".to_string(),
122 "email".to_string(),
123 "profile".to_string(),
124 "offline_access".to_string(),
125 ],
126 allowed_domains: vec![],
127 required_groups: vec![],
128 group_claim: "groups".to_string(),
129 email_claim: "email".to_string(),
130 name_claim: "name".to_string(),
131 additional_params: HashMap::new(),
132 }
133 }
134
135 pub fn okta(client_id: &str, client_secret: &str, domain: &str, auth_server_id: Option<&str>) -> Self {
137 let auth_server = auth_server_id.unwrap_or("default");
138 let base_url = format!("https://{}/oauth2/{}", domain, auth_server);
139
140 Self {
141 provider_type: ProviderType::Okta,
142 client_id: client_id.to_string(),
143 client_secret: client_secret.to_string(),
144 issuer_url: base_url.clone(),
145 authorization_endpoint: Some(format!("{}/v1/authorize", base_url)),
146 token_endpoint: Some(format!("{}/v1/token", base_url)),
147 userinfo_endpoint: Some(format!("{}/v1/userinfo", base_url)),
148 device_authorization_endpoint: Some(format!("{}/v1/device/authorize", base_url)),
149 jwks_uri: Some(format!("{}/v1/keys", base_url)),
150 scopes: vec![
151 "openid".to_string(),
152 "email".to_string(),
153 "profile".to_string(),
154 "groups".to_string(),
155 "offline_access".to_string(),
156 ],
157 allowed_domains: vec![],
158 required_groups: vec![],
159 group_claim: "groups".to_string(),
160 email_claim: "email".to_string(),
161 name_claim: "name".to_string(),
162 additional_params: HashMap::new(),
163 }
164 }
165
166 pub fn generic(client_id: &str, client_secret: &str, issuer_url: &str) -> Self {
168 Self {
169 provider_type: ProviderType::Generic,
170 client_id: client_id.to_string(),
171 client_secret: client_secret.to_string(),
172 issuer_url: issuer_url.to_string(),
173 authorization_endpoint: None,
174 token_endpoint: None,
175 userinfo_endpoint: None,
176 device_authorization_endpoint: None,
177 jwks_uri: None,
178 scopes: vec![
179 "openid".to_string(),
180 "email".to_string(),
181 "profile".to_string(),
182 ],
183 allowed_domains: vec![],
184 required_groups: vec![],
185 group_claim: "groups".to_string(),
186 email_claim: "email".to_string(),
187 name_claim: "name".to_string(),
188 additional_params: HashMap::new(),
189 }
190 }
191
192 pub fn validate(&self) -> Result<()> {
194 if self.client_id.is_empty() {
195 return Err(AuthError::ConfigError("client_id is required".into()));
196 }
197 if self.client_secret.is_empty() {
198 return Err(AuthError::ConfigError("client_secret is required".into()));
199 }
200 if self.issuer_url.is_empty() {
201 return Err(AuthError::ConfigError("issuer_url is required".into()));
202 }
203 Ok(())
204 }
205}
206
207pub struct OAuthProvider {
209 config: ProviderConfig,
211 http_client: reqwest::Client,
213 metadata: Option<OidcMetadata>,
215}
216
217#[derive(Debug, Clone, Deserialize)]
219pub struct OidcMetadata {
220 pub issuer: String,
221 pub authorization_endpoint: String,
222 pub token_endpoint: String,
223 #[serde(default)]
224 pub userinfo_endpoint: Option<String>,
225 #[serde(default)]
226 pub device_authorization_endpoint: Option<String>,
227 pub jwks_uri: String,
228 #[serde(default)]
229 pub scopes_supported: Vec<String>,
230 #[serde(default)]
231 pub response_types_supported: Vec<String>,
232 #[serde(default)]
233 pub grant_types_supported: Vec<String>,
234}
235
236impl OAuthProvider {
237 pub fn new(config: ProviderConfig) -> Self {
239 Self {
240 config,
241 http_client: reqwest::Client::new(),
242 metadata: None,
243 }
244 }
245
246 pub fn config(&self) -> &ProviderConfig {
248 &self.config
249 }
250
251 pub async fn discover(&mut self) -> Result<()> {
253 let discovery_url = format!("{}/.well-known/openid-configuration", self.config.issuer_url);
254
255 let response = self.http_client
256 .get(&discovery_url)
257 .send()
258 .await?;
259
260 if !response.status().is_success() {
261 return Err(AuthError::DiscoveryFailed(format!(
262 "HTTP {}: {}",
263 response.status(),
264 response.text().await.unwrap_or_default()
265 )));
266 }
267
268 let metadata: OidcMetadata = response.json().await?;
269
270 self.config.authorization_endpoint = Some(metadata.authorization_endpoint.clone());
272 self.config.token_endpoint = Some(metadata.token_endpoint.clone());
273 self.config.userinfo_endpoint = metadata.userinfo_endpoint.clone();
274 self.config.device_authorization_endpoint = metadata.device_authorization_endpoint.clone();
275 self.config.jwks_uri = Some(metadata.jwks_uri.clone());
276
277 self.metadata = Some(metadata);
278
279 Ok(())
280 }
281
282 pub fn authorization_endpoint(&self) -> Result<&str> {
284 self.config.authorization_endpoint
285 .as_deref()
286 .ok_or_else(|| AuthError::ConfigError("authorization_endpoint not configured".into()))
287 }
288
289 pub fn token_endpoint(&self) -> Result<&str> {
291 self.config.token_endpoint
292 .as_deref()
293 .ok_or_else(|| AuthError::ConfigError("token_endpoint not configured".into()))
294 }
295
296 pub fn device_authorization_endpoint(&self) -> Result<&str> {
298 self.config.device_authorization_endpoint
299 .as_deref()
300 .ok_or_else(|| AuthError::ConfigError("device_authorization_endpoint not configured".into()))
301 }
302
303 pub fn is_domain_allowed(&self, email: &str) -> bool {
305 if self.config.allowed_domains.is_empty() {
306 return true;
307 }
308
309 let domain = email.split('@').nth(1).unwrap_or("");
310 self.config.allowed_domains.iter().any(|d| d == domain)
311 }
312
313 pub fn is_in_required_group(&self, groups: &[String]) -> bool {
315 if self.config.required_groups.is_empty() {
316 return true;
317 }
318
319 self.config.required_groups.iter().any(|g| groups.contains(g))
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_google_config() {
329 let config = ProviderConfig::google("client-id", "client-secret", Some("example.com"));
330
331 assert_eq!(config.provider_type, ProviderType::Google);
332 assert_eq!(config.allowed_domains, vec!["example.com"]);
333 assert!(config.additional_params.contains_key("hd"));
334 }
335
336 #[test]
337 fn test_microsoft_config() {
338 let config = ProviderConfig::microsoft("client-id", "client-secret", "tenant-id");
339
340 assert_eq!(config.provider_type, ProviderType::Microsoft);
341 assert!(config.issuer_url.contains("tenant-id"));
342 }
343
344 #[test]
345 fn test_domain_check() {
346 let config = ProviderConfig::google("id", "secret", Some("example.com"));
347 let provider = OAuthProvider::new(config);
348
349 assert!(provider.is_domain_allowed("user@example.com"));
350 assert!(!provider.is_domain_allowed("user@other.com"));
351 }
352}