corevpn_auth/
provider.rs

1//! OAuth2/OIDC Provider Configuration
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::{AuthError, Result};
7
8/// OAuth2 provider type
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ProviderType {
12    /// Google Workspace
13    Google,
14    /// Microsoft Entra ID
15    Microsoft,
16    /// Okta
17    Okta,
18    /// Generic OIDC
19    Generic,
20}
21
22impl std::fmt::Display for ProviderType {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            ProviderType::Google => write!(f, "google"),
26            ProviderType::Microsoft => write!(f, "microsoft"),
27            ProviderType::Okta => write!(f, "okta"),
28            ProviderType::Generic => write!(f, "generic"),
29        }
30    }
31}
32
33/// Provider configuration
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ProviderConfig {
36    /// Provider type
37    pub provider_type: ProviderType,
38    /// OAuth2 Client ID
39    pub client_id: String,
40    /// OAuth2 Client Secret (encrypted at rest)
41    #[serde(skip_serializing)]
42    pub client_secret: String,
43    /// Issuer URL (for OIDC discovery)
44    pub issuer_url: String,
45    /// Authorization endpoint (optional, discovered via OIDC)
46    pub authorization_endpoint: Option<String>,
47    /// Token endpoint (optional, discovered via OIDC)
48    pub token_endpoint: Option<String>,
49    /// UserInfo endpoint (optional)
50    pub userinfo_endpoint: Option<String>,
51    /// Device authorization endpoint (for device flow)
52    pub device_authorization_endpoint: Option<String>,
53    /// JWKS URI for token validation
54    pub jwks_uri: Option<String>,
55    /// Scopes to request
56    pub scopes: Vec<String>,
57    /// Allowed domains (empty = all allowed)
58    pub allowed_domains: Vec<String>,
59    /// Required groups (user must be in at least one)
60    pub required_groups: Vec<String>,
61    /// Group claim name in ID token
62    pub group_claim: String,
63    /// Email claim name in ID token
64    pub email_claim: String,
65    /// Name claim name in ID token
66    pub name_claim: String,
67    /// Additional parameters for authorization URL
68    pub additional_params: HashMap<String, String>,
69}
70
71impl ProviderConfig {
72    /// Create a Google provider configuration
73    pub fn google(client_id: &str, client_secret: &str, allowed_domain: Option<&str>) -> Self {
74        let mut config = Self {
75            provider_type: ProviderType::Google,
76            client_id: client_id.to_string(),
77            client_secret: client_secret.to_string(),
78            issuer_url: "https://accounts.google.com".to_string(),
79            authorization_endpoint: Some("https://accounts.google.com/o/oauth2/v2/auth".to_string()),
80            token_endpoint: Some("https://oauth2.googleapis.com/token".to_string()),
81            userinfo_endpoint: Some("https://openidconnect.googleapis.com/v1/userinfo".to_string()),
82            device_authorization_endpoint: Some("https://oauth2.googleapis.com/device/code".to_string()),
83            jwks_uri: Some("https://www.googleapis.com/oauth2/v3/certs".to_string()),
84            scopes: vec![
85                "openid".to_string(),
86                "email".to_string(),
87                "profile".to_string(),
88            ],
89            allowed_domains: vec![],
90            required_groups: vec![],
91            group_claim: "groups".to_string(),
92            email_claim: "email".to_string(),
93            name_claim: "name".to_string(),
94            additional_params: HashMap::new(),
95        };
96
97        if let Some(domain) = allowed_domain {
98            config.allowed_domains.push(domain.to_string());
99            // Add hd parameter to restrict to domain
100            config.additional_params.insert("hd".to_string(), domain.to_string());
101        }
102
103        config
104    }
105
106    /// Create a Microsoft provider configuration
107    pub fn microsoft(client_id: &str, client_secret: &str, tenant_id: &str) -> Self {
108        let base_url = format!("https://login.microsoftonline.com/{}", tenant_id);
109
110        Self {
111            provider_type: ProviderType::Microsoft,
112            client_id: client_id.to_string(),
113            client_secret: client_secret.to_string(),
114            issuer_url: format!("{}/v2.0", base_url),
115            authorization_endpoint: Some(format!("{}/oauth2/v2.0/authorize", base_url)),
116            token_endpoint: Some(format!("{}/oauth2/v2.0/token", base_url)),
117            userinfo_endpoint: Some("https://graph.microsoft.com/oidc/userinfo".to_string()),
118            device_authorization_endpoint: Some(format!("{}/oauth2/v2.0/devicecode", base_url)),
119            jwks_uri: Some(format!("{}/discovery/v2.0/keys", base_url)),
120            scopes: vec![
121                "openid".to_string(),
122                "email".to_string(),
123                "profile".to_string(),
124                "offline_access".to_string(),
125            ],
126            allowed_domains: vec![],
127            required_groups: vec![],
128            group_claim: "groups".to_string(),
129            email_claim: "email".to_string(),
130            name_claim: "name".to_string(),
131            additional_params: HashMap::new(),
132        }
133    }
134
135    /// Create an Okta provider configuration
136    pub fn okta(client_id: &str, client_secret: &str, domain: &str, auth_server_id: Option<&str>) -> Self {
137        let auth_server = auth_server_id.unwrap_or("default");
138        let base_url = format!("https://{}/oauth2/{}", domain, auth_server);
139
140        Self {
141            provider_type: ProviderType::Okta,
142            client_id: client_id.to_string(),
143            client_secret: client_secret.to_string(),
144            issuer_url: base_url.clone(),
145            authorization_endpoint: Some(format!("{}/v1/authorize", base_url)),
146            token_endpoint: Some(format!("{}/v1/token", base_url)),
147            userinfo_endpoint: Some(format!("{}/v1/userinfo", base_url)),
148            device_authorization_endpoint: Some(format!("{}/v1/device/authorize", base_url)),
149            jwks_uri: Some(format!("{}/v1/keys", base_url)),
150            scopes: vec![
151                "openid".to_string(),
152                "email".to_string(),
153                "profile".to_string(),
154                "groups".to_string(),
155                "offline_access".to_string(),
156            ],
157            allowed_domains: vec![],
158            required_groups: vec![],
159            group_claim: "groups".to_string(),
160            email_claim: "email".to_string(),
161            name_claim: "name".to_string(),
162            additional_params: HashMap::new(),
163        }
164    }
165
166    /// Create a generic OIDC provider configuration
167    pub fn generic(client_id: &str, client_secret: &str, issuer_url: &str) -> Self {
168        Self {
169            provider_type: ProviderType::Generic,
170            client_id: client_id.to_string(),
171            client_secret: client_secret.to_string(),
172            issuer_url: issuer_url.to_string(),
173            authorization_endpoint: None,
174            token_endpoint: None,
175            userinfo_endpoint: None,
176            device_authorization_endpoint: None,
177            jwks_uri: None,
178            scopes: vec![
179                "openid".to_string(),
180                "email".to_string(),
181                "profile".to_string(),
182            ],
183            allowed_domains: vec![],
184            required_groups: vec![],
185            group_claim: "groups".to_string(),
186            email_claim: "email".to_string(),
187            name_claim: "name".to_string(),
188            additional_params: HashMap::new(),
189        }
190    }
191
192    /// Validate configuration
193    pub fn validate(&self) -> Result<()> {
194        if self.client_id.is_empty() {
195            return Err(AuthError::ConfigError("client_id is required".into()));
196        }
197        if self.client_secret.is_empty() {
198            return Err(AuthError::ConfigError("client_secret is required".into()));
199        }
200        if self.issuer_url.is_empty() {
201            return Err(AuthError::ConfigError("issuer_url is required".into()));
202        }
203        Ok(())
204    }
205}
206
207/// OAuth2 Provider with runtime state
208pub struct OAuthProvider {
209    /// Configuration
210    config: ProviderConfig,
211    /// HTTP client
212    http_client: reqwest::Client,
213    /// Discovered metadata (if using OIDC discovery)
214    metadata: Option<OidcMetadata>,
215}
216
217/// OIDC Discovery metadata
218#[derive(Debug, Clone, Deserialize)]
219pub struct OidcMetadata {
220    pub issuer: String,
221    pub authorization_endpoint: String,
222    pub token_endpoint: String,
223    #[serde(default)]
224    pub userinfo_endpoint: Option<String>,
225    #[serde(default)]
226    pub device_authorization_endpoint: Option<String>,
227    pub jwks_uri: String,
228    #[serde(default)]
229    pub scopes_supported: Vec<String>,
230    #[serde(default)]
231    pub response_types_supported: Vec<String>,
232    #[serde(default)]
233    pub grant_types_supported: Vec<String>,
234}
235
236impl OAuthProvider {
237    /// Create a new provider
238    pub fn new(config: ProviderConfig) -> Self {
239        Self {
240            config,
241            http_client: reqwest::Client::new(),
242            metadata: None,
243        }
244    }
245
246    /// Get the configuration
247    pub fn config(&self) -> &ProviderConfig {
248        &self.config
249    }
250
251    /// Perform OIDC discovery
252    pub async fn discover(&mut self) -> Result<()> {
253        let discovery_url = format!("{}/.well-known/openid-configuration", self.config.issuer_url);
254
255        let response = self.http_client
256            .get(&discovery_url)
257            .send()
258            .await?;
259
260        if !response.status().is_success() {
261            return Err(AuthError::DiscoveryFailed(format!(
262                "HTTP {}: {}",
263                response.status(),
264                response.text().await.unwrap_or_default()
265            )));
266        }
267
268        let metadata: OidcMetadata = response.json().await?;
269
270        // Update config with discovered endpoints
271        self.config.authorization_endpoint = Some(metadata.authorization_endpoint.clone());
272        self.config.token_endpoint = Some(metadata.token_endpoint.clone());
273        self.config.userinfo_endpoint = metadata.userinfo_endpoint.clone();
274        self.config.device_authorization_endpoint = metadata.device_authorization_endpoint.clone();
275        self.config.jwks_uri = Some(metadata.jwks_uri.clone());
276
277        self.metadata = Some(metadata);
278
279        Ok(())
280    }
281
282    /// Get authorization endpoint
283    pub fn authorization_endpoint(&self) -> Result<&str> {
284        self.config.authorization_endpoint
285            .as_deref()
286            .ok_or_else(|| AuthError::ConfigError("authorization_endpoint not configured".into()))
287    }
288
289    /// Get token endpoint
290    pub fn token_endpoint(&self) -> Result<&str> {
291        self.config.token_endpoint
292            .as_deref()
293            .ok_or_else(|| AuthError::ConfigError("token_endpoint not configured".into()))
294    }
295
296    /// Get device authorization endpoint
297    pub fn device_authorization_endpoint(&self) -> Result<&str> {
298        self.config.device_authorization_endpoint
299            .as_deref()
300            .ok_or_else(|| AuthError::ConfigError("device_authorization_endpoint not configured".into()))
301    }
302
303    /// Check if email domain is allowed
304    pub fn is_domain_allowed(&self, email: &str) -> bool {
305        if self.config.allowed_domains.is_empty() {
306            return true;
307        }
308
309        let domain = email.split('@').nth(1).unwrap_or("");
310        self.config.allowed_domains.iter().any(|d| d == domain)
311    }
312
313    /// Check if user is in required groups
314    pub fn is_in_required_group(&self, groups: &[String]) -> bool {
315        if self.config.required_groups.is_empty() {
316            return true;
317        }
318
319        self.config.required_groups.iter().any(|g| groups.contains(g))
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_google_config() {
329        let config = ProviderConfig::google("client-id", "client-secret", Some("example.com"));
330
331        assert_eq!(config.provider_type, ProviderType::Google);
332        assert_eq!(config.allowed_domains, vec!["example.com"]);
333        assert!(config.additional_params.contains_key("hd"));
334    }
335
336    #[test]
337    fn test_microsoft_config() {
338        let config = ProviderConfig::microsoft("client-id", "client-secret", "tenant-id");
339
340        assert_eq!(config.provider_type, ProviderType::Microsoft);
341        assert!(config.issuer_url.contains("tenant-id"));
342    }
343
344    #[test]
345    fn test_domain_check() {
346        let config = ProviderConfig::google("id", "secret", Some("example.com"));
347        let provider = OAuthProvider::new(config);
348
349        assert!(provider.is_domain_allowed("user@example.com"));
350        assert!(!provider.is_domain_allowed("user@other.com"));
351    }
352}