corevpn_auth/
flow.rs

1//! OAuth2 Authentication Flows
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9use crate::{AuthError, OAuthProvider, Result, TokenSet};
10
11/// Authentication state (for CSRF protection)
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AuthState {
14    /// Random state value
15    pub state: String,
16    /// Nonce for ID token validation
17    pub nonce: String,
18    /// PKCE code verifier
19    pub code_verifier: String,
20    /// Creation timestamp
21    pub created_at: chrono::DateTime<chrono::Utc>,
22    /// Expiration timestamp
23    pub expires_at: chrono::DateTime<chrono::Utc>,
24    /// Additional metadata
25    pub metadata: HashMap<String, String>,
26}
27
28impl AuthState {
29    /// Create a new auth state
30    pub fn new(lifetime: Duration) -> Self {
31        let now = chrono::Utc::now();
32        Self {
33            state: Uuid::new_v4().to_string(),
34            nonce: Uuid::new_v4().to_string(),
35            code_verifier: Self::generate_code_verifier(),
36            created_at: now,
37            expires_at: now + chrono::Duration::from_std(lifetime).unwrap(),
38            metadata: HashMap::new(),
39        }
40    }
41
42    /// Check if state is expired
43    pub fn is_expired(&self) -> bool {
44        chrono::Utc::now() > self.expires_at
45    }
46
47    /// Get PKCE code challenge
48    pub fn code_challenge(&self) -> String {
49        use sha2::{Sha256, Digest};
50        use base64::Engine;
51
52        let mut hasher = Sha256::new();
53        hasher.update(self.code_verifier.as_bytes());
54        let hash = hasher.finalize();
55
56        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
57    }
58
59    fn generate_code_verifier() -> String {
60        use base64::Engine;
61
62        let random_bytes: [u8; 32] = corevpn_crypto::random_bytes();
63        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(random_bytes)
64    }
65}
66
67/// OAuth2 Authorization Code Flow
68pub struct AuthFlow {
69    /// Provider
70    provider: OAuthProvider,
71    /// Redirect URI
72    redirect_uri: String,
73}
74
75impl AuthFlow {
76    /// Create a new auth flow
77    pub fn new(provider: OAuthProvider, redirect_uri: &str) -> Self {
78        Self {
79            provider,
80            redirect_uri: redirect_uri.to_string(),
81        }
82    }
83
84    /// Generate authorization URL
85    pub fn authorization_url(&self, state: &AuthState) -> Result<String> {
86        let endpoint = self.provider.authorization_endpoint()?;
87        let config = self.provider.config();
88
89        let code_challenge = state.code_challenge();
90        let mut params = vec![
91            ("client_id", config.client_id.as_str()),
92            ("response_type", "code"),
93            ("redirect_uri", &self.redirect_uri),
94            ("state", &state.state),
95            ("nonce", &state.nonce),
96            ("code_challenge", code_challenge.as_str()),
97            ("code_challenge_method", "S256"),
98        ];
99
100        // Add scopes
101        let scopes = config.scopes.join(" ");
102        params.push(("scope", &scopes));
103
104        // Build URL
105        let mut url = endpoint.to_string();
106        url.push('?');
107
108        for (i, (key, value)) in params.iter().enumerate() {
109            if i > 0 {
110                url.push('&');
111            }
112            url.push_str(key);
113            url.push('=');
114            url.push_str(&urlencoding::encode(value));
115        }
116
117        // Add any additional parameters
118        for (key, value) in &config.additional_params {
119            url.push('&');
120            url.push_str(key);
121            url.push('=');
122            url.push_str(&urlencoding::encode(value));
123        }
124
125        Ok(url)
126    }
127
128    /// Exchange authorization code for tokens
129    pub async fn exchange_code(&self, code: &str, state: &AuthState) -> Result<TokenSet> {
130        let endpoint = self.provider.token_endpoint()?;
131        let config = self.provider.config();
132
133        let params = [
134            ("grant_type", "authorization_code"),
135            ("client_id", &config.client_id),
136            ("client_secret", &config.client_secret),
137            ("code", code),
138            ("redirect_uri", &self.redirect_uri),
139            ("code_verifier", &state.code_verifier),
140        ];
141
142        let client = reqwest::Client::new();
143        let response = client
144            .post(endpoint)
145            .form(&params)
146            .send()
147            .await?;
148
149        if !response.status().is_success() {
150            let error_text = response.text().await.unwrap_or_default();
151            return Err(AuthError::OAuth2Error(error_text));
152        }
153
154        let token_response: TokenResponse = response.json().await?;
155
156        Ok(TokenSet {
157            access_token: token_response.access_token,
158            refresh_token: token_response.refresh_token,
159            id_token: token_response.id_token,
160            expires_at: chrono::Utc::now()
161                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
162            token_type: token_response.token_type,
163            scopes: token_response.scope
164                .map(|s| s.split(' ').map(String::from).collect())
165                .unwrap_or_default(),
166        })
167    }
168
169    /// Refresh access token
170    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenSet> {
171        let endpoint = self.provider.token_endpoint()?;
172        let config = self.provider.config();
173
174        let params = [
175            ("grant_type", "refresh_token"),
176            ("client_id", &config.client_id),
177            ("client_secret", &config.client_secret),
178            ("refresh_token", refresh_token),
179        ];
180
181        let client = reqwest::Client::new();
182        let response = client
183            .post(endpoint)
184            .form(&params)
185            .send()
186            .await?;
187
188        if !response.status().is_success() {
189            let error_text = response.text().await.unwrap_or_default();
190            return Err(AuthError::TokenRefreshFailed(error_text));
191        }
192
193        let token_response: TokenResponse = response.json().await?;
194
195        Ok(TokenSet {
196            access_token: token_response.access_token,
197            refresh_token: token_response.refresh_token,
198            id_token: token_response.id_token,
199            expires_at: chrono::Utc::now()
200                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
201            token_type: token_response.token_type,
202            scopes: token_response.scope
203                .map(|s| s.split(' ').map(String::from).collect())
204                .unwrap_or_default(),
205        })
206    }
207}
208
209/// OAuth2 Device Authorization Flow (for CLI/headless)
210pub struct DeviceAuthFlow {
211    /// Provider
212    provider: OAuthProvider,
213}
214
215impl DeviceAuthFlow {
216    /// Create a new device auth flow
217    pub fn new(provider: OAuthProvider) -> Self {
218        Self { provider }
219    }
220
221    /// Start device authorization
222    pub async fn start(&self) -> Result<DeviceAuthResponse> {
223        let endpoint = self.provider.device_authorization_endpoint()?;
224        let config = self.provider.config();
225
226        let scopes = config.scopes.join(" ");
227        let params = [
228            ("client_id", config.client_id.as_str()),
229            ("scope", &scopes),
230        ];
231
232        let client = reqwest::Client::new();
233        let response = client
234            .post(endpoint)
235            .form(&params)
236            .send()
237            .await?;
238
239        if !response.status().is_success() {
240            let error_text = response.text().await.unwrap_or_default();
241            return Err(AuthError::OAuth2Error(error_text));
242        }
243
244        let device_response: DeviceAuthResponse = response.json().await?;
245        Ok(device_response)
246    }
247
248    /// Poll for token (call repeatedly until success or error)
249    pub async fn poll(&self, device_code: &str) -> Result<TokenSet> {
250        let endpoint = self.provider.token_endpoint()?;
251        let config = self.provider.config();
252
253        let params = [
254            ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
255            ("client_id", &config.client_id),
256            ("client_secret", &config.client_secret),
257            ("device_code", device_code),
258        ];
259
260        let client = reqwest::Client::new();
261        let response = client
262            .post(endpoint)
263            .form(&params)
264            .send()
265            .await?;
266
267        if !response.status().is_success() {
268            let error_response: ErrorResponse = response.json().await?;
269
270            return match error_response.error.as_str() {
271                "authorization_pending" => Err(AuthError::AuthorizationPending),
272                "slow_down" => Err(AuthError::AuthorizationPending),
273                "expired_token" => Err(AuthError::DeviceAuthExpired),
274                _ => Err(AuthError::OAuth2Error(
275                    error_response.error_description.unwrap_or(error_response.error),
276                )),
277            };
278        }
279
280        let token_response: TokenResponse = response.json().await?;
281
282        Ok(TokenSet {
283            access_token: token_response.access_token,
284            refresh_token: token_response.refresh_token,
285            id_token: token_response.id_token,
286            expires_at: chrono::Utc::now()
287                + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
288            token_type: token_response.token_type,
289            scopes: token_response.scope
290                .map(|s| s.split(' ').map(String::from).collect())
291                .unwrap_or_default(),
292        })
293    }
294}
295
296/// OAuth2 token response
297#[derive(Debug, Deserialize)]
298struct TokenResponse {
299    access_token: String,
300    #[serde(default)]
301    refresh_token: Option<String>,
302    #[serde(default)]
303    id_token: Option<String>,
304    #[serde(default)]
305    expires_in: Option<u64>,
306    #[serde(default)]
307    token_type: String,
308    #[serde(default)]
309    scope: Option<String>,
310}
311
312/// Device authorization response
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct DeviceAuthResponse {
315    /// Device code (for polling)
316    pub device_code: String,
317    /// User code (to enter on verification page)
318    pub user_code: String,
319    /// Verification URI
320    pub verification_uri: String,
321    /// Verification URI with code pre-filled (optional)
322    #[serde(default)]
323    pub verification_uri_complete: Option<String>,
324    /// Expiration in seconds
325    pub expires_in: u64,
326    /// Polling interval in seconds
327    #[serde(default = "default_interval")]
328    pub interval: u64,
329}
330
331fn default_interval() -> u64 {
332    5
333}
334
335/// Error response
336#[derive(Debug, Deserialize)]
337struct ErrorResponse {
338    error: String,
339    #[serde(default)]
340    error_description: Option<String>,
341}
342
343/// Generate authentication challenge for OpenVPN auth-user-pass
344pub fn generate_vpn_auth_challenge(device_response: &DeviceAuthResponse) -> String {
345    format!(
346        "CRV1:R,E:{}:Please visit {} and enter code: {}",
347        base64::Engine::encode(
348            &base64::engine::general_purpose::STANDARD,
349            device_response.device_code.as_bytes()
350        ),
351        device_response.verification_uri,
352        device_response.user_code
353    )
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_auth_state() {
362        let state = AuthState::new(Duration::from_secs(300));
363
364        assert!(!state.is_expired());
365        assert!(!state.state.is_empty());
366        assert!(!state.nonce.is_empty());
367        assert!(!state.code_verifier.is_empty());
368    }
369
370    #[test]
371    fn test_code_challenge() {
372        let state = AuthState::new(Duration::from_secs(300));
373        let challenge = state.code_challenge();
374
375        // Should be base64url encoded SHA256 hash
376        assert!(!challenge.is_empty());
377        assert!(!challenge.contains('+'));
378        assert!(!challenge.contains('/'));
379    }
380}