auth_framework/server/oauth/
par.rs1use crate::errors::{AuthError, Result};
7use crate::storage::AuthStorage;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PushedAuthorizationRequest {
17 pub client_id: String,
19
20 pub response_type: String,
22
23 pub redirect_uri: String,
25
26 pub scope: Option<String>,
28
29 pub state: Option<String>,
31
32 pub code_challenge: Option<String>,
34
35 pub code_challenge_method: Option<String>,
37
38 #[serde(flatten)]
40 pub additional_params: HashMap<String, String>,
41}
42
43impl PushedAuthorizationRequest {
44 pub fn builder(
46 client_id: impl Into<String>,
47 response_type: impl Into<String>,
48 redirect_uri: impl Into<String>,
49 ) -> PushedAuthorizationRequestBuilder {
50 PushedAuthorizationRequestBuilder {
51 client_id: client_id.into(),
52 response_type: response_type.into(),
53 redirect_uri: redirect_uri.into(),
54 scope: None,
55 state: None,
56 code_challenge: None,
57 code_challenge_method: None,
58 additional_params: HashMap::new(),
59 }
60 }
61}
62
63pub struct PushedAuthorizationRequestBuilder {
65 client_id: String,
66 response_type: String,
67 redirect_uri: String,
68 scope: Option<String>,
69 state: Option<String>,
70 code_challenge: Option<String>,
71 code_challenge_method: Option<String>,
72 additional_params: HashMap<String, String>,
73}
74
75impl PushedAuthorizationRequestBuilder {
76 pub fn scope(mut self, scope: impl Into<String>) -> Self {
78 self.scope = Some(scope.into());
79 self
80 }
81
82 pub fn state(mut self, state: impl Into<String>) -> Self {
84 self.state = Some(state.into());
85 self
86 }
87
88 pub fn code_challenge(mut self, challenge: impl Into<String>) -> Self {
90 self.code_challenge = Some(challenge.into());
91 self
92 }
93
94 pub fn code_challenge_method(mut self, method: impl Into<String>) -> Self {
96 self.code_challenge_method = Some(method.into());
97 self
98 }
99
100 pub fn pkce(mut self, challenge: impl Into<String>, method: impl Into<String>) -> Self {
102 self.code_challenge = Some(challenge.into());
103 self.code_challenge_method = Some(method.into());
104 self
105 }
106
107 pub fn add_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109 self.additional_params.insert(key.into(), value.into());
110 self
111 }
112
113 pub fn build(self) -> PushedAuthorizationRequest {
115 PushedAuthorizationRequest {
116 client_id: self.client_id,
117 response_type: self.response_type,
118 redirect_uri: self.redirect_uri,
119 scope: self.scope,
120 state: self.state,
121 code_challenge: self.code_challenge,
122 code_challenge_method: self.code_challenge_method,
123 additional_params: self.additional_params,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct PushedAuthorizationResponse {
131 pub request_uri: String,
133
134 pub expires_in: u64,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct StoredPushedRequest {
141 pub request: PushedAuthorizationRequest,
143
144 pub created_at: SystemTime,
146
147 pub expires_at: SystemTime,
149
150 pub used: bool,
152}
153
154use std::fmt;
156
157#[derive(Clone)]
158pub struct PARManager {
159 storage: Arc<dyn AuthStorage>,
161
162 requests: Arc<tokio::sync::RwLock<HashMap<String, StoredPushedRequest>>>,
164
165 default_expiration: Duration,
167}
168
169impl fmt::Debug for PARManager {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("PARManager")
172 .field("storage", &"<dyn AuthStorage>")
173 .field("default_expiration", &self.default_expiration)
174 .finish()
175 }
176}
177
178impl PARManager {
179 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
181 Self {
182 storage,
183 requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
184 default_expiration: Duration::from_secs(90), }
186 }
187
188 pub fn with_expiration(storage: Arc<dyn AuthStorage>, expiration: Duration) -> Self {
190 Self {
191 storage,
192 requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
193 default_expiration: expiration,
194 }
195 }
196
197 pub fn expiration(mut self, expiration: Duration) -> Self {
201 self.default_expiration = expiration;
202 self
203 }
204
205 pub async fn store_request(
207 &self,
208 request: PushedAuthorizationRequest,
209 ) -> Result<PushedAuthorizationResponse> {
210 self.validate_request(&request)?;
212
213 let request_id = Uuid::new_v4().to_string();
215 let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
216
217 let now = SystemTime::now();
219 let expires_at = now + self.default_expiration;
220
221 let stored_request = StoredPushedRequest {
223 request: request.clone(),
224 created_at: now,
225 expires_at,
226 used: false,
227 };
228
229 let storage_key = format!("par:{}", request_uri);
231 let serialized = serde_json::to_string(&stored_request)
232 .map_err(|e| AuthError::internal(format!("Failed to serialize PAR request: {}", e)))?;
233
234 self.storage
235 .store_kv(
236 &storage_key,
237 &serialized.into_bytes(),
238 Some(self.default_expiration),
239 )
240 .await
241 .map_err(|e| AuthError::internal(format!("Failed to store PAR request: {}", e)))?;
242
243 let mut requests = self.requests.write().await;
245 requests.insert(request_uri.clone(), stored_request);
246
247 self.cleanup_expired_requests(&mut requests, now);
249
250 Ok(PushedAuthorizationResponse {
251 request_uri,
252 expires_in: self.default_expiration.as_secs(),
253 })
254 }
255
256 pub async fn consume_request(&self, request_uri: &str) -> Result<PushedAuthorizationRequest> {
258 let storage_key = format!("par:{}", request_uri);
259
260 let stored_request = if let Some(data) = self.storage.get_kv(&storage_key).await? {
262 let serialized = String::from_utf8(data)
263 .map_err(|_| AuthError::internal("Invalid UTF-8 in stored PAR data"))?;
264
265 serde_json::from_str::<StoredPushedRequest>(&serialized).map_err(|e| {
266 AuthError::internal(format!("Failed to deserialize PAR request: {}", e))
267 })?
268 } else {
269 let requests = self.requests.read().await;
271 requests
272 .get(request_uri)
273 .cloned()
274 .ok_or_else(|| AuthError::auth_method("par", "Invalid request_uri"))?
275 };
276
277 let now = SystemTime::now();
279 if now > stored_request.expires_at {
280 let _ = self.storage.delete_kv(&storage_key).await;
282 let mut requests = self.requests.write().await;
283 requests.remove(request_uri);
284 return Err(AuthError::auth_method("par", "Request URI expired"));
285 }
286
287 if stored_request.used {
289 return Err(AuthError::auth_method("par", "Request URI already used"));
290 }
291
292 self.storage
294 .delete_kv(&storage_key)
295 .await
296 .map_err(|e| AuthError::internal(format!("Failed to consume PAR request: {}", e)))?;
297
298 let mut requests = self.requests.write().await;
300 requests.remove(request_uri);
301
302 Ok(stored_request.request)
303 }
304
305 fn validate_request(&self, request: &PushedAuthorizationRequest) -> Result<()> {
307 if request.client_id.is_empty() {
309 return Err(AuthError::auth_method("par", "Missing client_id"));
310 }
311
312 if request.response_type.is_empty() {
313 return Err(AuthError::auth_method("par", "Missing response_type"));
314 }
315
316 if request.redirect_uri.is_empty() {
317 return Err(AuthError::auth_method("par", "Missing redirect_uri"));
318 }
319
320 if url::Url::parse(&request.redirect_uri).is_err() {
322 return Err(AuthError::auth_method("par", "Invalid redirect_uri format"));
323 }
324
325 if let (Some(challenge), Some(method)) =
327 (&request.code_challenge, &request.code_challenge_method)
328 {
329 if method != "S256" && method != "plain" {
330 return Err(AuthError::auth_method(
331 "par",
332 "Invalid code_challenge_method",
333 ));
334 }
335
336 if challenge.is_empty() {
337 return Err(AuthError::auth_method("par", "Empty code_challenge"));
338 }
339 }
340
341 Ok(())
342 }
343
344 fn cleanup_expired_requests(
346 &self,
347 requests: &mut HashMap<String, StoredPushedRequest>,
348 now: SystemTime,
349 ) {
350 requests.retain(|_, stored_request| now <= stored_request.expires_at);
351 }
352
353 pub async fn get_statistics(&self) -> PARStatistics {
355 let requests = self.requests.read().await;
356 let now = SystemTime::now();
357
358 let total_count = requests.len();
359 let expired_count = requests.values().filter(|req| now > req.expires_at).count();
360 let used_count = requests.values().filter(|req| req.used).count();
361
362 PARStatistics {
363 total_requests: total_count,
364 expired_requests: expired_count,
365 used_requests: used_count,
366 active_requests: total_count - expired_count - used_count,
367 }
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct PARStatistics {
374 pub total_requests: usize,
376
377 pub expired_requests: usize,
379
380 pub used_requests: usize,
382
383 pub active_requests: usize,
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use tokio::time::sleep;
391
392 #[test]
393 fn test_par_request_builder() {
394 let req = PushedAuthorizationRequest::builder("client_id", "code", "https://app/callback")
395 .scope("openid profile")
396 .state("state123")
397 .pkce("challenge_abc", "S256")
398 .add_param("custom", "value")
399 .build();
400
401 assert_eq!(req.client_id, "client_id");
402 assert_eq!(req.response_type, "code");
403 assert_eq!(req.redirect_uri, "https://app/callback");
404 assert_eq!(req.scope, Some("openid profile".to_string()));
405 assert_eq!(req.state, Some("state123".to_string()));
406 assert_eq!(req.code_challenge, Some("challenge_abc".to_string()));
407 assert_eq!(req.code_challenge_method, Some("S256".to_string()));
408 assert_eq!(req.additional_params.get("custom").map(String::as_str), Some("value"));
409 }
410
411 fn create_test_request() -> PushedAuthorizationRequest {
412 PushedAuthorizationRequest {
413 client_id: "test_client".to_string(),
414 response_type: "code".to_string(),
415 redirect_uri: "https://example.com/callback".to_string(),
416 scope: Some("openid profile".to_string()),
417 state: Some("test_state".to_string()),
418 code_challenge: Some("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()),
419 code_challenge_method: Some("S256".to_string()),
420 additional_params: HashMap::new(),
421 }
422 }
423
424 #[tokio::test]
425 async fn test_store_and_consume_request() {
426 use crate::storage::MemoryStorage;
427 use std::sync::Arc;
428
429 let storage = Arc::new(MemoryStorage::new());
430 let par_manager = PARManager::new(storage);
431 let request = create_test_request();
432
433 let response = par_manager.store_request(request.clone()).await.unwrap();
435 assert!(
436 response
437 .request_uri
438 .starts_with("urn:ietf:params:oauth:request_uri:")
439 );
440 assert_eq!(response.expires_in, 90);
441
442 let consumed_request = par_manager
444 .consume_request(&response.request_uri)
445 .await
446 .unwrap();
447 assert_eq!(consumed_request.client_id, request.client_id);
448 assert_eq!(consumed_request.response_type, request.response_type);
449
450 let result = par_manager.consume_request(&response.request_uri).await;
452 assert!(result.is_err());
453 }
454
455 #[tokio::test]
456 async fn test_request_expiration() {
457 use crate::storage::MemoryStorage;
458 use std::sync::Arc;
459
460 let storage = Arc::new(MemoryStorage::new());
461 let par_manager = PARManager::with_expiration(storage, Duration::from_millis(50));
462 let request = create_test_request();
463
464 let response = par_manager.store_request(request).await.unwrap();
466
467 sleep(Duration::from_millis(100)).await;
469
470 let result = par_manager.consume_request(&response.request_uri).await;
472 assert!(result.is_err());
473 }
474
475 #[tokio::test]
476 async fn test_invalid_request_validation() {
477 use crate::storage::MemoryStorage;
478 use std::sync::Arc;
479
480 let storage = Arc::new(MemoryStorage::new());
481 let par_manager = PARManager::new(storage);
482
483 let mut request = create_test_request();
485 request.client_id = "".to_string();
486 let result = par_manager.store_request(request).await;
487 assert!(result.is_err());
488
489 let mut request = create_test_request();
491 request.redirect_uri = "invalid-uri".to_string();
492 let result = par_manager.store_request(request).await;
493 assert!(result.is_err());
494
495 let mut request = create_test_request();
497 request.code_challenge_method = Some("invalid".to_string());
498 let result = par_manager.store_request(request).await;
499 assert!(result.is_err());
500 }
501
502 #[tokio::test]
503 async fn test_statistics() {
504 use crate::storage::MemoryStorage;
505 use std::sync::Arc;
506
507 let storage = Arc::new(MemoryStorage::new());
508 let par_manager = PARManager::new(storage);
509 let request = create_test_request();
510
511 let stats = par_manager.get_statistics().await;
513 assert_eq!(stats.total_requests, 0);
514
515 let response = par_manager.store_request(request).await.unwrap();
517 let stats = par_manager.get_statistics().await;
518 assert_eq!(stats.total_requests, 1);
519 assert_eq!(stats.active_requests, 1);
520
521 par_manager
523 .consume_request(&response.request_uri)
524 .await
525 .unwrap();
526 let stats = par_manager.get_statistics().await;
527 assert_eq!(stats.total_requests, 0); }
529}