Skip to main content

corevpn_auth/
flow.rs

1//! OAuth2 Authentication Flows
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8use secrecy::ExposeSecret;
9use tracing::{debug, warn};
10
11use crate::{AuthError, OAuthProvider, Result, TokenSet};
12use crate::session::RateLimiter;
13use std::sync::Arc;
14use parking_lot::RwLock;
15
16/// Authentication state (for CSRF protection)
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthState {
19    /// Random state value
20    pub state: String,
21    /// Nonce for ID token validation
22    pub nonce: String,
23    /// PKCE code verifier
24    pub code_verifier: String,
25    /// Creation timestamp
26    pub created_at: chrono::DateTime<chrono::Utc>,
27    /// Expiration timestamp
28    pub expires_at: chrono::DateTime<chrono::Utc>,
29    /// Additional metadata
30    pub metadata: HashMap<String, String>,
31}
32
33impl AuthState {
34    /// Create a new auth state
35    pub fn new(lifetime: Duration) -> Self {
36        let now = chrono::Utc::now();
37        Self {
38            state: Uuid::new_v4().to_string(),
39            nonce: Uuid::new_v4().to_string(),
40            code_verifier: Self::generate_code_verifier(),
41            created_at: now,
42            expires_at: now + chrono::Duration::from_std(lifetime)
43                .unwrap_or_else(|_| chrono::Duration::seconds(600)), // Fallback to 10 minutes
44            metadata: HashMap::new(),
45        }
46    }
47
48    /// Check if state is expired
49    pub fn is_expired(&self) -> bool {
50        chrono::Utc::now() > self.expires_at
51    }
52
53    /// Get PKCE code challenge
54    pub fn code_challenge(&self) -> String {
55        use sha2::{Sha256, Digest};
56        use base64::Engine;
57
58        let mut hasher = Sha256::new();
59        hasher.update(self.code_verifier.as_bytes());
60        let hash = hasher.finalize();
61
62        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
63    }
64
65    fn generate_code_verifier() -> String {
66        use base64::Engine;
67
68        let random_bytes: [u8; 32] = corevpn_crypto::random_bytes();
69        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(random_bytes)
70    }
71}
72
73/// OAuth2 Authorization Code Flow
74pub struct AuthFlow {
75    /// Provider
76    provider: OAuthProvider,
77    /// Redirect URI
78    redirect_uri: String,
79    /// Rate limiter for authentication attempts
80    rate_limiter: Arc<RwLock<RateLimiter>>,
81}
82
83impl AuthFlow {
84    /// Create a new auth flow
85    pub fn new(provider: OAuthProvider, redirect_uri: &str) -> Self {
86        Self {
87            provider,
88            redirect_uri: redirect_uri.to_string(),
89            rate_limiter: Arc::new(RwLock::new(RateLimiter::new(
90                5, // 5 attempts
91                std::time::Duration::from_secs(300), // per 5 minutes
92            ))),
93        }
94    }
95
96    /// Create a new auth flow with custom rate limiter
97    pub fn with_rate_limiter(provider: OAuthProvider, redirect_uri: &str, rate_limiter: RateLimiter) -> Self {
98        Self {
99            provider,
100            redirect_uri: redirect_uri.to_string(),
101            rate_limiter: Arc::new(RwLock::new(rate_limiter)),
102        }
103    }
104
105    /// Generate authorization URL
106    pub fn authorization_url(&self, state: &AuthState) -> Result<String> {
107        let endpoint = self.provider.authorization_endpoint()?;
108        let config = self.provider.config();
109
110        let code_challenge = state.code_challenge();
111        let mut params = vec![
112            ("client_id", config.client_id.as_str()),
113            ("response_type", "code"),
114            ("redirect_uri", &self.redirect_uri),
115            ("state", &state.state),
116            ("nonce", &state.nonce),
117            ("code_challenge", code_challenge.as_str()),
118            ("code_challenge_method", "S256"),
119        ];
120
121        // Add scopes
122        let scopes = config.scopes.join(" ");
123        params.push(("scope", &scopes));
124
125        // Build URL
126        let mut url = endpoint.to_string();
127        url.push('?');
128
129        for (i, (key, value)) in params.iter().enumerate() {
130            if i > 0 {
131                url.push('&');
132            }
133            url.push_str(key);
134            url.push('=');
135            url.push_str(&urlencoding::encode(value));
136        }
137
138        // Add any additional parameters
139        for (key, value) in &config.additional_params {
140            url.push('&');
141            url.push_str(key);
142            url.push('=');
143            url.push_str(&urlencoding::encode(value));
144        }
145
146        Ok(url)
147    }
148
149    /// Exchange authorization code for tokens
150    pub async fn exchange_code(&self, code: &str, state: &AuthState) -> Result<TokenSet> {
151        let endpoint = self.provider.token_endpoint()?;
152        let config = self.provider.config();
153
154        let client_secret = config.client_secret.expose_secret();
155        let params = [
156            ("grant_type", "authorization_code"),
157            ("client_id", config.client_id.as_str()),
158            ("client_secret", client_secret.as_str()),
159            ("code", code),
160            ("redirect_uri", self.redirect_uri.as_str()),
161            ("code_verifier", state.code_verifier.as_str()),
162        ];
163
164        let client = reqwest::Client::new();
165        let response = client
166            .post(endpoint)
167            .form(&params)
168            .send()
169            .await
170            .map_err(|e| {
171                warn!("Token exchange request failed: {}", e);
172                AuthError::OAuth2Error("Authentication failed".into())
173            })?;
174
175        if !response.status().is_success() {
176            let status = response.status();
177            let error_text = response.text().await.unwrap_or_default();
178            warn!("Token exchange failed with status {}: {}", status, error_text);
179            return Err(AuthError::OAuth2Error("Authentication failed".into()));
180        }
181
182        let token_response: TokenResponse = response.json().await?;
183
184        Ok(TokenSet {
185            access_token: token_response.access_token,
186            refresh_token: token_response.refresh_token,
187            id_token: token_response.id_token,
188            expires_at: chrono::Utc::now()
189                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
190            token_type: token_response.token_type,
191            scopes: token_response.scope
192                .map(|s| s.split(' ').map(String::from).collect())
193                .unwrap_or_default(),
194        })
195    }
196
197    /// Refresh access token
198    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenSet> {
199        let endpoint = self.provider.token_endpoint()?;
200        let config = self.provider.config();
201
202        let client_secret = config.client_secret.expose_secret();
203        let params = [
204            ("grant_type", "refresh_token"),
205            ("client_id", config.client_id.as_str()),
206            ("client_secret", client_secret.as_str()),
207            ("refresh_token", refresh_token),
208        ];
209
210        let client = reqwest::Client::new();
211        let response = client
212            .post(endpoint)
213            .form(&params)
214            .send()
215            .await
216            .map_err(|e| {
217                warn!("Token refresh request failed: {}", e);
218                AuthError::TokenRefreshFailed("Token refresh failed".into())
219            })?;
220
221        if !response.status().is_success() {
222            let status = response.status();
223            let error_text = response.text().await.unwrap_or_default();
224            warn!("Token refresh failed with status {}: {}", status, error_text);
225            return Err(AuthError::TokenRefreshFailed("Token refresh failed".into()));
226        }
227
228        let token_response: TokenResponse = response.json().await?;
229
230        Ok(TokenSet {
231            access_token: token_response.access_token,
232            refresh_token: token_response.refresh_token,
233            id_token: token_response.id_token,
234            expires_at: chrono::Utc::now()
235                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
236            token_type: token_response.token_type,
237            scopes: token_response.scope
238                .map(|s| s.split(' ').map(String::from).collect())
239                .unwrap_or_default(),
240        })
241    }
242}
243
244/// OAuth2 Device Authorization Flow (for CLI/headless)
245pub struct DeviceAuthFlow {
246    /// Provider
247    provider: OAuthProvider,
248    /// Rate limiter for device auth attempts
249    rate_limiter: Arc<RwLock<RateLimiter>>,
250}
251
252impl DeviceAuthFlow {
253    /// Create a new device auth flow
254    pub fn new(provider: OAuthProvider) -> Self {
255        Self {
256            provider,
257            rate_limiter: Arc::new(RwLock::new(RateLimiter::new(
258                10, // 10 attempts
259                std::time::Duration::from_secs(600), // per 10 minutes
260            ))),
261        }
262    }
263
264    /// Start device authorization
265    pub async fn start(&self, client_ip: Option<&str>) -> Result<DeviceAuthResponse> {
266        // Rate limit device auth attempts
267        let rate_limit_key = client_ip.unwrap_or("unknown");
268        if !self.rate_limiter.read().check(rate_limit_key) {
269            return Err(AuthError::OAuth2Error("Too many device authorization attempts".into()));
270        }
271
272        let endpoint = self.provider.device_authorization_endpoint()?;
273        let config = self.provider.config();
274
275        let scopes = config.scopes.join(" ");
276        let params = [
277            ("client_id", config.client_id.as_str()),
278            ("scope", &scopes),
279        ];
280
281        let client = reqwest::Client::new();
282        let response = client
283            .post(endpoint)
284            .form(&params)
285            .send()
286            .await?;
287
288        if !response.status().is_success() {
289            let status = response.status();
290            let error_text = response.text().await.unwrap_or_default();
291            warn!("Device authorization failed with status {}: {}", status, error_text);
292            return Err(AuthError::OAuth2Error("Device authorization failed".into()));
293        }
294
295        let device_response: DeviceAuthResponse = response.json().await?;
296        Ok(device_response)
297    }
298
299    /// Poll for token (call repeatedly until success or error)
300    pub async fn poll(&self, device_code: &str, client_ip: Option<&str>) -> Result<TokenSet> {
301        // Rate limit polling attempts
302        let rate_limit_key = client_ip.unwrap_or(device_code);
303        if !self.rate_limiter.read().check(rate_limit_key) {
304            return Err(AuthError::OAuth2Error("Too many polling attempts".into()));
305        }
306
307        let endpoint = self.provider.token_endpoint()?;
308        let config = self.provider.config();
309
310        let client_secret = config.client_secret.expose_secret();
311        let params = [
312            ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
313            ("client_id", config.client_id.as_str()),
314            ("client_secret", client_secret.as_str()),
315            ("device_code", device_code),
316        ];
317
318        let client = reqwest::Client::new();
319        let response = client
320            .post(endpoint)
321            .form(&params)
322            .send()
323            .await?;
324
325        if !response.status().is_success() {
326            let error_response: ErrorResponse = response.json().await
327                .map_err(|e| {
328                    warn!("Failed to parse error response: {}", e);
329                    AuthError::OAuth2Error("Device authorization failed".into())
330                })?;
331
332            return match error_response.error.as_str() {
333                "authorization_pending" => Err(AuthError::AuthorizationPending),
334                "slow_down" => Err(AuthError::AuthorizationPending),
335                "expired_token" => Err(AuthError::DeviceAuthExpired),
336                _ => {
337                    warn!("Device auth error: {}", error_response.error);
338                    Err(AuthError::OAuth2Error("Device authorization failed".into()))
339                }
340            };
341        }
342
343        let token_response: TokenResponse = response.json().await?;
344
345        Ok(TokenSet {
346            access_token: token_response.access_token,
347            refresh_token: token_response.refresh_token,
348            id_token: token_response.id_token,
349            expires_at: chrono::Utc::now()
350                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
351            token_type: token_response.token_type,
352            scopes: token_response.scope
353                .map(|s| s.split(' ').map(String::from).collect())
354                .unwrap_or_default(),
355        })
356    }
357}
358
359/// OAuth2 token response
360#[derive(Debug, Deserialize)]
361struct TokenResponse {
362    access_token: String,
363    #[serde(default)]
364    refresh_token: Option<String>,
365    #[serde(default)]
366    id_token: Option<String>,
367    #[serde(default)]
368    expires_in: Option<u64>,
369    #[serde(default)]
370    token_type: String,
371    #[serde(default)]
372    scope: Option<String>,
373}
374
375/// Device authorization response
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct DeviceAuthResponse {
378    /// Device code (for polling)
379    pub device_code: String,
380    /// User code (to enter on verification page)
381    pub user_code: String,
382    /// Verification URI
383    pub verification_uri: String,
384    /// Verification URI with code pre-filled (optional)
385    #[serde(default)]
386    pub verification_uri_complete: Option<String>,
387    /// Expiration in seconds
388    pub expires_in: u64,
389    /// Polling interval in seconds
390    #[serde(default = "default_interval")]
391    pub interval: u64,
392}
393
394fn default_interval() -> u64 {
395    5
396}
397
398/// Error response
399#[derive(Debug, Deserialize)]
400struct ErrorResponse {
401    error: String,
402    #[serde(default)]
403    error_description: Option<String>,
404}
405
406/// Generate authentication challenge for OpenVPN auth-user-pass
407pub fn generate_vpn_auth_challenge(device_response: &DeviceAuthResponse) -> String {
408    format!(
409        "CRV1:R,E:{}:Please visit {} and enter code: {}",
410        base64::Engine::encode(
411            &base64::engine::general_purpose::STANDARD,
412            device_response.device_code.as_bytes()
413        ),
414        device_response.verification_uri,
415        device_response.user_code
416    )
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_auth_state() {
425        let state = AuthState::new(Duration::from_secs(300));
426
427        assert!(!state.is_expired());
428        assert!(!state.state.is_empty());
429        assert!(!state.nonce.is_empty());
430        assert!(!state.code_verifier.is_empty());
431    }
432
433    #[test]
434    fn test_code_challenge() {
435        let state = AuthState::new(Duration::from_secs(300));
436        let challenge = state.code_challenge();
437
438        // Should be base64url encoded SHA256 hash
439        assert!(!challenge.is_empty());
440        assert!(!challenge.contains('+'));
441        assert!(!challenge.contains('/'));
442    }
443}