oxify_authn/
oauth.rs

1//! OAuth2/OIDC authentication support
2//!
3//! Ported from `OxiRS` (<https://github.com/cool-japan/oxirs>)
4//! Original implementation: Copyright (c) `OxiRS` Contributors
5//! Adapted for `OxiFY` (simplified for maintainability)
6//! License: MIT OR Apache-2.0 (compatible with `OxiRS`)
7
8use crate::types::{AuthError, AuthResult, OAuth2Config, Permission, Result, User};
9use chrono::{DateTime, Duration, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// `OAuth2` authentication service
16#[derive(Clone)]
17pub struct OAuth2Service {
18    config: Arc<OAuth2Config>,
19    active_states: Arc<RwLock<HashMap<String, OAuth2State>>>,
20    client: reqwest::Client,
21}
22
23/// `OAuth2` state for authorization flow
24#[derive(Debug, Clone)]
25pub struct OAuth2State {
26    pub state: String,
27    pub code_verifier: Option<String>, // For PKCE
28    pub redirect_uri: String,
29    pub created_at: DateTime<Utc>,
30    pub expires_at: DateTime<Utc>,
31}
32
33/// `OAuth2` access token information
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct OAuth2Token {
36    pub access_token: String,
37    pub token_type: String,
38    pub expires_in: u64,
39    pub refresh_token: Option<String>,
40    pub scope: String,
41    pub id_token: Option<String>,
42    pub issued_at: DateTime<Utc>,
43}
44
45/// OIDC user information
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct OIDCUserInfo {
48    pub sub: String,
49    pub name: Option<String>,
50    pub email: Option<String>,
51    pub email_verified: Option<bool>,
52    pub groups: Option<Vec<String>>,
53    pub roles: Option<Vec<String>>,
54}
55
56/// `OAuth2` token response from provider
57#[derive(Debug, Deserialize)]
58struct OAuth2TokenResponse {
59    access_token: String,
60    token_type: String,
61    expires_in: Option<u64>,
62    refresh_token: Option<String>,
63    scope: Option<String>,
64    id_token: Option<String>,
65}
66
67impl OAuth2Service {
68    /// Create new `OAuth2` service
69    #[must_use]
70    pub fn new(config: OAuth2Config) -> Self {
71        let client = reqwest::Client::builder()
72            .timeout(std::time::Duration::from_secs(30))
73            .build()
74            .unwrap_or_default();
75
76        Self {
77            config: Arc::new(config),
78            active_states: Arc::new(RwLock::new(HashMap::new())),
79            client,
80        }
81    }
82
83    /// Generate authorization URL for `OAuth2` flow
84    pub async fn generate_authorization_url(
85        &self,
86        redirect_uri: &str,
87        use_pkce: bool,
88    ) -> Result<(String, String)> {
89        let state = uuid::Uuid::new_v4().to_string();
90        let scope_string = self.config.scopes.join(" ");
91
92        let mut url = format!(
93            "{}?response_type=code&client_id={}&redirect_uri={}&state={}&scope={}",
94            self.config.auth_url,
95            url_encode(&self.config.client_id),
96            url_encode(redirect_uri),
97            url_encode(&state),
98            url_encode(&scope_string)
99        );
100
101        let mut oauth_state = OAuth2State {
102            state: state.clone(),
103            code_verifier: None,
104            redirect_uri: redirect_uri.to_string(),
105            created_at: Utc::now(),
106            expires_at: Utc::now() + Duration::minutes(10),
107        };
108
109        // Add PKCE if requested
110        if use_pkce {
111            let code_verifier = generate_code_verifier();
112            let code_challenge = generate_code_challenge(&code_verifier);
113
114            url.push_str("&code_challenge=");
115            url.push_str(&url_encode(&code_challenge));
116            url.push_str("&code_challenge_method=S256");
117
118            oauth_state.code_verifier = Some(code_verifier);
119        }
120
121        // Store state
122        let mut states = self.active_states.write().await;
123        states.insert(state.clone(), oauth_state);
124
125        Ok((url, state))
126    }
127
128    /// Exchange authorization code for access token
129    pub async fn exchange_code_for_token(
130        &self,
131        code: &str,
132        state: &str,
133        redirect_uri: &str,
134    ) -> Result<OAuth2Token> {
135        // Validate and remove state
136        let oauth_state = {
137            let mut states = self.active_states.write().await;
138            states.remove(state)
139        };
140
141        let oauth_state = oauth_state.ok_or(AuthError::OAuthError(
142            "Invalid or expired state".to_string(),
143        ))?;
144
145        if oauth_state.redirect_uri != redirect_uri {
146            return Err(AuthError::OAuthError("Redirect URI mismatch".to_string()));
147        }
148
149        if Utc::now() > oauth_state.expires_at {
150            return Err(AuthError::OAuthError("State expired".to_string()));
151        }
152
153        // Prepare token request
154        let mut params = vec![
155            ("grant_type", "authorization_code".to_string()),
156            ("code", code.to_string()),
157            ("redirect_uri", redirect_uri.to_string()),
158            ("client_id", self.config.client_id.clone()),
159            ("client_secret", self.config.client_secret.clone()),
160        ];
161
162        // Add PKCE if used
163        if let Some(code_verifier) = oauth_state.code_verifier {
164            params.push(("code_verifier", code_verifier));
165        }
166
167        let response = self
168            .client
169            .post(&self.config.token_url)
170            .form(&params)
171            .send()
172            .await
173            .map_err(|e| AuthError::OAuthError(format!("Token exchange failed: {e}")))?;
174
175        if !response.status().is_success() {
176            let error_text = response.text().await.unwrap_or_default();
177            return Err(AuthError::OAuthError(format!(
178                "Token exchange failed: {error_text}"
179            )));
180        }
181
182        let token_response: OAuth2TokenResponse = response
183            .json()
184            .await
185            .map_err(|e| AuthError::OAuthError(format!("Failed to parse token: {e}")))?;
186
187        Ok(OAuth2Token {
188            access_token: token_response.access_token,
189            token_type: token_response.token_type,
190            expires_in: token_response.expires_in.unwrap_or(3600),
191            refresh_token: token_response.refresh_token,
192            scope: token_response.scope.unwrap_or_default(),
193            id_token: token_response.id_token,
194            issued_at: Utc::now(),
195        })
196    }
197
198    /// Get user information from OIDC userinfo endpoint
199    pub async fn get_user_info(&self, access_token: &str) -> Result<OIDCUserInfo> {
200        let response = self
201            .client
202            .get(&self.config.user_info_url)
203            .bearer_auth(access_token)
204            .send()
205            .await
206            .map_err(|e| AuthError::OAuthError(format!("UserInfo request failed: {e}")))?;
207
208        if !response.status().is_success() {
209            return Err(AuthError::OAuthError(format!(
210                "UserInfo failed with status: {}",
211                response.status()
212            )));
213        }
214
215        response
216            .json()
217            .await
218            .map_err(|e| AuthError::OAuthError(format!("Failed to parse user info: {e}")))
219    }
220
221    /// Authenticate user using OAuth2/OIDC
222    pub async fn authenticate(&self, access_token: &str) -> Result<AuthResult> {
223        let user_info = self.get_user_info(access_token).await?;
224        let user = Self::map_oidc_user(user_info);
225        Ok(AuthResult::Authenticated(user))
226    }
227
228    /// Refresh access token
229    pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuth2Token> {
230        let params = vec![
231            ("grant_type", "refresh_token"),
232            ("refresh_token", refresh_token),
233            ("client_id", &self.config.client_id),
234            ("client_secret", &self.config.client_secret),
235        ];
236
237        let response = self
238            .client
239            .post(&self.config.token_url)
240            .form(&params)
241            .send()
242            .await
243            .map_err(|e| AuthError::OAuthError(format!("Token refresh failed: {e}")))?;
244
245        if !response.status().is_success() {
246            return Err(AuthError::OAuthError(format!(
247                "Token refresh failed: {}",
248                response.status()
249            )));
250        }
251
252        let token_response: OAuth2TokenResponse = response
253            .json()
254            .await
255            .map_err(|e| AuthError::OAuthError(format!("Failed to parse refresh: {e}")))?;
256
257        Ok(OAuth2Token {
258            access_token: token_response.access_token,
259            token_type: token_response.token_type,
260            expires_in: token_response.expires_in.unwrap_or(3600),
261            refresh_token: token_response
262                .refresh_token
263                .or(Some(refresh_token.to_string())),
264            scope: token_response.scope.unwrap_or_default(),
265            id_token: token_response.id_token,
266            issued_at: Utc::now(),
267        })
268    }
269
270    /// Map OIDC user info to internal user
271    fn map_oidc_user(user_info: OIDCUserInfo) -> User {
272        let username = user_info.email.as_ref().unwrap_or(&user_info.sub).clone();
273
274        let mut roles = Vec::new();
275        if let Some(oidc_roles) = &user_info.roles {
276            roles.extend(oidc_roles.clone());
277        }
278        if let Some(groups) = &user_info.groups {
279            for group in groups {
280                roles.push(map_group_to_role(group));
281            }
282        }
283        if roles.is_empty() {
284            roles.push("user".to_string());
285        }
286
287        let permissions = compute_permissions(&roles);
288
289        User {
290            username,
291            roles,
292            email: user_info.email,
293            full_name: user_info.name,
294            last_login: Some(Utc::now()),
295            permissions,
296        }
297    }
298
299    /// Cleanup expired states
300    pub async fn cleanup_expired(&self) {
301        let now = Utc::now();
302        let mut states = self.active_states.write().await;
303        states.retain(|_, state| state.expires_at > now);
304    }
305}
306
307/// Generate code verifier for PKCE
308fn generate_code_verifier() -> String {
309    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
310    use rand::Rng;
311    let mut rng = rand::rng();
312
313    (0..128)
314        .map(|_| {
315            let idx = rng.random_range(0..CHARSET.len());
316            CHARSET[idx] as char
317        })
318        .collect()
319}
320
321/// Generate code challenge for PKCE (S256 method)
322fn generate_code_challenge(code_verifier: &str) -> String {
323    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
324    use sha2::{Digest, Sha256};
325    let digest = Sha256::digest(code_verifier.as_bytes());
326    URL_SAFE_NO_PAD.encode(digest)
327}
328
329/// URL encoding helper
330fn url_encode(input: &str) -> String {
331    percent_encoding::utf8_percent_encode(input, percent_encoding::NON_ALPHANUMERIC).to_string()
332}
333
334/// Map OIDC group to internal role
335fn map_group_to_role(group: &str) -> String {
336    match group.to_lowercase().as_str() {
337        "admin" | "administrators" => "admin".to_string(),
338        "writers" | "editors" => "writer".to_string(),
339        "readers" | "viewers" => "reader".to_string(),
340        _ => "user".to_string(),
341    }
342}
343
344/// Compute permissions from roles
345fn compute_permissions(roles: &[String]) -> Vec<Permission> {
346    let mut permissions = Vec::new();
347
348    for role in roles {
349        match role.as_str() {
350            "admin" => {
351                permissions.extend(vec![
352                    Permission::GlobalAdmin,
353                    Permission::GlobalRead,
354                    Permission::GlobalWrite,
355                    Permission::Admin,
356                ]);
357            }
358            "writer" => {
359                permissions.extend(vec![
360                    Permission::GlobalRead,
361                    Permission::GlobalWrite,
362                    Permission::Write,
363                ]);
364            }
365            "reader" | "user" => {
366                permissions.extend(vec![Permission::GlobalRead, Permission::Read]);
367            }
368            _ => {}
369        }
370    }
371
372    permissions.sort();
373    permissions.dedup();
374    permissions
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    fn create_test_config() -> OAuth2Config {
382        OAuth2Config {
383            provider: "test".to_string(),
384            client_id: "test_client_id".to_string(),
385            client_secret: "test_secret".to_string(),
386            auth_url: "https://provider.example.com/auth".to_string(),
387            token_url: "https://provider.example.com/token".to_string(),
388            user_info_url: "https://provider.example.com/userinfo".to_string(),
389            scopes: vec!["openid".to_string(), "profile".to_string()],
390        }
391    }
392
393    #[tokio::test]
394    async fn test_oauth2_service_creation() {
395        let config = create_test_config();
396        let service = OAuth2Service::new(config);
397        assert_eq!(service.config.provider, "test");
398    }
399
400    #[tokio::test]
401    async fn test_authorization_url() {
402        let config = create_test_config();
403        let service = OAuth2Service::new(config);
404
405        let (url, state) = service
406            .generate_authorization_url("http://localhost/callback", false)
407            .await
408            .unwrap();
409
410        assert!(url.contains("response_type=code"));
411        assert!(url.contains("client_id"));
412        assert!(!state.is_empty());
413    }
414
415    #[tokio::test]
416    async fn test_pkce_generation() {
417        let verifier = generate_code_verifier();
418        let challenge = generate_code_challenge(&verifier);
419
420        assert_eq!(verifier.len(), 128);
421        assert!(!challenge.is_empty());
422        assert_ne!(verifier, challenge);
423    }
424
425    #[test]
426    fn test_group_mapping() {
427        assert_eq!(map_group_to_role("admin"), "admin");
428        assert_eq!(map_group_to_role("administrators"), "admin");
429        assert_eq!(map_group_to_role("writers"), "writer");
430        assert_eq!(map_group_to_role("unknown"), "user");
431    }
432
433    #[test]
434    fn test_permission_computation() {
435        let perms = compute_permissions(&["admin".to_string()]);
436        assert!(perms.contains(&Permission::GlobalAdmin));
437
438        let perms = compute_permissions(&["reader".to_string()]);
439        assert!(perms.contains(&Permission::Read));
440    }
441}