Skip to main content

oxidite_auth/oauth2/
provider.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use uuid::Uuid;
6use crate::{AuthError, Result};
7use crate::oauth2::grants::AuthorizationCodeGrant;
8use base64::Engine;
9
10/// Authorization request
11#[derive(Debug, Clone, Deserialize)]
12pub struct AuthorizationRequest {
13    pub client_id: String,
14    pub redirect_uri: String,
15    pub response_type: String,
16    pub scope: Option<String>,
17    pub state: Option<String>,
18    pub code_challenge: Option<String>,
19    pub code_challenge_method: Option<String>,
20}
21
22/// Token request
23#[derive(Debug, Clone, Deserialize)]
24pub struct TokenRequest {
25    pub grant_type: String,
26    pub code: Option<String>,
27    pub redirect_uri: Option<String>,
28    pub client_id: String,
29    pub client_secret: String,
30    pub code_verifier: Option<String>,
31    pub refresh_token: Option<String>,
32}
33
34/// Token response
35#[derive(Debug, Clone, Serialize)]
36pub struct TokenResponse {
37    pub access_token: String,
38    pub token_type: String,
39    pub expires_in: u64,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub refresh_token: Option<String>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub scope: Option<String>,
44}
45
46/// OAuth2 provider
47pub struct OAuth2Provider {
48    codes: Arc<RwLock<HashMap<String, AuthorizationCodeGrant>>>,
49    clients: Arc<RwLock<HashMap<String, ClientConfig>>>,
50}
51
52#[derive(Debug, Clone)]
53pub struct ClientConfig {
54    pub client_id: String,
55    pub client_secret: String,
56    pub redirect_uris: Vec<String>,
57}
58
59impl OAuth2Provider {
60    pub fn new() -> Self {
61        Self {
62            codes: Arc::new(RwLock::new(HashMap::new())),
63            clients: Arc::new(RwLock::new(HashMap::new())),
64        }
65    }
66
67    /// Register a client
68    pub async fn register_client(&self, config: ClientConfig) -> Result<()> {
69        let mut clients = self.clients.write().await;
70        clients.insert(config.client_id.clone(), config);
71        Ok(())
72    }
73
74    /// Handle authorization request
75    pub async fn authorize(&self, req: AuthorizationRequest, _user_id: String) -> Result<String> {
76        // Validate client
77        let clients = self.clients.read().await;
78        let client = clients.get(&req.client_id)
79            .ok_or(AuthError::InvalidCredentials)?;
80
81        // Validate redirect URI
82        if !client.redirect_uris.contains(&req.redirect_uri) {
83            return Err(AuthError::InvalidCredentials);
84        }
85
86        // Generate authorization code
87        let mut grant = AuthorizationCodeGrant::new(
88            req.client_id.clone(),
89            req.redirect_uri.clone(),
90            600, // 10 minutes
91        );
92
93        if let Some(challenge) = req.code_challenge {
94            grant = grant.with_pkce(challenge);
95        }
96
97        let code = grant.code.clone();
98        let mut codes = self.codes.write().await;
99        codes.insert(code.clone(), grant);
100
101        Ok(code)
102    }
103
104    /// Exchange authorization code for access token
105    pub async fn exchange_code(&self, req: TokenRequest) -> Result<TokenResponse> {
106        let code = req.code.ok_or(AuthError::InvalidToken)?;
107
108        // Get and remove authorization code
109        let mut codes = self.codes.write().await;
110        let grant = codes.remove(&code).ok_or(AuthError::InvalidToken)?;
111
112        // Validate client
113        let clients = self.clients.read().await;
114        let client = clients.get(&req.client_id)
115            .ok_or(AuthError::InvalidCredentials)?;
116
117        if client.client_secret != req.client_secret {
118            return Err(AuthError::InvalidCredentials);
119        }
120
121        // Validate redirect URI
122        if let Some(redirect_uri) = req.redirect_uri {
123            if grant.redirect_uri != redirect_uri {
124                return Err(AuthError::InvalidCredentials);
125            }
126        }
127
128        // Check expiration
129        if grant.is_expired() {
130            return Err(AuthError::TokenExpired);
131        }
132
133        // Validate PKCE if used
134        if let Some(challenge) = grant.code_challenge {
135            let verifier = req.code_verifier.ok_or(AuthError::InvalidToken)?;
136            
137            // Verify PKCE challenge using SHA256
138            use sha2::{Sha256, Digest};
139            let mut hasher = Sha256::new();
140            hasher.update(verifier.as_bytes());
141            let computed_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD
142                .encode(hasher.finalize());
143            
144            if computed_challenge != challenge {
145                return Err(AuthError::InvalidToken);
146            }
147        }
148
149        // Generate access token
150        let access_token = Uuid::new_v4().to_string();
151        let refresh_token = Uuid::new_v4().to_string();
152
153        Ok(TokenResponse {
154            access_token,
155            token_type: "Bearer".to_string(),
156            expires_in: 3600,
157            refresh_token: Some(refresh_token),
158            scope: None,
159        })
160    }
161}
162
163impl Default for OAuth2Provider {
164    fn default() -> Self {
165        Self::new()
166    }
167}