auth_framework/server/oauth/
par.rs1use crate::errors::{AuthError, Result};
7use crate::storage::AuthStorage;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PushedAuthorizationRequest {
17 pub client_id: String,
19
20 pub response_type: String,
22
23 pub redirect_uri: String,
25
26 pub scope: Option<String>,
28
29 pub state: Option<String>,
31
32 pub code_challenge: Option<String>,
34
35 pub code_challenge_method: Option<String>,
37
38 #[serde(flatten)]
40 pub additional_params: HashMap<String, String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PushedAuthorizationResponse {
46 pub request_uri: String,
48
49 pub expires_in: u64,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StoredPushedRequest {
56 pub request: PushedAuthorizationRequest,
58
59 pub created_at: SystemTime,
61
62 pub expires_at: SystemTime,
64
65 pub used: bool,
67}
68
69use std::fmt;
71
72#[derive(Clone)]
73pub struct PARManager {
74 storage: Arc<dyn AuthStorage>,
76
77 requests: Arc<tokio::sync::RwLock<HashMap<String, StoredPushedRequest>>>,
79
80 default_expiration: Duration,
82}
83
84impl fmt::Debug for PARManager {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 f.debug_struct("PARManager")
87 .field("storage", &"<dyn AuthStorage>")
88 .field("default_expiration", &self.default_expiration)
89 .finish()
90 }
91}
92
93impl PARManager {
94 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
96 Self {
97 storage,
98 requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
99 default_expiration: Duration::from_secs(90), }
101 }
102
103 pub fn with_expiration(storage: Arc<dyn AuthStorage>, expiration: Duration) -> Self {
105 Self {
106 storage,
107 requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
108 default_expiration: expiration,
109 }
110 }
111
112 pub async fn store_request(
114 &self,
115 request: PushedAuthorizationRequest,
116 ) -> Result<PushedAuthorizationResponse> {
117 self.validate_request(&request)?;
119
120 let request_id = Uuid::new_v4().to_string();
122 let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
123
124 let now = SystemTime::now();
126 let expires_at = now + self.default_expiration;
127
128 let stored_request = StoredPushedRequest {
130 request: request.clone(),
131 created_at: now,
132 expires_at,
133 used: false,
134 };
135
136 let storage_key = format!("par:{}", request_uri);
138 let serialized = serde_json::to_string(&stored_request)
139 .map_err(|e| AuthError::internal(format!("Failed to serialize PAR request: {}", e)))?;
140
141 self.storage
142 .store_kv(
143 &storage_key,
144 &serialized.into_bytes(),
145 Some(self.default_expiration),
146 )
147 .await
148 .map_err(|e| AuthError::internal(format!("Failed to store PAR request: {}", e)))?;
149
150 let mut requests = self.requests.write().await;
152 requests.insert(request_uri.clone(), stored_request);
153
154 self.cleanup_expired_requests(&mut requests, now);
156
157 Ok(PushedAuthorizationResponse {
158 request_uri,
159 expires_in: self.default_expiration.as_secs(),
160 })
161 }
162
163 pub async fn consume_request(&self, request_uri: &str) -> Result<PushedAuthorizationRequest> {
165 let storage_key = format!("par:{}", request_uri);
166
167 let stored_request = if let Some(data) = self.storage.get_kv(&storage_key).await? {
169 let serialized = String::from_utf8(data)
170 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored PAR data"))?;
171
172 serde_json::from_str::<StoredPushedRequest>(&serialized).map_err(|e| {
173 AuthError::internal(format!("Failed to deserialize PAR request: {}", e))
174 })?
175 } else {
176 let requests = self.requests.read().await;
178 requests
179 .get(request_uri)
180 .cloned()
181 .ok_or_else(|| AuthError::auth_method("par", "Invalid request_uri"))?
182 };
183
184 let now = SystemTime::now();
186 if now > stored_request.expires_at {
187 let _ = self.storage.delete_kv(&storage_key).await;
189 let mut requests = self.requests.write().await;
190 requests.remove(request_uri);
191 return Err(AuthError::auth_method("par", "Request URI expired"));
192 }
193
194 if stored_request.used {
196 return Err(AuthError::auth_method("par", "Request URI already used"));
197 }
198
199 self.storage
201 .delete_kv(&storage_key)
202 .await
203 .map_err(|e| AuthError::internal(format!("Failed to consume PAR request: {}", e)))?;
204
205 let mut requests = self.requests.write().await;
207 requests.remove(request_uri);
208
209 Ok(stored_request.request)
210 }
211
212 fn validate_request(&self, request: &PushedAuthorizationRequest) -> Result<()> {
214 if request.client_id.is_empty() {
216 return Err(AuthError::auth_method("par", "Missing client_id"));
217 }
218
219 if request.response_type.is_empty() {
220 return Err(AuthError::auth_method("par", "Missing response_type"));
221 }
222
223 if request.redirect_uri.is_empty() {
224 return Err(AuthError::auth_method("par", "Missing redirect_uri"));
225 }
226
227 if url::Url::parse(&request.redirect_uri).is_err() {
229 return Err(AuthError::auth_method("par", "Invalid redirect_uri format"));
230 }
231
232 if let (Some(challenge), Some(method)) =
234 (&request.code_challenge, &request.code_challenge_method)
235 {
236 if method != "S256" && method != "plain" {
237 return Err(AuthError::auth_method(
238 "par",
239 "Invalid code_challenge_method",
240 ));
241 }
242
243 if challenge.is_empty() {
244 return Err(AuthError::auth_method("par", "Empty code_challenge"));
245 }
246 }
247
248 Ok(())
249 }
250
251 fn cleanup_expired_requests(
253 &self,
254 requests: &mut HashMap<String, StoredPushedRequest>,
255 now: SystemTime,
256 ) {
257 requests.retain(|_, stored_request| now <= stored_request.expires_at);
258 }
259
260 pub async fn get_statistics(&self) -> PARStatistics {
262 let requests = self.requests.read().await;
263 let now = SystemTime::now();
264
265 let total_count = requests.len();
266 let expired_count = requests.values().filter(|req| now > req.expires_at).count();
267 let used_count = requests.values().filter(|req| req.used).count();
268
269 PARStatistics {
270 total_requests: total_count,
271 expired_requests: expired_count,
272 used_requests: used_count,
273 active_requests: total_count - expired_count - used_count,
274 }
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct PARStatistics {
281 pub total_requests: usize,
283
284 pub expired_requests: usize,
286
287 pub used_requests: usize,
289
290 pub active_requests: usize,
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use tokio::time::sleep;
298
299 fn create_test_request() -> PushedAuthorizationRequest {
300 PushedAuthorizationRequest {
301 client_id: "test_client".to_string(),
302 response_type: "code".to_string(),
303 redirect_uri: "https://example.com/callback".to_string(),
304 scope: Some("openid profile".to_string()),
305 state: Some("test_state".to_string()),
306 code_challenge: Some("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()),
307 code_challenge_method: Some("S256".to_string()),
308 additional_params: HashMap::new(),
309 }
310 }
311
312 #[tokio::test]
313 async fn test_store_and_consume_request() {
314 use crate::storage::MemoryStorage;
315 use std::sync::Arc;
316
317 let storage = Arc::new(MemoryStorage::new());
318 let par_manager = PARManager::new(storage);
319 let request = create_test_request();
320
321 let response = par_manager.store_request(request.clone()).await.unwrap();
323 assert!(
324 response
325 .request_uri
326 .starts_with("urn:ietf:params:oauth:request_uri:")
327 );
328 assert_eq!(response.expires_in, 90);
329
330 let consumed_request = par_manager
332 .consume_request(&response.request_uri)
333 .await
334 .unwrap();
335 assert_eq!(consumed_request.client_id, request.client_id);
336 assert_eq!(consumed_request.response_type, request.response_type);
337
338 let result = par_manager.consume_request(&response.request_uri).await;
340 assert!(result.is_err());
341 }
342
343 #[tokio::test]
344 async fn test_request_expiration() {
345 use crate::storage::MemoryStorage;
346 use std::sync::Arc;
347
348 let storage = Arc::new(MemoryStorage::new());
349 let par_manager = PARManager::with_expiration(storage, Duration::from_millis(50));
350 let request = create_test_request();
351
352 let response = par_manager.store_request(request).await.unwrap();
354
355 sleep(Duration::from_millis(100)).await;
357
358 let result = par_manager.consume_request(&response.request_uri).await;
360 assert!(result.is_err());
361 }
362
363 #[tokio::test]
364 async fn test_invalid_request_validation() {
365 use crate::storage::MemoryStorage;
366 use std::sync::Arc;
367
368 let storage = Arc::new(MemoryStorage::new());
369 let par_manager = PARManager::new(storage);
370
371 let mut request = create_test_request();
373 request.client_id = "".to_string();
374 let result = par_manager.store_request(request).await;
375 assert!(result.is_err());
376
377 let mut request = create_test_request();
379 request.redirect_uri = "invalid-uri".to_string();
380 let result = par_manager.store_request(request).await;
381 assert!(result.is_err());
382
383 let mut request = create_test_request();
385 request.code_challenge_method = Some("invalid".to_string());
386 let result = par_manager.store_request(request).await;
387 assert!(result.is_err());
388 }
389
390 #[tokio::test]
391 async fn test_statistics() {
392 use crate::storage::MemoryStorage;
393 use std::sync::Arc;
394
395 let storage = Arc::new(MemoryStorage::new());
396 let par_manager = PARManager::new(storage);
397 let request = create_test_request();
398
399 let stats = par_manager.get_statistics().await;
401 assert_eq!(stats.total_requests, 0);
402
403 let response = par_manager.store_request(request).await.unwrap();
405 let stats = par_manager.get_statistics().await;
406 assert_eq!(stats.total_requests, 1);
407 assert_eq!(stats.active_requests, 1);
408
409 par_manager
411 .consume_request(&response.request_uri)
412 .await
413 .unwrap();
414 let stats = par_manager.get_statistics().await;
415 assert_eq!(stats.total_requests, 0); }
417}
418
419