oxidite_auth/oauth2/
provider.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use uuid::Uuid;
6use crate::{AuthError, Result};
7use crate::oauth2::grants::AuthorizationCodeGrant;
8use base64::Engine;
9
10#[derive(Debug, Clone, Deserialize)]
12pub struct AuthorizationRequest {
13 pub client_id: String,
14 pub redirect_uri: String,
15 pub response_type: String,
16 pub scope: Option<String>,
17 pub state: Option<String>,
18 pub code_challenge: Option<String>,
19 pub code_challenge_method: Option<String>,
20}
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct TokenRequest {
25 pub grant_type: String,
26 pub code: Option<String>,
27 pub redirect_uri: Option<String>,
28 pub client_id: String,
29 pub client_secret: String,
30 pub code_verifier: Option<String>,
31 pub refresh_token: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize)]
36pub struct TokenResponse {
37 pub access_token: String,
38 pub token_type: String,
39 pub expires_in: u64,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub refresh_token: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub scope: Option<String>,
44}
45
46pub struct OAuth2Provider {
48 codes: Arc<RwLock<HashMap<String, AuthorizationCodeGrant>>>,
49 clients: Arc<RwLock<HashMap<String, ClientConfig>>>,
50}
51
52#[derive(Debug, Clone)]
53pub struct ClientConfig {
54 pub client_id: String,
55 pub client_secret: String,
56 pub redirect_uris: Vec<String>,
57}
58
59impl OAuth2Provider {
60 pub fn new() -> Self {
61 Self {
62 codes: Arc::new(RwLock::new(HashMap::new())),
63 clients: Arc::new(RwLock::new(HashMap::new())),
64 }
65 }
66
67 pub async fn register_client(&self, config: ClientConfig) -> Result<()> {
69 let mut clients = self.clients.write().await;
70 clients.insert(config.client_id.clone(), config);
71 Ok(())
72 }
73
74 pub async fn authorize(&self, req: AuthorizationRequest, _user_id: String) -> Result<String> {
76 let clients = self.clients.read().await;
78 let client = clients.get(&req.client_id)
79 .ok_or(AuthError::InvalidCredentials)?;
80
81 if !client.redirect_uris.contains(&req.redirect_uri) {
83 return Err(AuthError::InvalidCredentials);
84 }
85
86 let mut grant = AuthorizationCodeGrant::new(
88 req.client_id.clone(),
89 req.redirect_uri.clone(),
90 600, );
92
93 if let Some(challenge) = req.code_challenge {
94 grant = grant.with_pkce(challenge);
95 }
96
97 let code = grant.code.clone();
98 let mut codes = self.codes.write().await;
99 codes.insert(code.clone(), grant);
100
101 Ok(code)
102 }
103
104 pub async fn exchange_code(&self, req: TokenRequest) -> Result<TokenResponse> {
106 let code = req.code.ok_or(AuthError::InvalidToken)?;
107
108 let mut codes = self.codes.write().await;
110 let grant = codes.remove(&code).ok_or(AuthError::InvalidToken)?;
111
112 let clients = self.clients.read().await;
114 let client = clients.get(&req.client_id)
115 .ok_or(AuthError::InvalidCredentials)?;
116
117 if client.client_secret != req.client_secret {
118 return Err(AuthError::InvalidCredentials);
119 }
120
121 if let Some(redirect_uri) = req.redirect_uri {
123 if grant.redirect_uri != redirect_uri {
124 return Err(AuthError::InvalidCredentials);
125 }
126 }
127
128 if grant.is_expired() {
130 return Err(AuthError::TokenExpired);
131 }
132
133 if let Some(challenge) = grant.code_challenge {
135 let verifier = req.code_verifier.ok_or(AuthError::InvalidToken)?;
136
137 use sha2::{Sha256, Digest};
139 let mut hasher = Sha256::new();
140 hasher.update(verifier.as_bytes());
141 let computed_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD
142 .encode(hasher.finalize());
143
144 if computed_challenge != challenge {
145 return Err(AuthError::InvalidToken);
146 }
147 }
148
149 let access_token = Uuid::new_v4().to_string();
151 let refresh_token = Uuid::new_v4().to_string();
152
153 Ok(TokenResponse {
154 access_token,
155 token_type: "Bearer".to_string(),
156 expires_in: 3600,
157 refresh_token: Some(refresh_token),
158 scope: None,
159 })
160 }
161}
162
163impl Default for OAuth2Provider {
164 fn default() -> Self {
165 Self::new()
166 }
167}