Skip to main content

corevpn_auth/
provider.rs

1//! OAuth2/OIDC Provider Configuration
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::net::{IpAddr, Ipv4Addr};
6use secrecy::{ExposeSecret, Secret, SecretString};
7use tracing::{debug, warn};
8use url::Url;
9
10use crate::{AuthError, Result};
11
12/// OAuth2 provider type
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum ProviderType {
16    /// Google Workspace
17    Google,
18    /// Microsoft Entra ID
19    Microsoft,
20    /// Okta
21    Okta,
22    /// Generic OIDC
23    Generic,
24}
25
26impl std::fmt::Display for ProviderType {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            ProviderType::Google => write!(f, "google"),
30            ProviderType::Microsoft => write!(f, "microsoft"),
31            ProviderType::Okta => write!(f, "okta"),
32            ProviderType::Generic => write!(f, "generic"),
33        }
34    }
35}
36
37/// Provider configuration
38#[derive(Clone)]
39pub struct ProviderConfig {
40    /// Provider type
41    pub provider_type: ProviderType,
42    /// OAuth2 Client ID
43    pub client_id: String,
44    /// OAuth2 Client Secret (encrypted at rest)
45    pub client_secret: SecretString,
46    /// Issuer URL (for OIDC discovery)
47    pub issuer_url: String,
48    /// Authorization endpoint (optional, discovered via OIDC)
49    pub authorization_endpoint: Option<String>,
50    /// Token endpoint (optional, discovered via OIDC)
51    pub token_endpoint: Option<String>,
52    /// UserInfo endpoint (optional)
53    pub userinfo_endpoint: Option<String>,
54    /// Device authorization endpoint (for device flow)
55    pub device_authorization_endpoint: Option<String>,
56    /// JWKS URI for token validation
57    pub jwks_uri: Option<String>,
58    /// Scopes to request
59    pub scopes: Vec<String>,
60    /// Allowed domains (empty = all allowed)
61    pub allowed_domains: Vec<String>,
62    /// Required groups (user must be in at least one)
63    pub required_groups: Vec<String>,
64    /// Group claim name in ID token
65    pub group_claim: String,
66    /// Email claim name in ID token
67    pub email_claim: String,
68    /// Name claim name in ID token
69    pub name_claim: String,
70    /// Additional parameters for authorization URL
71    pub additional_params: HashMap<String, String>,
72}
73
74impl ProviderConfig {
75    /// Create a Google provider configuration
76    pub fn google(client_id: &str, client_secret: &str, allowed_domain: Option<&str>) -> Self {
77        let mut config = Self {
78            provider_type: ProviderType::Google,
79            client_id: client_id.to_string(),
80            client_secret: SecretString::new(client_secret.to_string()),
81            issuer_url: "https://accounts.google.com".to_string(),
82            authorization_endpoint: Some("https://accounts.google.com/o/oauth2/v2/auth".to_string()),
83            token_endpoint: Some("https://oauth2.googleapis.com/token".to_string()),
84            userinfo_endpoint: Some("https://openidconnect.googleapis.com/v1/userinfo".to_string()),
85            device_authorization_endpoint: Some("https://oauth2.googleapis.com/device/code".to_string()),
86            jwks_uri: Some("https://www.googleapis.com/oauth2/v3/certs".to_string()),
87            scopes: vec![
88                "openid".to_string(),
89                "email".to_string(),
90                "profile".to_string(),
91            ],
92            allowed_domains: vec![],
93            required_groups: vec![],
94            group_claim: "groups".to_string(),
95            email_claim: "email".to_string(),
96            name_claim: "name".to_string(),
97            additional_params: HashMap::new(),
98        };
99
100        if let Some(domain) = allowed_domain {
101            config.allowed_domains.push(domain.to_string());
102            // Add hd parameter to restrict to domain
103            config.additional_params.insert("hd".to_string(), domain.to_string());
104        }
105
106        config
107    }
108
109    /// Create a Microsoft provider configuration
110    pub fn microsoft(client_id: &str, client_secret: &str, tenant_id: &str) -> Self {
111        let base_url = format!("https://login.microsoftonline.com/{}", tenant_id);
112
113        Self {
114            provider_type: ProviderType::Microsoft,
115            client_id: client_id.to_string(),
116            client_secret: SecretString::new(client_secret.to_string()),
117            issuer_url: format!("{}/v2.0", base_url),
118            authorization_endpoint: Some(format!("{}/oauth2/v2.0/authorize", base_url)),
119            token_endpoint: Some(format!("{}/oauth2/v2.0/token", base_url)),
120            userinfo_endpoint: Some("https://graph.microsoft.com/oidc/userinfo".to_string()),
121            device_authorization_endpoint: Some(format!("{}/oauth2/v2.0/devicecode", base_url)),
122            jwks_uri: Some(format!("{}/discovery/v2.0/keys", base_url)),
123            scopes: vec![
124                "openid".to_string(),
125                "email".to_string(),
126                "profile".to_string(),
127                "offline_access".to_string(),
128            ],
129            allowed_domains: vec![],
130            required_groups: vec![],
131            group_claim: "groups".to_string(),
132            email_claim: "email".to_string(),
133            name_claim: "name".to_string(),
134            additional_params: HashMap::new(),
135        }
136    }
137
138    /// Create an Okta provider configuration
139    pub fn okta(client_id: &str, client_secret: &str, domain: &str, auth_server_id: Option<&str>) -> Self {
140        let auth_server = auth_server_id.unwrap_or("default");
141        let base_url = format!("https://{}/oauth2/{}", domain, auth_server);
142
143        Self {
144            provider_type: ProviderType::Okta,
145            client_id: client_id.to_string(),
146            client_secret: SecretString::new(client_secret.to_string()),
147            issuer_url: base_url.clone(),
148            authorization_endpoint: Some(format!("{}/v1/authorize", base_url)),
149            token_endpoint: Some(format!("{}/v1/token", base_url)),
150            userinfo_endpoint: Some(format!("{}/v1/userinfo", base_url)),
151            device_authorization_endpoint: Some(format!("{}/v1/device/authorize", base_url)),
152            jwks_uri: Some(format!("{}/v1/keys", base_url)),
153            scopes: vec![
154                "openid".to_string(),
155                "email".to_string(),
156                "profile".to_string(),
157                "groups".to_string(),
158                "offline_access".to_string(),
159            ],
160            allowed_domains: vec![],
161            required_groups: vec![],
162            group_claim: "groups".to_string(),
163            email_claim: "email".to_string(),
164            name_claim: "name".to_string(),
165            additional_params: HashMap::new(),
166        }
167    }
168
169    /// Create a generic OIDC provider configuration
170    pub fn generic(client_id: &str, client_secret: &str, issuer_url: &str) -> Self {
171        Self {
172            provider_type: ProviderType::Generic,
173            client_id: client_id.to_string(),
174            client_secret: SecretString::new(client_secret.to_string()),
175            issuer_url: issuer_url.to_string(),
176            authorization_endpoint: None,
177            token_endpoint: None,
178            userinfo_endpoint: None,
179            device_authorization_endpoint: None,
180            jwks_uri: None,
181            scopes: vec![
182                "openid".to_string(),
183                "email".to_string(),
184                "profile".to_string(),
185            ],
186            allowed_domains: vec![],
187            required_groups: vec![],
188            group_claim: "groups".to_string(),
189            email_claim: "email".to_string(),
190            name_claim: "name".to_string(),
191            additional_params: HashMap::new(),
192        }
193    }
194
195    /// Validate configuration
196    pub fn validate(&self) -> Result<()> {
197        if self.client_id.is_empty() {
198            return Err(AuthError::ConfigError("client_id is required".into()));
199        }
200        if self.client_secret.expose_secret().is_empty() {
201            return Err(AuthError::ConfigError("client_secret is required".into()));
202        }
203        if self.issuer_url.is_empty() {
204            return Err(AuthError::ConfigError("issuer_url is required".into()));
205        }
206        Ok(())
207    }
208}
209
210/// OAuth2 Provider with runtime state
211pub struct OAuthProvider {
212    /// Configuration
213    config: ProviderConfig,
214    /// HTTP client
215    http_client: reqwest::Client,
216    /// Discovered metadata (if using OIDC discovery)
217    metadata: Option<OidcMetadata>,
218}
219
220/// OIDC Discovery metadata
221#[derive(Debug, Clone, Deserialize)]
222pub struct OidcMetadata {
223    pub issuer: String,
224    pub authorization_endpoint: String,
225    pub token_endpoint: String,
226    #[serde(default)]
227    pub userinfo_endpoint: Option<String>,
228    #[serde(default)]
229    pub device_authorization_endpoint: Option<String>,
230    pub jwks_uri: String,
231    #[serde(default)]
232    pub scopes_supported: Vec<String>,
233    #[serde(default)]
234    pub response_types_supported: Vec<String>,
235    #[serde(default)]
236    pub grant_types_supported: Vec<String>,
237}
238
239/// Check if an IP address is in a private range
240fn is_private_ip(ip: IpAddr) -> bool {
241    match ip {
242        IpAddr::V4(ipv4) => {
243            let octets = ipv4.octets();
244            // 10.0.0.0/8
245            (octets[0] == 10) ||
246            // 172.16.0.0/12
247            (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) ||
248            // 192.168.0.0/16
249            (octets[0] == 192 && octets[1] == 168) ||
250            // 127.0.0.0/8 (loopback)
251            (octets[0] == 127) ||
252            // 169.254.0.0/16 (link-local)
253            (octets[0] == 169 && octets[1] == 254) ||
254            // 0.0.0.0/8
255            (octets[0] == 0)
256        }
257        IpAddr::V6(ipv6) => {
258            // ::1 (loopback)
259            ipv6.is_loopback() ||
260            // fe80::/10 (link-local)
261            ipv6.is_unicast_link_local() ||
262            // fc00::/7 (unique local)
263            (ipv6.segments()[0] & 0xfe00) == 0xfc00
264        }
265    }
266}
267
268/// Validate URL for SSRF protection
269fn validate_url_for_ssrf(url_str: &str) -> Result<()> {
270    let url = Url::parse(url_str)
271        .map_err(|e| AuthError::ConfigError(format!("Invalid URL: {}", e)))?;
272
273    // Only allow HTTPS
274    if url.scheme() != "https" {
275        return Err(AuthError::ConfigError(
276            "Only HTTPS URLs are allowed for security".into()
277        ));
278    }
279
280    // Resolve hostname and check for private IPs
281    if let Some(host) = url.host() {
282        match host {
283            url::Host::Domain(domain) => {
284                // Block localhost variants
285                if domain == "localhost" || domain.ends_with(".localhost") {
286                    return Err(AuthError::ConfigError(
287                        "localhost URLs are not allowed".into()
288                    ));
289                }
290                // Block .local domains (mDNS)
291                if domain.ends_with(".local") {
292                    return Err(AuthError::ConfigError(
293                        ".local domains are not allowed".into()
294                    ));
295                }
296            }
297            url::Host::Ipv4(ip) => {
298                if is_private_ip(IpAddr::V4(ip)) {
299                    return Err(AuthError::ConfigError(
300                        "Private IP addresses are not allowed".into()
301                    ));
302                }
303            }
304            url::Host::Ipv6(ip) => {
305                if is_private_ip(IpAddr::V6(ip)) {
306                    return Err(AuthError::ConfigError(
307                        "Private IP addresses are not allowed".into()
308                    ));
309                }
310            }
311        }
312    }
313
314    Ok(())
315}
316
317impl OAuthProvider {
318    /// Create a new provider
319    pub fn new(config: ProviderConfig) -> Self {
320        Self {
321            config,
322            http_client: reqwest::Client::new(),
323            metadata: None,
324        }
325    }
326
327    /// Get the configuration
328    pub fn config(&self) -> &ProviderConfig {
329        &self.config
330    }
331
332    /// Perform OIDC discovery
333    pub async fn discover(&mut self) -> Result<()> {
334        let discovery_url = format!("{}/.well-known/openid-configuration", self.config.issuer_url);
335
336        // Validate URL for SSRF protection
337        validate_url_for_ssrf(&discovery_url)?;
338
339        debug!("Performing OIDC discovery for {}", self.config.issuer_url);
340
341        let response = self.http_client
342            .get(&discovery_url)
343            .send()
344            .await
345            .map_err(|e| {
346                warn!("OIDC discovery failed: {}", e);
347                AuthError::DiscoveryFailed("Failed to connect to provider".into())
348            })?;
349
350        if !response.status().is_success() {
351            let status = response.status();
352            warn!("OIDC discovery returned status {}", status);
353            return Err(AuthError::DiscoveryFailed(
354                "Provider discovery failed".into()
355            ));
356        }
357
358        let metadata: OidcMetadata = response.json().await
359            .map_err(|e| {
360                warn!("Failed to parse OIDC metadata: {}", e);
361                AuthError::DiscoveryFailed("Invalid provider response".into())
362            })?;
363
364        // Validate discovered endpoints for SSRF
365        validate_url_for_ssrf(&metadata.authorization_endpoint)?;
366        validate_url_for_ssrf(&metadata.token_endpoint)?;
367        validate_url_for_ssrf(&metadata.jwks_uri)?;
368        if let Some(ref uri) = metadata.userinfo_endpoint {
369            validate_url_for_ssrf(uri)?;
370        }
371
372        // Update config with discovered endpoints
373        self.config.authorization_endpoint = Some(metadata.authorization_endpoint.clone());
374        self.config.token_endpoint = Some(metadata.token_endpoint.clone());
375        self.config.userinfo_endpoint = metadata.userinfo_endpoint.clone();
376        self.config.device_authorization_endpoint = metadata.device_authorization_endpoint.clone();
377        self.config.jwks_uri = Some(metadata.jwks_uri.clone());
378
379        self.metadata = Some(metadata);
380
381        Ok(())
382    }
383
384    /// Get authorization endpoint
385    pub fn authorization_endpoint(&self) -> Result<&str> {
386        self.config.authorization_endpoint
387            .as_deref()
388            .ok_or_else(|| AuthError::ConfigError("authorization_endpoint not configured".into()))
389    }
390
391    /// Get token endpoint
392    pub fn token_endpoint(&self) -> Result<&str> {
393        self.config.token_endpoint
394            .as_deref()
395            .ok_or_else(|| AuthError::ConfigError("token_endpoint not configured".into()))
396    }
397
398    /// Get device authorization endpoint
399    pub fn device_authorization_endpoint(&self) -> Result<&str> {
400        self.config.device_authorization_endpoint
401            .as_deref()
402            .ok_or_else(|| AuthError::ConfigError("device_authorization_endpoint not configured".into()))
403    }
404
405    /// Check if email domain is allowed
406    pub fn is_domain_allowed(&self, email: &str) -> bool {
407        if self.config.allowed_domains.is_empty() {
408            return true;
409        }
410
411        // Parse email properly using addr crate for validation
412        if addr::parse_email_address(email).is_err() {
413            debug!("Invalid email format: {}", email);
414            return false;
415        }
416
417        // Extract domain part after @ (already validated by addr crate)
418        let email_domain = match email.rsplit_once('@') {
419            Some((_, domain)) => domain.to_lowercase(),
420            None => return false,
421        };
422
423        self.config.allowed_domains.iter().any(|d| {
424            d.to_lowercase() == email_domain
425        })
426    }
427
428    /// Check if user is in required groups
429    pub fn is_in_required_group(&self, groups: &[String]) -> bool {
430        if self.config.required_groups.is_empty() {
431            return true;
432        }
433
434        self.config.required_groups.iter().any(|g| groups.contains(g))
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_google_config() {
444        let config = ProviderConfig::google("client-id", "client-secret", Some("example.com"));
445
446        assert_eq!(config.provider_type, ProviderType::Google);
447        assert_eq!(config.allowed_domains, vec!["example.com"]);
448        assert!(config.additional_params.contains_key("hd"));
449    }
450
451    #[test]
452    fn test_microsoft_config() {
453        let config = ProviderConfig::microsoft("client-id", "client-secret", "tenant-id");
454
455        assert_eq!(config.provider_type, ProviderType::Microsoft);
456        assert!(config.issuer_url.contains("tenant-id"));
457    }
458
459    #[test]
460    fn test_domain_check() {
461        let config = ProviderConfig::google("id", "secret", Some("example.com"));
462        let provider = OAuthProvider::new(config);
463
464        assert!(provider.is_domain_allowed("user@example.com"));
465        assert!(!provider.is_domain_allowed("user@other.com"));
466    }
467}