1use std::collections::HashMap;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9use crate::{AuthError, OAuthProvider, Result, TokenSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AuthState {
14 pub state: String,
16 pub nonce: String,
18 pub code_verifier: String,
20 pub created_at: chrono::DateTime<chrono::Utc>,
22 pub expires_at: chrono::DateTime<chrono::Utc>,
24 pub metadata: HashMap<String, String>,
26}
27
28impl AuthState {
29 pub fn new(lifetime: Duration) -> Self {
31 let now = chrono::Utc::now();
32 Self {
33 state: Uuid::new_v4().to_string(),
34 nonce: Uuid::new_v4().to_string(),
35 code_verifier: Self::generate_code_verifier(),
36 created_at: now,
37 expires_at: now + chrono::Duration::from_std(lifetime).unwrap(),
38 metadata: HashMap::new(),
39 }
40 }
41
42 pub fn is_expired(&self) -> bool {
44 chrono::Utc::now() > self.expires_at
45 }
46
47 pub fn code_challenge(&self) -> String {
49 use sha2::{Sha256, Digest};
50 use base64::Engine;
51
52 let mut hasher = Sha256::new();
53 hasher.update(self.code_verifier.as_bytes());
54 let hash = hasher.finalize();
55
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
57 }
58
59 fn generate_code_verifier() -> String {
60 use base64::Engine;
61
62 let random_bytes: [u8; 32] = corevpn_crypto::random_bytes();
63 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(random_bytes)
64 }
65}
66
67pub struct AuthFlow {
69 provider: OAuthProvider,
71 redirect_uri: String,
73}
74
75impl AuthFlow {
76 pub fn new(provider: OAuthProvider, redirect_uri: &str) -> Self {
78 Self {
79 provider,
80 redirect_uri: redirect_uri.to_string(),
81 }
82 }
83
84 pub fn authorization_url(&self, state: &AuthState) -> Result<String> {
86 let endpoint = self.provider.authorization_endpoint()?;
87 let config = self.provider.config();
88
89 let code_challenge = state.code_challenge();
90 let mut params = vec![
91 ("client_id", config.client_id.as_str()),
92 ("response_type", "code"),
93 ("redirect_uri", &self.redirect_uri),
94 ("state", &state.state),
95 ("nonce", &state.nonce),
96 ("code_challenge", code_challenge.as_str()),
97 ("code_challenge_method", "S256"),
98 ];
99
100 let scopes = config.scopes.join(" ");
102 params.push(("scope", &scopes));
103
104 let mut url = endpoint.to_string();
106 url.push('?');
107
108 for (i, (key, value)) in params.iter().enumerate() {
109 if i > 0 {
110 url.push('&');
111 }
112 url.push_str(key);
113 url.push('=');
114 url.push_str(&urlencoding::encode(value));
115 }
116
117 for (key, value) in &config.additional_params {
119 url.push('&');
120 url.push_str(key);
121 url.push('=');
122 url.push_str(&urlencoding::encode(value));
123 }
124
125 Ok(url)
126 }
127
128 pub async fn exchange_code(&self, code: &str, state: &AuthState) -> Result<TokenSet> {
130 let endpoint = self.provider.token_endpoint()?;
131 let config = self.provider.config();
132
133 let params = [
134 ("grant_type", "authorization_code"),
135 ("client_id", &config.client_id),
136 ("client_secret", &config.client_secret),
137 ("code", code),
138 ("redirect_uri", &self.redirect_uri),
139 ("code_verifier", &state.code_verifier),
140 ];
141
142 let client = reqwest::Client::new();
143 let response = client
144 .post(endpoint)
145 .form(¶ms)
146 .send()
147 .await?;
148
149 if !response.status().is_success() {
150 let error_text = response.text().await.unwrap_or_default();
151 return Err(AuthError::OAuth2Error(error_text));
152 }
153
154 let token_response: TokenResponse = response.json().await?;
155
156 Ok(TokenSet {
157 access_token: token_response.access_token,
158 refresh_token: token_response.refresh_token,
159 id_token: token_response.id_token,
160 expires_at: chrono::Utc::now()
161 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
162 token_type: token_response.token_type,
163 scopes: token_response.scope
164 .map(|s| s.split(' ').map(String::from).collect())
165 .unwrap_or_default(),
166 })
167 }
168
169 pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenSet> {
171 let endpoint = self.provider.token_endpoint()?;
172 let config = self.provider.config();
173
174 let params = [
175 ("grant_type", "refresh_token"),
176 ("client_id", &config.client_id),
177 ("client_secret", &config.client_secret),
178 ("refresh_token", refresh_token),
179 ];
180
181 let client = reqwest::Client::new();
182 let response = client
183 .post(endpoint)
184 .form(¶ms)
185 .send()
186 .await?;
187
188 if !response.status().is_success() {
189 let error_text = response.text().await.unwrap_or_default();
190 return Err(AuthError::TokenRefreshFailed(error_text));
191 }
192
193 let token_response: TokenResponse = response.json().await?;
194
195 Ok(TokenSet {
196 access_token: token_response.access_token,
197 refresh_token: token_response.refresh_token,
198 id_token: token_response.id_token,
199 expires_at: chrono::Utc::now()
200 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
201 token_type: token_response.token_type,
202 scopes: token_response.scope
203 .map(|s| s.split(' ').map(String::from).collect())
204 .unwrap_or_default(),
205 })
206 }
207}
208
209pub struct DeviceAuthFlow {
211 provider: OAuthProvider,
213}
214
215impl DeviceAuthFlow {
216 pub fn new(provider: OAuthProvider) -> Self {
218 Self { provider }
219 }
220
221 pub async fn start(&self) -> Result<DeviceAuthResponse> {
223 let endpoint = self.provider.device_authorization_endpoint()?;
224 let config = self.provider.config();
225
226 let scopes = config.scopes.join(" ");
227 let params = [
228 ("client_id", config.client_id.as_str()),
229 ("scope", &scopes),
230 ];
231
232 let client = reqwest::Client::new();
233 let response = client
234 .post(endpoint)
235 .form(¶ms)
236 .send()
237 .await?;
238
239 if !response.status().is_success() {
240 let error_text = response.text().await.unwrap_or_default();
241 return Err(AuthError::OAuth2Error(error_text));
242 }
243
244 let device_response: DeviceAuthResponse = response.json().await?;
245 Ok(device_response)
246 }
247
248 pub async fn poll(&self, device_code: &str) -> Result<TokenSet> {
250 let endpoint = self.provider.token_endpoint()?;
251 let config = self.provider.config();
252
253 let params = [
254 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
255 ("client_id", &config.client_id),
256 ("client_secret", &config.client_secret),
257 ("device_code", device_code),
258 ];
259
260 let client = reqwest::Client::new();
261 let response = client
262 .post(endpoint)
263 .form(¶ms)
264 .send()
265 .await?;
266
267 if !response.status().is_success() {
268 let error_response: ErrorResponse = response.json().await?;
269
270 return match error_response.error.as_str() {
271 "authorization_pending" => Err(AuthError::AuthorizationPending),
272 "slow_down" => Err(AuthError::AuthorizationPending),
273 "expired_token" => Err(AuthError::DeviceAuthExpired),
274 _ => Err(AuthError::OAuth2Error(
275 error_response.error_description.unwrap_or(error_response.error),
276 )),
277 };
278 }
279
280 let token_response: TokenResponse = response.json().await?;
281
282 Ok(TokenSet {
283 access_token: token_response.access_token,
284 refresh_token: token_response.refresh_token,
285 id_token: token_response.id_token,
286 expires_at: chrono::Utc::now()
287 + chrono::Duration::seconds(token_response.expires_in.unwrap_or(3600) as i64),
288 token_type: token_response.token_type,
289 scopes: token_response.scope
290 .map(|s| s.split(' ').map(String::from).collect())
291 .unwrap_or_default(),
292 })
293 }
294}
295
296#[derive(Debug, Deserialize)]
298struct TokenResponse {
299 access_token: String,
300 #[serde(default)]
301 refresh_token: Option<String>,
302 #[serde(default)]
303 id_token: Option<String>,
304 #[serde(default)]
305 expires_in: Option<u64>,
306 #[serde(default)]
307 token_type: String,
308 #[serde(default)]
309 scope: Option<String>,
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct DeviceAuthResponse {
315 pub device_code: String,
317 pub user_code: String,
319 pub verification_uri: String,
321 #[serde(default)]
323 pub verification_uri_complete: Option<String>,
324 pub expires_in: u64,
326 #[serde(default = "default_interval")]
328 pub interval: u64,
329}
330
331fn default_interval() -> u64 {
332 5
333}
334
335#[derive(Debug, Deserialize)]
337struct ErrorResponse {
338 error: String,
339 #[serde(default)]
340 error_description: Option<String>,
341}
342
343pub fn generate_vpn_auth_challenge(device_response: &DeviceAuthResponse) -> String {
345 format!(
346 "CRV1:R,E:{}:Please visit {} and enter code: {}",
347 base64::Engine::encode(
348 &base64::engine::general_purpose::STANDARD,
349 device_response.device_code.as_bytes()
350 ),
351 device_response.verification_uri,
352 device_response.user_code
353 )
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_auth_state() {
362 let state = AuthState::new(Duration::from_secs(300));
363
364 assert!(!state.is_expired());
365 assert!(!state.state.is_empty());
366 assert!(!state.nonce.is_empty());
367 assert!(!state.code_verifier.is_empty());
368 }
369
370 #[test]
371 fn test_code_challenge() {
372 let state = AuthState::new(Duration::from_secs(300));
373 let challenge = state.code_challenge();
374
375 assert!(!challenge.is_empty());
377 assert!(!challenge.contains('+'));
378 assert!(!challenge.contains('/'));
379 }
380}