1use crate::types::{AuthError, AuthResult, OAuth2Config, Permission, Result, User};
9use chrono::{DateTime, Duration, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15#[derive(Clone)]
17pub struct OAuth2Service {
18 config: Arc<OAuth2Config>,
19 active_states: Arc<RwLock<HashMap<String, OAuth2State>>>,
20 client: reqwest::Client,
21}
22
23#[derive(Debug, Clone)]
25pub struct OAuth2State {
26 pub state: String,
27 pub code_verifier: Option<String>, pub redirect_uri: String,
29 pub created_at: DateTime<Utc>,
30 pub expires_at: DateTime<Utc>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct OAuth2Token {
36 pub access_token: String,
37 pub token_type: String,
38 pub expires_in: u64,
39 pub refresh_token: Option<String>,
40 pub scope: String,
41 pub id_token: Option<String>,
42 pub issued_at: DateTime<Utc>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct OIDCUserInfo {
48 pub sub: String,
49 pub name: Option<String>,
50 pub email: Option<String>,
51 pub email_verified: Option<bool>,
52 pub groups: Option<Vec<String>>,
53 pub roles: Option<Vec<String>>,
54}
55
56#[derive(Debug, Deserialize)]
58struct OAuth2TokenResponse {
59 access_token: String,
60 token_type: String,
61 expires_in: Option<u64>,
62 refresh_token: Option<String>,
63 scope: Option<String>,
64 id_token: Option<String>,
65}
66
67impl OAuth2Service {
68 #[must_use]
70 pub fn new(config: OAuth2Config) -> Self {
71 let client = reqwest::Client::builder()
72 .timeout(std::time::Duration::from_secs(30))
73 .build()
74 .unwrap_or_default();
75
76 Self {
77 config: Arc::new(config),
78 active_states: Arc::new(RwLock::new(HashMap::new())),
79 client,
80 }
81 }
82
83 pub async fn generate_authorization_url(
85 &self,
86 redirect_uri: &str,
87 use_pkce: bool,
88 ) -> Result<(String, String)> {
89 let state = uuid::Uuid::new_v4().to_string();
90 let scope_string = self.config.scopes.join(" ");
91
92 let mut url = format!(
93 "{}?response_type=code&client_id={}&redirect_uri={}&state={}&scope={}",
94 self.config.auth_url,
95 url_encode(&self.config.client_id),
96 url_encode(redirect_uri),
97 url_encode(&state),
98 url_encode(&scope_string)
99 );
100
101 let mut oauth_state = OAuth2State {
102 state: state.clone(),
103 code_verifier: None,
104 redirect_uri: redirect_uri.to_string(),
105 created_at: Utc::now(),
106 expires_at: Utc::now() + Duration::minutes(10),
107 };
108
109 if use_pkce {
111 let code_verifier = generate_code_verifier();
112 let code_challenge = generate_code_challenge(&code_verifier);
113
114 url.push_str("&code_challenge=");
115 url.push_str(&url_encode(&code_challenge));
116 url.push_str("&code_challenge_method=S256");
117
118 oauth_state.code_verifier = Some(code_verifier);
119 }
120
121 let mut states = self.active_states.write().await;
123 states.insert(state.clone(), oauth_state);
124
125 Ok((url, state))
126 }
127
128 pub async fn exchange_code_for_token(
130 &self,
131 code: &str,
132 state: &str,
133 redirect_uri: &str,
134 ) -> Result<OAuth2Token> {
135 let oauth_state = {
137 let mut states = self.active_states.write().await;
138 states.remove(state)
139 };
140
141 let oauth_state = oauth_state.ok_or(AuthError::OAuthError(
142 "Invalid or expired state".to_string(),
143 ))?;
144
145 if oauth_state.redirect_uri != redirect_uri {
146 return Err(AuthError::OAuthError("Redirect URI mismatch".to_string()));
147 }
148
149 if Utc::now() > oauth_state.expires_at {
150 return Err(AuthError::OAuthError("State expired".to_string()));
151 }
152
153 let mut params = vec![
155 ("grant_type", "authorization_code".to_string()),
156 ("code", code.to_string()),
157 ("redirect_uri", redirect_uri.to_string()),
158 ("client_id", self.config.client_id.clone()),
159 ("client_secret", self.config.client_secret.clone()),
160 ];
161
162 if let Some(code_verifier) = oauth_state.code_verifier {
164 params.push(("code_verifier", code_verifier));
165 }
166
167 let response = self
168 .client
169 .post(&self.config.token_url)
170 .form(¶ms)
171 .send()
172 .await
173 .map_err(|e| AuthError::OAuthError(format!("Token exchange failed: {e}")))?;
174
175 if !response.status().is_success() {
176 let error_text = response.text().await.unwrap_or_default();
177 return Err(AuthError::OAuthError(format!(
178 "Token exchange failed: {error_text}"
179 )));
180 }
181
182 let token_response: OAuth2TokenResponse = response
183 .json()
184 .await
185 .map_err(|e| AuthError::OAuthError(format!("Failed to parse token: {e}")))?;
186
187 Ok(OAuth2Token {
188 access_token: token_response.access_token,
189 token_type: token_response.token_type,
190 expires_in: token_response.expires_in.unwrap_or(3600),
191 refresh_token: token_response.refresh_token,
192 scope: token_response.scope.unwrap_or_default(),
193 id_token: token_response.id_token,
194 issued_at: Utc::now(),
195 })
196 }
197
198 pub async fn get_user_info(&self, access_token: &str) -> Result<OIDCUserInfo> {
200 let response = self
201 .client
202 .get(&self.config.user_info_url)
203 .bearer_auth(access_token)
204 .send()
205 .await
206 .map_err(|e| AuthError::OAuthError(format!("UserInfo request failed: {e}")))?;
207
208 if !response.status().is_success() {
209 return Err(AuthError::OAuthError(format!(
210 "UserInfo failed with status: {}",
211 response.status()
212 )));
213 }
214
215 response
216 .json()
217 .await
218 .map_err(|e| AuthError::OAuthError(format!("Failed to parse user info: {e}")))
219 }
220
221 pub async fn authenticate(&self, access_token: &str) -> Result<AuthResult> {
223 let user_info = self.get_user_info(access_token).await?;
224 let user = Self::map_oidc_user(user_info);
225 Ok(AuthResult::Authenticated(user))
226 }
227
228 pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuth2Token> {
230 let params = vec![
231 ("grant_type", "refresh_token"),
232 ("refresh_token", refresh_token),
233 ("client_id", &self.config.client_id),
234 ("client_secret", &self.config.client_secret),
235 ];
236
237 let response = self
238 .client
239 .post(&self.config.token_url)
240 .form(¶ms)
241 .send()
242 .await
243 .map_err(|e| AuthError::OAuthError(format!("Token refresh failed: {e}")))?;
244
245 if !response.status().is_success() {
246 return Err(AuthError::OAuthError(format!(
247 "Token refresh failed: {}",
248 response.status()
249 )));
250 }
251
252 let token_response: OAuth2TokenResponse = response
253 .json()
254 .await
255 .map_err(|e| AuthError::OAuthError(format!("Failed to parse refresh: {e}")))?;
256
257 Ok(OAuth2Token {
258 access_token: token_response.access_token,
259 token_type: token_response.token_type,
260 expires_in: token_response.expires_in.unwrap_or(3600),
261 refresh_token: token_response
262 .refresh_token
263 .or(Some(refresh_token.to_string())),
264 scope: token_response.scope.unwrap_or_default(),
265 id_token: token_response.id_token,
266 issued_at: Utc::now(),
267 })
268 }
269
270 fn map_oidc_user(user_info: OIDCUserInfo) -> User {
272 let username = user_info.email.as_ref().unwrap_or(&user_info.sub).clone();
273
274 let mut roles = Vec::new();
275 if let Some(oidc_roles) = &user_info.roles {
276 roles.extend(oidc_roles.clone());
277 }
278 if let Some(groups) = &user_info.groups {
279 for group in groups {
280 roles.push(map_group_to_role(group));
281 }
282 }
283 if roles.is_empty() {
284 roles.push("user".to_string());
285 }
286
287 let permissions = compute_permissions(&roles);
288
289 User {
290 username,
291 roles,
292 email: user_info.email,
293 full_name: user_info.name,
294 last_login: Some(Utc::now()),
295 permissions,
296 }
297 }
298
299 pub async fn cleanup_expired(&self) {
301 let now = Utc::now();
302 let mut states = self.active_states.write().await;
303 states.retain(|_, state| state.expires_at > now);
304 }
305}
306
307fn generate_code_verifier() -> String {
309 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
310 use rand::Rng;
311 let mut rng = rand::rng();
312
313 (0..128)
314 .map(|_| {
315 let idx = rng.random_range(0..CHARSET.len());
316 CHARSET[idx] as char
317 })
318 .collect()
319}
320
321fn generate_code_challenge(code_verifier: &str) -> String {
323 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
324 use sha2::{Digest, Sha256};
325 let digest = Sha256::digest(code_verifier.as_bytes());
326 URL_SAFE_NO_PAD.encode(digest)
327}
328
329fn url_encode(input: &str) -> String {
331 percent_encoding::utf8_percent_encode(input, percent_encoding::NON_ALPHANUMERIC).to_string()
332}
333
334fn map_group_to_role(group: &str) -> String {
336 match group.to_lowercase().as_str() {
337 "admin" | "administrators" => "admin".to_string(),
338 "writers" | "editors" => "writer".to_string(),
339 "readers" | "viewers" => "reader".to_string(),
340 _ => "user".to_string(),
341 }
342}
343
344fn compute_permissions(roles: &[String]) -> Vec<Permission> {
346 let mut permissions = Vec::new();
347
348 for role in roles {
349 match role.as_str() {
350 "admin" => {
351 permissions.extend(vec![
352 Permission::GlobalAdmin,
353 Permission::GlobalRead,
354 Permission::GlobalWrite,
355 Permission::Admin,
356 ]);
357 }
358 "writer" => {
359 permissions.extend(vec![
360 Permission::GlobalRead,
361 Permission::GlobalWrite,
362 Permission::Write,
363 ]);
364 }
365 "reader" | "user" => {
366 permissions.extend(vec![Permission::GlobalRead, Permission::Read]);
367 }
368 _ => {}
369 }
370 }
371
372 permissions.sort();
373 permissions.dedup();
374 permissions
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 fn create_test_config() -> OAuth2Config {
382 OAuth2Config {
383 provider: "test".to_string(),
384 client_id: "test_client_id".to_string(),
385 client_secret: "test_secret".to_string(),
386 auth_url: "https://provider.example.com/auth".to_string(),
387 token_url: "https://provider.example.com/token".to_string(),
388 user_info_url: "https://provider.example.com/userinfo".to_string(),
389 scopes: vec!["openid".to_string(), "profile".to_string()],
390 }
391 }
392
393 #[tokio::test]
394 async fn test_oauth2_service_creation() {
395 let config = create_test_config();
396 let service = OAuth2Service::new(config);
397 assert_eq!(service.config.provider, "test");
398 }
399
400 #[tokio::test]
401 async fn test_authorization_url() {
402 let config = create_test_config();
403 let service = OAuth2Service::new(config);
404
405 let (url, state) = service
406 .generate_authorization_url("http://localhost/callback", false)
407 .await
408 .unwrap();
409
410 assert!(url.contains("response_type=code"));
411 assert!(url.contains("client_id"));
412 assert!(!state.is_empty());
413 }
414
415 #[tokio::test]
416 async fn test_pkce_generation() {
417 let verifier = generate_code_verifier();
418 let challenge = generate_code_challenge(&verifier);
419
420 assert_eq!(verifier.len(), 128);
421 assert!(!challenge.is_empty());
422 assert_ne!(verifier, challenge);
423 }
424
425 #[test]
426 fn test_group_mapping() {
427 assert_eq!(map_group_to_role("admin"), "admin");
428 assert_eq!(map_group_to_role("administrators"), "admin");
429 assert_eq!(map_group_to_role("writers"), "writer");
430 assert_eq!(map_group_to_role("unknown"), "user");
431 }
432
433 #[test]
434 fn test_permission_computation() {
435 let perms = compute_permissions(&["admin".to_string()]);
436 assert!(perms.contains(&Permission::GlobalAdmin));
437
438 let perms = compute_permissions(&["reader".to_string()]);
439 assert!(perms.contains(&Permission::Read));
440 }
441}