1use std::collections::HashMap;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8use secrecy::ExposeSecret;
9use tracing::{debug, warn};
10
11use crate::{AuthError, OAuthProvider, Result, TokenSet};
12use crate::session::RateLimiter;
13use std::sync::Arc;
14use parking_lot::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthState {
19 pub state: String,
21 pub nonce: String,
23 pub code_verifier: String,
25 pub created_at: chrono::DateTime<chrono::Utc>,
27 pub expires_at: chrono::DateTime<chrono::Utc>,
29 pub metadata: HashMap<String, String>,
31}
32
33impl AuthState {
34 pub fn new(lifetime: Duration) -> Self {
36 let now = chrono::Utc::now();
37 Self {
38 state: Uuid::new_v4().to_string(),
39 nonce: Uuid::new_v4().to_string(),
40 code_verifier: Self::generate_code_verifier(),
41 created_at: now,
42 expires_at: now + chrono::Duration::from_std(lifetime)
43 .unwrap_or_else(|_| chrono::Duration::seconds(600)), metadata: HashMap::new(),
45 }
46 }
47
48 pub fn is_expired(&self) -> bool {
50 chrono::Utc::now() > self.expires_at
51 }
52
53 pub fn code_challenge(&self) -> String {
55 use sha2::{Sha256, Digest};
56 use base64::Engine;
57
58 let mut hasher = Sha256::new();
59 hasher.update(self.code_verifier.as_bytes());
60 let hash = hasher.finalize();
61
62 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
63 }
64
65 fn generate_code_verifier() -> String {
66 use base64::Engine;
67
68 let random_bytes: [u8; 32] = corevpn_crypto::random_bytes();
69 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(random_bytes)
70 }
71}
72
73pub struct AuthFlow {
75 provider: OAuthProvider,
77 redirect_uri: String,
79 rate_limiter: Arc<RwLock<RateLimiter>>,
81}
82
83impl AuthFlow {
84 pub fn new(provider: OAuthProvider, redirect_uri: &str) -> Self {
86 Self {
87 provider,
88 redirect_uri: redirect_uri.to_string(),
89 rate_limiter: Arc::new(RwLock::new(RateLimiter::new(
90 5, std::time::Duration::from_secs(300), ))),
93 }
94 }
95
96 pub fn with_rate_limiter(provider: OAuthProvider, redirect_uri: &str, rate_limiter: RateLimiter) -> Self {
98 Self {
99 provider,
100 redirect_uri: redirect_uri.to_string(),
101 rate_limiter: Arc::new(RwLock::new(rate_limiter)),
102 }
103 }
104
105 pub fn authorization_url(&self, state: &AuthState) -> Result<String> {
107 let endpoint = self.provider.authorization_endpoint()?;
108 let config = self.provider.config();
109
110 let code_challenge = state.code_challenge();
111 let mut params = vec![
112 ("client_id", config.client_id.as_str()),
113 ("response_type", "code"),
114 ("redirect_uri", &self.redirect_uri),
115 ("state", &state.state),
116 ("nonce", &state.nonce),
117 ("code_challenge", code_challenge.as_str()),
118 ("code_challenge_method", "S256"),
119 ];
120
121 let scopes = config.scopes.join(" ");
123 params.push(("scope", &scopes));
124
125 let mut url = endpoint.to_string();
127 url.push('?');
128
129 for (i, (key, value)) in params.iter().enumerate() {
130 if i > 0 {
131 url.push('&');
132 }
133 url.push_str(key);
134 url.push('=');
135 url.push_str(&urlencoding::encode(value));
136 }
137
138 for (key, value) in &config.additional_params {
140 url.push('&');
141 url.push_str(key);
142 url.push('=');
143 url.push_str(&urlencoding::encode(value));
144 }
145
146 Ok(url)
147 }
148
149 pub async fn exchange_code(&self, code: &str, state: &AuthState) -> Result<TokenSet> {
151 let endpoint = self.provider.token_endpoint()?;
152 let config = self.provider.config();
153
154 let client_secret = config.client_secret.expose_secret();
155 let params = [
156 ("grant_type", "authorization_code"),
157 ("client_id", config.client_id.as_str()),
158 ("client_secret", client_secret.as_str()),
159 ("code", code),
160 ("redirect_uri", self.redirect_uri.as_str()),
161 ("code_verifier", state.code_verifier.as_str()),
162 ];
163
164 let client = reqwest::Client::new();
165 let response = client
166 .post(endpoint)
167 .form(¶ms)
168 .send()
169 .await
170 .map_err(|e| {
171 warn!("Token exchange request failed: {}", e);
172 AuthError::OAuth2Error("Authentication failed".into())
173 })?;
174
175 if !response.status().is_success() {
176 let status = response.status();
177 let error_text = response.text().await.unwrap_or_default();
178 warn!("Token exchange failed with status {}: {}", status, error_text);
179 return Err(AuthError::OAuth2Error("Authentication failed".into()));
180 }
181
182 let token_response: TokenResponse = response.json().await?;
183
184 Ok(TokenSet {
185 access_token: token_response.access_token,
186 refresh_token: token_response.refresh_token,
187 id_token: token_response.id_token,
188 expires_at: chrono::Utc::now()
189 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
190 token_type: token_response.token_type,
191 scopes: token_response.scope
192 .map(|s| s.split(' ').map(String::from).collect())
193 .unwrap_or_default(),
194 })
195 }
196
197 pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenSet> {
199 let endpoint = self.provider.token_endpoint()?;
200 let config = self.provider.config();
201
202 let client_secret = config.client_secret.expose_secret();
203 let params = [
204 ("grant_type", "refresh_token"),
205 ("client_id", config.client_id.as_str()),
206 ("client_secret", client_secret.as_str()),
207 ("refresh_token", refresh_token),
208 ];
209
210 let client = reqwest::Client::new();
211 let response = client
212 .post(endpoint)
213 .form(¶ms)
214 .send()
215 .await
216 .map_err(|e| {
217 warn!("Token refresh request failed: {}", e);
218 AuthError::TokenRefreshFailed("Token refresh failed".into())
219 })?;
220
221 if !response.status().is_success() {
222 let status = response.status();
223 let error_text = response.text().await.unwrap_or_default();
224 warn!("Token refresh failed with status {}: {}", status, error_text);
225 return Err(AuthError::TokenRefreshFailed("Token refresh failed".into()));
226 }
227
228 let token_response: TokenResponse = response.json().await?;
229
230 Ok(TokenSet {
231 access_token: token_response.access_token,
232 refresh_token: token_response.refresh_token,
233 id_token: token_response.id_token,
234 expires_at: chrono::Utc::now()
235 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
236 token_type: token_response.token_type,
237 scopes: token_response.scope
238 .map(|s| s.split(' ').map(String::from).collect())
239 .unwrap_or_default(),
240 })
241 }
242}
243
244pub struct DeviceAuthFlow {
246 provider: OAuthProvider,
248 rate_limiter: Arc<RwLock<RateLimiter>>,
250}
251
252impl DeviceAuthFlow {
253 pub fn new(provider: OAuthProvider) -> Self {
255 Self {
256 provider,
257 rate_limiter: Arc::new(RwLock::new(RateLimiter::new(
258 10, std::time::Duration::from_secs(600), ))),
261 }
262 }
263
264 pub async fn start(&self, client_ip: Option<&str>) -> Result<DeviceAuthResponse> {
266 let rate_limit_key = client_ip.unwrap_or("unknown");
268 if !self.rate_limiter.read().check(rate_limit_key) {
269 return Err(AuthError::OAuth2Error("Too many device authorization attempts".into()));
270 }
271
272 let endpoint = self.provider.device_authorization_endpoint()?;
273 let config = self.provider.config();
274
275 let scopes = config.scopes.join(" ");
276 let params = [
277 ("client_id", config.client_id.as_str()),
278 ("scope", &scopes),
279 ];
280
281 let client = reqwest::Client::new();
282 let response = client
283 .post(endpoint)
284 .form(¶ms)
285 .send()
286 .await?;
287
288 if !response.status().is_success() {
289 let status = response.status();
290 let error_text = response.text().await.unwrap_or_default();
291 warn!("Device authorization failed with status {}: {}", status, error_text);
292 return Err(AuthError::OAuth2Error("Device authorization failed".into()));
293 }
294
295 let device_response: DeviceAuthResponse = response.json().await?;
296 Ok(device_response)
297 }
298
299 pub async fn poll(&self, device_code: &str, client_ip: Option<&str>) -> Result<TokenSet> {
301 let rate_limit_key = client_ip.unwrap_or(device_code);
303 if !self.rate_limiter.read().check(rate_limit_key) {
304 return Err(AuthError::OAuth2Error("Too many polling attempts".into()));
305 }
306
307 let endpoint = self.provider.token_endpoint()?;
308 let config = self.provider.config();
309
310 let client_secret = config.client_secret.expose_secret();
311 let params = [
312 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
313 ("client_id", config.client_id.as_str()),
314 ("client_secret", client_secret.as_str()),
315 ("device_code", device_code),
316 ];
317
318 let client = reqwest::Client::new();
319 let response = client
320 .post(endpoint)
321 .form(¶ms)
322 .send()
323 .await?;
324
325 if !response.status().is_success() {
326 let error_response: ErrorResponse = response.json().await
327 .map_err(|e| {
328 warn!("Failed to parse error response: {}", e);
329 AuthError::OAuth2Error("Device authorization failed".into())
330 })?;
331
332 return match error_response.error.as_str() {
333 "authorization_pending" => Err(AuthError::AuthorizationPending),
334 "slow_down" => Err(AuthError::AuthorizationPending),
335 "expired_token" => Err(AuthError::DeviceAuthExpired),
336 _ => {
337 warn!("Device auth error: {}", error_response.error);
338 Err(AuthError::OAuth2Error("Device authorization failed".into()))
339 }
340 };
341 }
342
343 let token_response: TokenResponse = response.json().await?;
344
345 Ok(TokenSet {
346 access_token: token_response.access_token,
347 refresh_token: token_response.refresh_token,
348 id_token: token_response.id_token,
349 expires_at: chrono::Utc::now()
350 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
351 token_type: token_response.token_type,
352 scopes: token_response.scope
353 .map(|s| s.split(' ').map(String::from).collect())
354 .unwrap_or_default(),
355 })
356 }
357}
358
359#[derive(Debug, Deserialize)]
361struct TokenResponse {
362 access_token: String,
363 #[serde(default)]
364 refresh_token: Option<String>,
365 #[serde(default)]
366 id_token: Option<String>,
367 #[serde(default)]
368 expires_in: Option<u64>,
369 #[serde(default)]
370 token_type: String,
371 #[serde(default)]
372 scope: Option<String>,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct DeviceAuthResponse {
378 pub device_code: String,
380 pub user_code: String,
382 pub verification_uri: String,
384 #[serde(default)]
386 pub verification_uri_complete: Option<String>,
387 pub expires_in: u64,
389 #[serde(default = "default_interval")]
391 pub interval: u64,
392}
393
394fn default_interval() -> u64 {
395 5
396}
397
398#[derive(Debug, Deserialize)]
400struct ErrorResponse {
401 error: String,
402 #[serde(default)]
403 error_description: Option<String>,
404}
405
406pub fn generate_vpn_auth_challenge(device_response: &DeviceAuthResponse) -> String {
408 format!(
409 "CRV1:R,E:{}:Please visit {} and enter code: {}",
410 base64::Engine::encode(
411 &base64::engine::general_purpose::STANDARD,
412 device_response.device_code.as_bytes()
413 ),
414 device_response.verification_uri,
415 device_response.user_code
416 )
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_auth_state() {
425 let state = AuthState::new(Duration::from_secs(300));
426
427 assert!(!state.is_expired());
428 assert!(!state.state.is_empty());
429 assert!(!state.nonce.is_empty());
430 assert!(!state.code_verifier.is_empty());
431 }
432
433 #[test]
434 fn test_code_challenge() {
435 let state = AuthState::new(Duration::from_secs(300));
436 let challenge = state.code_challenge();
437
438 assert!(!challenge.is_empty());
440 assert!(!challenge.contains('+'));
441 assert!(!challenge.contains('/'));
442 }
443}