1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::net::{IpAddr, Ipv4Addr};
6use secrecy::{ExposeSecret, Secret, SecretString};
7use tracing::{debug, warn};
8use url::Url;
9
10use crate::{AuthError, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum ProviderType {
16 Google,
18 Microsoft,
20 Okta,
22 Generic,
24}
25
26impl std::fmt::Display for ProviderType {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 ProviderType::Google => write!(f, "google"),
30 ProviderType::Microsoft => write!(f, "microsoft"),
31 ProviderType::Okta => write!(f, "okta"),
32 ProviderType::Generic => write!(f, "generic"),
33 }
34 }
35}
36
37#[derive(Clone)]
39pub struct ProviderConfig {
40 pub provider_type: ProviderType,
42 pub client_id: String,
44 pub client_secret: SecretString,
46 pub issuer_url: String,
48 pub authorization_endpoint: Option<String>,
50 pub token_endpoint: Option<String>,
52 pub userinfo_endpoint: Option<String>,
54 pub device_authorization_endpoint: Option<String>,
56 pub jwks_uri: Option<String>,
58 pub scopes: Vec<String>,
60 pub allowed_domains: Vec<String>,
62 pub required_groups: Vec<String>,
64 pub group_claim: String,
66 pub email_claim: String,
68 pub name_claim: String,
70 pub additional_params: HashMap<String, String>,
72}
73
74impl ProviderConfig {
75 pub fn google(client_id: &str, client_secret: &str, allowed_domain: Option<&str>) -> Self {
77 let mut config = Self {
78 provider_type: ProviderType::Google,
79 client_id: client_id.to_string(),
80 client_secret: SecretString::new(client_secret.to_string()),
81 issuer_url: "https://accounts.google.com".to_string(),
82 authorization_endpoint: Some("https://accounts.google.com/o/oauth2/v2/auth".to_string()),
83 token_endpoint: Some("https://oauth2.googleapis.com/token".to_string()),
84 userinfo_endpoint: Some("https://openidconnect.googleapis.com/v1/userinfo".to_string()),
85 device_authorization_endpoint: Some("https://oauth2.googleapis.com/device/code".to_string()),
86 jwks_uri: Some("https://www.googleapis.com/oauth2/v3/certs".to_string()),
87 scopes: vec![
88 "openid".to_string(),
89 "email".to_string(),
90 "profile".to_string(),
91 ],
92 allowed_domains: vec![],
93 required_groups: vec![],
94 group_claim: "groups".to_string(),
95 email_claim: "email".to_string(),
96 name_claim: "name".to_string(),
97 additional_params: HashMap::new(),
98 };
99
100 if let Some(domain) = allowed_domain {
101 config.allowed_domains.push(domain.to_string());
102 config.additional_params.insert("hd".to_string(), domain.to_string());
104 }
105
106 config
107 }
108
109 pub fn microsoft(client_id: &str, client_secret: &str, tenant_id: &str) -> Self {
111 let base_url = format!("https://login.microsoftonline.com/{}", tenant_id);
112
113 Self {
114 provider_type: ProviderType::Microsoft,
115 client_id: client_id.to_string(),
116 client_secret: SecretString::new(client_secret.to_string()),
117 issuer_url: format!("{}/v2.0", base_url),
118 authorization_endpoint: Some(format!("{}/oauth2/v2.0/authorize", base_url)),
119 token_endpoint: Some(format!("{}/oauth2/v2.0/token", base_url)),
120 userinfo_endpoint: Some("https://graph.microsoft.com/oidc/userinfo".to_string()),
121 device_authorization_endpoint: Some(format!("{}/oauth2/v2.0/devicecode", base_url)),
122 jwks_uri: Some(format!("{}/discovery/v2.0/keys", base_url)),
123 scopes: vec![
124 "openid".to_string(),
125 "email".to_string(),
126 "profile".to_string(),
127 "offline_access".to_string(),
128 ],
129 allowed_domains: vec![],
130 required_groups: vec![],
131 group_claim: "groups".to_string(),
132 email_claim: "email".to_string(),
133 name_claim: "name".to_string(),
134 additional_params: HashMap::new(),
135 }
136 }
137
138 pub fn okta(client_id: &str, client_secret: &str, domain: &str, auth_server_id: Option<&str>) -> Self {
140 let auth_server = auth_server_id.unwrap_or("default");
141 let base_url = format!("https://{}/oauth2/{}", domain, auth_server);
142
143 Self {
144 provider_type: ProviderType::Okta,
145 client_id: client_id.to_string(),
146 client_secret: SecretString::new(client_secret.to_string()),
147 issuer_url: base_url.clone(),
148 authorization_endpoint: Some(format!("{}/v1/authorize", base_url)),
149 token_endpoint: Some(format!("{}/v1/token", base_url)),
150 userinfo_endpoint: Some(format!("{}/v1/userinfo", base_url)),
151 device_authorization_endpoint: Some(format!("{}/v1/device/authorize", base_url)),
152 jwks_uri: Some(format!("{}/v1/keys", base_url)),
153 scopes: vec![
154 "openid".to_string(),
155 "email".to_string(),
156 "profile".to_string(),
157 "groups".to_string(),
158 "offline_access".to_string(),
159 ],
160 allowed_domains: vec![],
161 required_groups: vec![],
162 group_claim: "groups".to_string(),
163 email_claim: "email".to_string(),
164 name_claim: "name".to_string(),
165 additional_params: HashMap::new(),
166 }
167 }
168
169 pub fn generic(client_id: &str, client_secret: &str, issuer_url: &str) -> Self {
171 Self {
172 provider_type: ProviderType::Generic,
173 client_id: client_id.to_string(),
174 client_secret: SecretString::new(client_secret.to_string()),
175 issuer_url: issuer_url.to_string(),
176 authorization_endpoint: None,
177 token_endpoint: None,
178 userinfo_endpoint: None,
179 device_authorization_endpoint: None,
180 jwks_uri: None,
181 scopes: vec![
182 "openid".to_string(),
183 "email".to_string(),
184 "profile".to_string(),
185 ],
186 allowed_domains: vec![],
187 required_groups: vec![],
188 group_claim: "groups".to_string(),
189 email_claim: "email".to_string(),
190 name_claim: "name".to_string(),
191 additional_params: HashMap::new(),
192 }
193 }
194
195 pub fn validate(&self) -> Result<()> {
197 if self.client_id.is_empty() {
198 return Err(AuthError::ConfigError("client_id is required".into()));
199 }
200 if self.client_secret.expose_secret().is_empty() {
201 return Err(AuthError::ConfigError("client_secret is required".into()));
202 }
203 if self.issuer_url.is_empty() {
204 return Err(AuthError::ConfigError("issuer_url is required".into()));
205 }
206 Ok(())
207 }
208}
209
210pub struct OAuthProvider {
212 config: ProviderConfig,
214 http_client: reqwest::Client,
216 metadata: Option<OidcMetadata>,
218}
219
220#[derive(Debug, Clone, Deserialize)]
222pub struct OidcMetadata {
223 pub issuer: String,
224 pub authorization_endpoint: String,
225 pub token_endpoint: String,
226 #[serde(default)]
227 pub userinfo_endpoint: Option<String>,
228 #[serde(default)]
229 pub device_authorization_endpoint: Option<String>,
230 pub jwks_uri: String,
231 #[serde(default)]
232 pub scopes_supported: Vec<String>,
233 #[serde(default)]
234 pub response_types_supported: Vec<String>,
235 #[serde(default)]
236 pub grant_types_supported: Vec<String>,
237}
238
239fn is_private_ip(ip: IpAddr) -> bool {
241 match ip {
242 IpAddr::V4(ipv4) => {
243 let octets = ipv4.octets();
244 (octets[0] == 10) ||
246 (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) ||
248 (octets[0] == 192 && octets[1] == 168) ||
250 (octets[0] == 127) ||
252 (octets[0] == 169 && octets[1] == 254) ||
254 (octets[0] == 0)
256 }
257 IpAddr::V6(ipv6) => {
258 ipv6.is_loopback() ||
260 ipv6.is_unicast_link_local() ||
262 (ipv6.segments()[0] & 0xfe00) == 0xfc00
264 }
265 }
266}
267
268fn validate_url_for_ssrf(url_str: &str) -> Result<()> {
270 let url = Url::parse(url_str)
271 .map_err(|e| AuthError::ConfigError(format!("Invalid URL: {}", e)))?;
272
273 if url.scheme() != "https" {
275 return Err(AuthError::ConfigError(
276 "Only HTTPS URLs are allowed for security".into()
277 ));
278 }
279
280 if let Some(host) = url.host() {
282 match host {
283 url::Host::Domain(domain) => {
284 if domain == "localhost" || domain.ends_with(".localhost") {
286 return Err(AuthError::ConfigError(
287 "localhost URLs are not allowed".into()
288 ));
289 }
290 if domain.ends_with(".local") {
292 return Err(AuthError::ConfigError(
293 ".local domains are not allowed".into()
294 ));
295 }
296 }
297 url::Host::Ipv4(ip) => {
298 if is_private_ip(IpAddr::V4(ip)) {
299 return Err(AuthError::ConfigError(
300 "Private IP addresses are not allowed".into()
301 ));
302 }
303 }
304 url::Host::Ipv6(ip) => {
305 if is_private_ip(IpAddr::V6(ip)) {
306 return Err(AuthError::ConfigError(
307 "Private IP addresses are not allowed".into()
308 ));
309 }
310 }
311 }
312 }
313
314 Ok(())
315}
316
317impl OAuthProvider {
318 pub fn new(config: ProviderConfig) -> Self {
320 Self {
321 config,
322 http_client: reqwest::Client::new(),
323 metadata: None,
324 }
325 }
326
327 pub fn config(&self) -> &ProviderConfig {
329 &self.config
330 }
331
332 pub async fn discover(&mut self) -> Result<()> {
334 let discovery_url = format!("{}/.well-known/openid-configuration", self.config.issuer_url);
335
336 validate_url_for_ssrf(&discovery_url)?;
338
339 debug!("Performing OIDC discovery for {}", self.config.issuer_url);
340
341 let response = self.http_client
342 .get(&discovery_url)
343 .send()
344 .await
345 .map_err(|e| {
346 warn!("OIDC discovery failed: {}", e);
347 AuthError::DiscoveryFailed("Failed to connect to provider".into())
348 })?;
349
350 if !response.status().is_success() {
351 let status = response.status();
352 warn!("OIDC discovery returned status {}", status);
353 return Err(AuthError::DiscoveryFailed(
354 "Provider discovery failed".into()
355 ));
356 }
357
358 let metadata: OidcMetadata = response.json().await
359 .map_err(|e| {
360 warn!("Failed to parse OIDC metadata: {}", e);
361 AuthError::DiscoveryFailed("Invalid provider response".into())
362 })?;
363
364 validate_url_for_ssrf(&metadata.authorization_endpoint)?;
366 validate_url_for_ssrf(&metadata.token_endpoint)?;
367 validate_url_for_ssrf(&metadata.jwks_uri)?;
368 if let Some(ref uri) = metadata.userinfo_endpoint {
369 validate_url_for_ssrf(uri)?;
370 }
371
372 self.config.authorization_endpoint = Some(metadata.authorization_endpoint.clone());
374 self.config.token_endpoint = Some(metadata.token_endpoint.clone());
375 self.config.userinfo_endpoint = metadata.userinfo_endpoint.clone();
376 self.config.device_authorization_endpoint = metadata.device_authorization_endpoint.clone();
377 self.config.jwks_uri = Some(metadata.jwks_uri.clone());
378
379 self.metadata = Some(metadata);
380
381 Ok(())
382 }
383
384 pub fn authorization_endpoint(&self) -> Result<&str> {
386 self.config.authorization_endpoint
387 .as_deref()
388 .ok_or_else(|| AuthError::ConfigError("authorization_endpoint not configured".into()))
389 }
390
391 pub fn token_endpoint(&self) -> Result<&str> {
393 self.config.token_endpoint
394 .as_deref()
395 .ok_or_else(|| AuthError::ConfigError("token_endpoint not configured".into()))
396 }
397
398 pub fn device_authorization_endpoint(&self) -> Result<&str> {
400 self.config.device_authorization_endpoint
401 .as_deref()
402 .ok_or_else(|| AuthError::ConfigError("device_authorization_endpoint not configured".into()))
403 }
404
405 pub fn is_domain_allowed(&self, email: &str) -> bool {
407 if self.config.allowed_domains.is_empty() {
408 return true;
409 }
410
411 if addr::parse_email_address(email).is_err() {
413 debug!("Invalid email format: {}", email);
414 return false;
415 }
416
417 let email_domain = match email.rsplit_once('@') {
419 Some((_, domain)) => domain.to_lowercase(),
420 None => return false,
421 };
422
423 self.config.allowed_domains.iter().any(|d| {
424 d.to_lowercase() == email_domain
425 })
426 }
427
428 pub fn is_in_required_group(&self, groups: &[String]) -> bool {
430 if self.config.required_groups.is_empty() {
431 return true;
432 }
433
434 self.config.required_groups.iter().any(|g| groups.contains(g))
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_google_config() {
444 let config = ProviderConfig::google("client-id", "client-secret", Some("example.com"));
445
446 assert_eq!(config.provider_type, ProviderType::Google);
447 assert_eq!(config.allowed_domains, vec!["example.com"]);
448 assert!(config.additional_params.contains_key("hd"));
449 }
450
451 #[test]
452 fn test_microsoft_config() {
453 let config = ProviderConfig::microsoft("client-id", "client-secret", "tenant-id");
454
455 assert_eq!(config.provider_type, ProviderType::Microsoft);
456 assert!(config.issuer_url.contains("tenant-id"));
457 }
458
459 #[test]
460 fn test_domain_check() {
461 let config = ProviderConfig::google("id", "secret", Some("example.com"));
462 let provider = OAuthProvider::new(config);
463
464 assert!(provider.is_domain_allowed("user@example.com"));
465 assert!(!provider.is_domain_allowed("user@other.com"));
466 }
467}