Skip to main content

auth_framework/server/oauth/
par.rs

1//! Pushed Authorization Requests (PAR) Implementation - RFC 9126
2//!
3//! This module implements RFC 9126 - OAuth 2.0 Pushed Authorization Requests
4/// which enhances security by allowing clients to push authorization request
5/// parameters directly to the authorization server.
6use crate::errors::{AuthError, Result};
7use crate::storage::AuthStorage;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use uuid::Uuid;
13
14/// PAR request containing authorization parameters
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PushedAuthorizationRequest {
17    /// Client identifier
18    pub client_id: String,
19
20    /// Response type (e.g., "code")
21    pub response_type: String,
22
23    /// Redirect URI
24    pub redirect_uri: String,
25
26    /// Requested scopes
27    pub scope: Option<String>,
28
29    /// State parameter
30    pub state: Option<String>,
31
32    /// PKCE code challenge
33    pub code_challenge: Option<String>,
34
35    /// PKCE code challenge method
36    pub code_challenge_method: Option<String>,
37
38    /// Additional parameters
39    #[serde(flatten)]
40    pub additional_params: HashMap<String, String>,
41}
42
43impl PushedAuthorizationRequest {
44    /// Create a new builder for a PushedAuthorizationRequest.
45    pub fn builder(
46        client_id: impl Into<String>,
47        response_type: impl Into<String>,
48        redirect_uri: impl Into<String>,
49    ) -> PushedAuthorizationRequestBuilder {
50        PushedAuthorizationRequestBuilder {
51            client_id: client_id.into(),
52            response_type: response_type.into(),
53            redirect_uri: redirect_uri.into(),
54            scope: None,
55            state: None,
56            code_challenge: None,
57            code_challenge_method: None,
58            additional_params: HashMap::new(),
59        }
60    }
61}
62
63/// Builder for PushedAuthorizationRequest
64pub struct PushedAuthorizationRequestBuilder {
65    client_id: String,
66    response_type: String,
67    redirect_uri: String,
68    scope: Option<String>,
69    state: Option<String>,
70    code_challenge: Option<String>,
71    code_challenge_method: Option<String>,
72    additional_params: HashMap<String, String>,
73}
74
75impl PushedAuthorizationRequestBuilder {
76    /// Set the scopes list
77    pub fn scope(mut self, scope: impl Into<String>) -> Self {
78        self.scope = Some(scope.into());
79        self
80    }
81
82    /// Set the state parameter
83    pub fn state(mut self, state: impl Into<String>) -> Self {
84        self.state = Some(state.into());
85        self
86    }
87
88    /// Set PKCE code challenge
89    pub fn code_challenge(mut self, challenge: impl Into<String>) -> Self {
90        self.code_challenge = Some(challenge.into());
91        self
92    }
93
94    /// Set PKCE code challenge method
95    pub fn code_challenge_method(mut self, method: impl Into<String>) -> Self {
96        self.code_challenge_method = Some(method.into());
97        self
98    }
99
100    /// Set PKCE challenge and method together
101    pub fn pkce(mut self, challenge: impl Into<String>, method: impl Into<String>) -> Self {
102        self.code_challenge = Some(challenge.into());
103        self.code_challenge_method = Some(method.into());
104        self
105    }
106
107    /// Add an additional custom parameter
108    pub fn add_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109        self.additional_params.insert(key.into(), value.into());
110        self
111    }
112
113    /// Build the request
114    pub fn build(self) -> PushedAuthorizationRequest {
115        PushedAuthorizationRequest {
116            client_id: self.client_id,
117            response_type: self.response_type,
118            redirect_uri: self.redirect_uri,
119            scope: self.scope,
120            state: self.state,
121            code_challenge: self.code_challenge,
122            code_challenge_method: self.code_challenge_method,
123            additional_params: self.additional_params,
124        }
125    }
126}
127
128/// PAR response containing request URI
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct PushedAuthorizationResponse {
131    /// Request URI to be used in subsequent authorization request
132    pub request_uri: String,
133
134    /// Expiration time in seconds
135    pub expires_in: u64,
136}
137
138/// Stored PAR request with metadata
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct StoredPushedRequest {
141    /// Original request parameters
142    pub request: PushedAuthorizationRequest,
143
144    /// When the request was created
145    pub created_at: SystemTime,
146
147    /// When the request expires
148    pub expires_at: SystemTime,
149
150    /// Whether the request has been used
151    pub used: bool,
152}
153
154/// PAR request manager with persistent storage
155use std::fmt;
156
157#[derive(Clone)]
158pub struct PARManager {
159    /// Persistent storage backend
160    storage: Arc<dyn AuthStorage>,
161
162    /// Memory cache for fast access
163    requests: Arc<tokio::sync::RwLock<HashMap<String, StoredPushedRequest>>>,
164
165    /// Default expiration time for PAR requests
166    default_expiration: Duration,
167}
168
169impl fmt::Debug for PARManager {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("PARManager")
172            .field("storage", &"<dyn AuthStorage>")
173            .field("default_expiration", &self.default_expiration)
174            .finish()
175    }
176}
177
178impl PARManager {
179    /// Create a new PAR manager with storage backend
180    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
181        Self {
182            storage,
183            requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
184            default_expiration: Duration::from_secs(90), // RFC 9126 recommendation
185        }
186    }
187
188    /// Create a new PAR manager with custom expiration
189    pub fn with_expiration(storage: Arc<dyn AuthStorage>, expiration: Duration) -> Self {
190        Self {
191            storage,
192            requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
193            default_expiration: expiration,
194        }
195    }
196
197    /// Set the default expiration time for PAR requests (chainable).
198    ///
199    /// Default: 90 seconds (RFC 9126 recommendation).
200    pub fn expiration(mut self, expiration: Duration) -> Self {
201        self.default_expiration = expiration;
202        self
203    }
204
205    /// Store a pushed authorization request
206    pub async fn store_request(
207        &self,
208        request: PushedAuthorizationRequest,
209    ) -> Result<PushedAuthorizationResponse> {
210        // Validate the request
211        self.validate_request(&request)?;
212
213        // Generate request URI
214        let request_id = Uuid::new_v4().to_string();
215        let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
216
217        // Calculate expiration
218        let now = SystemTime::now();
219        let expires_at = now + self.default_expiration;
220
221        // Store the request in persistent storage with TTL
222        let stored_request = StoredPushedRequest {
223            request: request.clone(),
224            created_at: now,
225            expires_at,
226            used: false,
227        };
228
229        // Store in persistent backend with TTL
230        let storage_key = format!("par:{}", request_uri);
231        let serialized = serde_json::to_string(&stored_request)
232            .map_err(|e| AuthError::internal(format!("Failed to serialize PAR request: {}", e)))?;
233
234        self.storage
235            .store_kv(
236                &storage_key,
237                &serialized.into_bytes(),
238                Some(self.default_expiration),
239            )
240            .await
241            .map_err(|e| AuthError::internal(format!("Failed to store PAR request: {}", e)))?;
242
243        // Also cache in memory for fast access
244        let mut requests = self.requests.write().await;
245        requests.insert(request_uri.clone(), stored_request);
246
247        // Clean up expired requests from memory cache
248        self.cleanup_expired_requests(&mut requests, now);
249
250        Ok(PushedAuthorizationResponse {
251            request_uri,
252            expires_in: self.default_expiration.as_secs(),
253        })
254    }
255
256    /// Retrieve and consume a pushed authorization request
257    pub async fn consume_request(&self, request_uri: &str) -> Result<PushedAuthorizationRequest> {
258        let storage_key = format!("par:{}", request_uri);
259
260        // Try to load from persistent storage first
261        let stored_request = if let Some(data) = self.storage.get_kv(&storage_key).await? {
262            let serialized = String::from_utf8(data)
263                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored PAR data"))?;
264
265            serde_json::from_str::<StoredPushedRequest>(&serialized).map_err(|e| {
266                AuthError::internal(format!("Failed to deserialize PAR request: {}", e))
267            })?
268        } else {
269            // Fallback to memory cache (for backward compatibility during transition)
270            let requests = self.requests.read().await;
271            requests
272                .get(request_uri)
273                .cloned()
274                .ok_or_else(|| AuthError::auth_method("par", "Invalid request_uri"))?
275        };
276
277        // Check if expired
278        let now = SystemTime::now();
279        if now > stored_request.expires_at {
280            // Clean up from both storage and cache
281            let _ = self.storage.delete_kv(&storage_key).await;
282            let mut requests = self.requests.write().await;
283            requests.remove(request_uri);
284            return Err(AuthError::auth_method("par", "Request URI expired"));
285        }
286
287        // Check if already used
288        if stored_request.used {
289            return Err(AuthError::auth_method("par", "Request URI already used"));
290        }
291
292        // Mark as consumed by removing from storage (single use)
293        self.storage
294            .delete_kv(&storage_key)
295            .await
296            .map_err(|e| AuthError::internal(format!("Failed to consume PAR request: {}", e)))?;
297
298        // Also remove from memory cache
299        let mut requests = self.requests.write().await;
300        requests.remove(request_uri);
301
302        Ok(stored_request.request)
303    }
304
305    /// Validate a PAR request
306    fn validate_request(&self, request: &PushedAuthorizationRequest) -> Result<()> {
307        // Validate required parameters
308        if request.client_id.is_empty() {
309            return Err(AuthError::auth_method("par", "Missing client_id"));
310        }
311
312        if request.response_type.is_empty() {
313            return Err(AuthError::auth_method("par", "Missing response_type"));
314        }
315
316        if request.redirect_uri.is_empty() {
317            return Err(AuthError::auth_method("par", "Missing redirect_uri"));
318        }
319
320        // Validate redirect URI format
321        if url::Url::parse(&request.redirect_uri).is_err() {
322            return Err(AuthError::auth_method("par", "Invalid redirect_uri format"));
323        }
324
325        // Validate PKCE parameters if present
326        if let (Some(challenge), Some(method)) =
327            (&request.code_challenge, &request.code_challenge_method)
328        {
329            if method != "S256" && method != "plain" {
330                return Err(AuthError::auth_method(
331                    "par",
332                    "Invalid code_challenge_method",
333                ));
334            }
335
336            if challenge.is_empty() {
337                return Err(AuthError::auth_method("par", "Empty code_challenge"));
338            }
339        }
340
341        Ok(())
342    }
343
344    /// Clean up expired requests
345    fn cleanup_expired_requests(
346        &self,
347        requests: &mut HashMap<String, StoredPushedRequest>,
348        now: SystemTime,
349    ) {
350        requests.retain(|_, stored_request| now <= stored_request.expires_at);
351    }
352
353    /// Get statistics about stored requests
354    pub async fn get_statistics(&self) -> PARStatistics {
355        let requests = self.requests.read().await;
356        let now = SystemTime::now();
357
358        let total_count = requests.len();
359        let expired_count = requests.values().filter(|req| now > req.expires_at).count();
360        let used_count = requests.values().filter(|req| req.used).count();
361
362        PARStatistics {
363            total_requests: total_count,
364            expired_requests: expired_count,
365            used_requests: used_count,
366            active_requests: total_count - expired_count - used_count,
367        }
368    }
369}
370
371/// PAR statistics
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct PARStatistics {
374    /// Total number of stored requests
375    pub total_requests: usize,
376
377    /// Number of expired requests
378    pub expired_requests: usize,
379
380    /// Number of used requests
381    pub used_requests: usize,
382
383    /// Number of active (valid, unused) requests
384    pub active_requests: usize,
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use tokio::time::sleep;
391
392    #[test]
393    fn test_par_request_builder() {
394        let req = PushedAuthorizationRequest::builder("client_id", "code", "https://app/callback")
395            .scope("openid profile")
396            .state("state123")
397            .pkce("challenge_abc", "S256")
398            .add_param("custom", "value")
399            .build();
400
401        assert_eq!(req.client_id, "client_id");
402        assert_eq!(req.response_type, "code");
403        assert_eq!(req.redirect_uri, "https://app/callback");
404        assert_eq!(req.scope, Some("openid profile".to_string()));
405        assert_eq!(req.state, Some("state123".to_string()));
406        assert_eq!(req.code_challenge, Some("challenge_abc".to_string()));
407        assert_eq!(req.code_challenge_method, Some("S256".to_string()));
408        assert_eq!(req.additional_params.get("custom").map(String::as_str), Some("value"));
409    }
410
411    fn create_test_request() -> PushedAuthorizationRequest {
412        PushedAuthorizationRequest {
413            client_id: "test_client".to_string(),
414            response_type: "code".to_string(),
415            redirect_uri: "https://example.com/callback".to_string(),
416            scope: Some("openid profile".to_string()),
417            state: Some("test_state".to_string()),
418            code_challenge: Some("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()),
419            code_challenge_method: Some("S256".to_string()),
420            additional_params: HashMap::new(),
421        }
422    }
423
424    #[tokio::test]
425    async fn test_store_and_consume_request() {
426        use crate::storage::MemoryStorage;
427        use std::sync::Arc;
428
429        let storage = Arc::new(MemoryStorage::new());
430        let par_manager = PARManager::new(storage);
431        let request = create_test_request();
432
433        // Store the request
434        let response = par_manager.store_request(request.clone()).await.unwrap();
435        assert!(
436            response
437                .request_uri
438                .starts_with("urn:ietf:params:oauth:request_uri:")
439        );
440        assert_eq!(response.expires_in, 90);
441
442        // Consume the request
443        let consumed_request = par_manager
444            .consume_request(&response.request_uri)
445            .await
446            .unwrap();
447        assert_eq!(consumed_request.client_id, request.client_id);
448        assert_eq!(consumed_request.response_type, request.response_type);
449
450        // Try to consume again (should fail)
451        let result = par_manager.consume_request(&response.request_uri).await;
452        assert!(result.is_err());
453    }
454
455    #[tokio::test]
456    async fn test_request_expiration() {
457        use crate::storage::MemoryStorage;
458        use std::sync::Arc;
459
460        let storage = Arc::new(MemoryStorage::new());
461        let par_manager = PARManager::with_expiration(storage, Duration::from_millis(50));
462        let request = create_test_request();
463
464        // Store the request
465        let response = par_manager.store_request(request).await.unwrap();
466
467        // Wait for expiration
468        sleep(Duration::from_millis(100)).await;
469
470        // Try to consume (should fail due to expiration)
471        let result = par_manager.consume_request(&response.request_uri).await;
472        assert!(result.is_err());
473    }
474
475    #[tokio::test]
476    async fn test_invalid_request_validation() {
477        use crate::storage::MemoryStorage;
478        use std::sync::Arc;
479
480        let storage = Arc::new(MemoryStorage::new());
481        let par_manager = PARManager::new(storage);
482
483        // Test missing client_id
484        let mut request = create_test_request();
485        request.client_id = "".to_string();
486        let result = par_manager.store_request(request).await;
487        assert!(result.is_err());
488
489        // Test invalid redirect_uri
490        let mut request = create_test_request();
491        request.redirect_uri = "invalid-uri".to_string();
492        let result = par_manager.store_request(request).await;
493        assert!(result.is_err());
494
495        // Test invalid PKCE method
496        let mut request = create_test_request();
497        request.code_challenge_method = Some("invalid".to_string());
498        let result = par_manager.store_request(request).await;
499        assert!(result.is_err());
500    }
501
502    #[tokio::test]
503    async fn test_statistics() {
504        use crate::storage::MemoryStorage;
505        use std::sync::Arc;
506
507        let storage = Arc::new(MemoryStorage::new());
508        let par_manager = PARManager::new(storage);
509        let request = create_test_request();
510
511        // Initial statistics
512        let stats = par_manager.get_statistics().await;
513        assert_eq!(stats.total_requests, 0);
514
515        // Store a request
516        let response = par_manager.store_request(request).await.unwrap();
517        let stats = par_manager.get_statistics().await;
518        assert_eq!(stats.total_requests, 1);
519        assert_eq!(stats.active_requests, 1);
520
521        // Consume the request
522        par_manager
523            .consume_request(&response.request_uri)
524            .await
525            .unwrap();
526        let stats = par_manager.get_statistics().await;
527        assert_eq!(stats.total_requests, 0); // Removed after consumption
528    }
529}