auth_framework/server/oauth/
par.rs

1//! Pushed Authorization Requests (PAR) Implementation - RFC 9126
2//!
3//! This module implements RFC 9126 - OAuth 2.0 Pushed Authorization Requests
4/// which enhances security by allowing clients to push authorization request
5/// parameters directly to the authorization server.
6use crate::errors::{AuthError, Result};
7use crate::storage::AuthStorage;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use uuid::Uuid;
13
14/// PAR request containing authorization parameters
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PushedAuthorizationRequest {
17    /// Client identifier
18    pub client_id: String,
19
20    /// Response type (e.g., "code")
21    pub response_type: String,
22
23    /// Redirect URI
24    pub redirect_uri: String,
25
26    /// Requested scopes
27    pub scope: Option<String>,
28
29    /// State parameter
30    pub state: Option<String>,
31
32    /// PKCE code challenge
33    pub code_challenge: Option<String>,
34
35    /// PKCE code challenge method
36    pub code_challenge_method: Option<String>,
37
38    /// Additional parameters
39    #[serde(flatten)]
40    pub additional_params: HashMap<String, String>,
41}
42
43/// PAR response containing request URI
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PushedAuthorizationResponse {
46    /// Request URI to be used in subsequent authorization request
47    pub request_uri: String,
48
49    /// Expiration time in seconds
50    pub expires_in: u64,
51}
52
53/// Stored PAR request with metadata
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StoredPushedRequest {
56    /// Original request parameters
57    pub request: PushedAuthorizationRequest,
58
59    /// When the request was created
60    pub created_at: SystemTime,
61
62    /// When the request expires
63    pub expires_at: SystemTime,
64
65    /// Whether the request has been used
66    pub used: bool,
67}
68
69/// PAR request manager with persistent storage
70use std::fmt;
71
72#[derive(Clone)]
73pub struct PARManager {
74    /// Persistent storage backend
75    storage: Arc<dyn AuthStorage>,
76
77    /// Memory cache for fast access
78    requests: Arc<tokio::sync::RwLock<HashMap<String, StoredPushedRequest>>>,
79
80    /// Default expiration time for PAR requests
81    default_expiration: Duration,
82}
83
84impl fmt::Debug for PARManager {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_struct("PARManager")
87            .field("storage", &"<dyn AuthStorage>")
88            .field("default_expiration", &self.default_expiration)
89            .finish()
90    }
91}
92
93impl PARManager {
94    /// Create a new PAR manager with storage backend
95    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
96        Self {
97            storage,
98            requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
99            default_expiration: Duration::from_secs(90), // RFC 9126 recommendation
100        }
101    }
102
103    /// Create a new PAR manager with custom expiration
104    pub fn with_expiration(storage: Arc<dyn AuthStorage>, expiration: Duration) -> Self {
105        Self {
106            storage,
107            requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
108            default_expiration: expiration,
109        }
110    }
111
112    /// Store a pushed authorization request
113    pub async fn store_request(
114        &self,
115        request: PushedAuthorizationRequest,
116    ) -> Result<PushedAuthorizationResponse> {
117        // Validate the request
118        self.validate_request(&request)?;
119
120        // Generate request URI
121        let request_id = Uuid::new_v4().to_string();
122        let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
123
124        // Calculate expiration
125        let now = SystemTime::now();
126        let expires_at = now + self.default_expiration;
127
128        // Store the request in persistent storage with TTL
129        let stored_request = StoredPushedRequest {
130            request: request.clone(),
131            created_at: now,
132            expires_at,
133            used: false,
134        };
135
136        // Store in persistent backend with TTL
137        let storage_key = format!("par:{}", request_uri);
138        let serialized = serde_json::to_string(&stored_request)
139            .map_err(|e| AuthError::internal(format!("Failed to serialize PAR request: {}", e)))?;
140
141        self.storage
142            .store_kv(
143                &storage_key,
144                &serialized.into_bytes(),
145                Some(self.default_expiration),
146            )
147            .await
148            .map_err(|e| AuthError::internal(format!("Failed to store PAR request: {}", e)))?;
149
150        // Also cache in memory for fast access
151        let mut requests = self.requests.write().await;
152        requests.insert(request_uri.clone(), stored_request);
153
154        // Clean up expired requests from memory cache
155        self.cleanup_expired_requests(&mut requests, now);
156
157        Ok(PushedAuthorizationResponse {
158            request_uri,
159            expires_in: self.default_expiration.as_secs(),
160        })
161    }
162
163    /// Retrieve and consume a pushed authorization request
164    pub async fn consume_request(&self, request_uri: &str) -> Result<PushedAuthorizationRequest> {
165        let storage_key = format!("par:{}", request_uri);
166
167        // Try to load from persistent storage first
168        let stored_request = if let Some(data) = self.storage.get_kv(&storage_key).await? {
169            let serialized = String::from_utf8(data)
170                .map_err(|_| AuthError::internal("Invalid UTF-8 in stored PAR data"))?;
171
172            serde_json::from_str::<StoredPushedRequest>(&serialized).map_err(|e| {
173                AuthError::internal(format!("Failed to deserialize PAR request: {}", e))
174            })?
175        } else {
176            // Fallback to memory cache (for backward compatibility during transition)
177            let requests = self.requests.read().await;
178            requests
179                .get(request_uri)
180                .cloned()
181                .ok_or_else(|| AuthError::auth_method("par", "Invalid request_uri"))?
182        };
183
184        // Check if expired
185        let now = SystemTime::now();
186        if now > stored_request.expires_at {
187            // Clean up from both storage and cache
188            let _ = self.storage.delete_kv(&storage_key).await;
189            let mut requests = self.requests.write().await;
190            requests.remove(request_uri);
191            return Err(AuthError::auth_method("par", "Request URI expired"));
192        }
193
194        // Check if already used
195        if stored_request.used {
196            return Err(AuthError::auth_method("par", "Request URI already used"));
197        }
198
199        // Mark as consumed by removing from storage (single use)
200        self.storage
201            .delete_kv(&storage_key)
202            .await
203            .map_err(|e| AuthError::internal(format!("Failed to consume PAR request: {}", e)))?;
204
205        // Also remove from memory cache
206        let mut requests = self.requests.write().await;
207        requests.remove(request_uri);
208
209        Ok(stored_request.request)
210    }
211
212    /// Validate a PAR request
213    fn validate_request(&self, request: &PushedAuthorizationRequest) -> Result<()> {
214        // Validate required parameters
215        if request.client_id.is_empty() {
216            return Err(AuthError::auth_method("par", "Missing client_id"));
217        }
218
219        if request.response_type.is_empty() {
220            return Err(AuthError::auth_method("par", "Missing response_type"));
221        }
222
223        if request.redirect_uri.is_empty() {
224            return Err(AuthError::auth_method("par", "Missing redirect_uri"));
225        }
226
227        // Validate redirect URI format
228        if url::Url::parse(&request.redirect_uri).is_err() {
229            return Err(AuthError::auth_method("par", "Invalid redirect_uri format"));
230        }
231
232        // Validate PKCE parameters if present
233        if let (Some(challenge), Some(method)) =
234            (&request.code_challenge, &request.code_challenge_method)
235        {
236            if method != "S256" && method != "plain" {
237                return Err(AuthError::auth_method(
238                    "par",
239                    "Invalid code_challenge_method",
240                ));
241            }
242
243            if challenge.is_empty() {
244                return Err(AuthError::auth_method("par", "Empty code_challenge"));
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Clean up expired requests
252    fn cleanup_expired_requests(
253        &self,
254        requests: &mut HashMap<String, StoredPushedRequest>,
255        now: SystemTime,
256    ) {
257        requests.retain(|_, stored_request| now <= stored_request.expires_at);
258    }
259
260    /// Get statistics about stored requests
261    pub async fn get_statistics(&self) -> PARStatistics {
262        let requests = self.requests.read().await;
263        let now = SystemTime::now();
264
265        let total_count = requests.len();
266        let expired_count = requests.values().filter(|req| now > req.expires_at).count();
267        let used_count = requests.values().filter(|req| req.used).count();
268
269        PARStatistics {
270            total_requests: total_count,
271            expired_requests: expired_count,
272            used_requests: used_count,
273            active_requests: total_count - expired_count - used_count,
274        }
275    }
276}
277
278/// PAR statistics
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct PARStatistics {
281    /// Total number of stored requests
282    pub total_requests: usize,
283
284    /// Number of expired requests
285    pub expired_requests: usize,
286
287    /// Number of used requests
288    pub used_requests: usize,
289
290    /// Number of active (valid, unused) requests
291    pub active_requests: usize,
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use tokio::time::sleep;
298
299    fn create_test_request() -> PushedAuthorizationRequest {
300        PushedAuthorizationRequest {
301            client_id: "test_client".to_string(),
302            response_type: "code".to_string(),
303            redirect_uri: "https://example.com/callback".to_string(),
304            scope: Some("openid profile".to_string()),
305            state: Some("test_state".to_string()),
306            code_challenge: Some("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()),
307            code_challenge_method: Some("S256".to_string()),
308            additional_params: HashMap::new(),
309        }
310    }
311
312    #[tokio::test]
313    async fn test_store_and_consume_request() {
314        use crate::storage::MemoryStorage;
315        use std::sync::Arc;
316
317        let storage = Arc::new(MemoryStorage::new());
318        let par_manager = PARManager::new(storage);
319        let request = create_test_request();
320
321        // Store the request
322        let response = par_manager.store_request(request.clone()).await.unwrap();
323        assert!(
324            response
325                .request_uri
326                .starts_with("urn:ietf:params:oauth:request_uri:")
327        );
328        assert_eq!(response.expires_in, 90);
329
330        // Consume the request
331        let consumed_request = par_manager
332            .consume_request(&response.request_uri)
333            .await
334            .unwrap();
335        assert_eq!(consumed_request.client_id, request.client_id);
336        assert_eq!(consumed_request.response_type, request.response_type);
337
338        // Try to consume again (should fail)
339        let result = par_manager.consume_request(&response.request_uri).await;
340        assert!(result.is_err());
341    }
342
343    #[tokio::test]
344    async fn test_request_expiration() {
345        use crate::storage::MemoryStorage;
346        use std::sync::Arc;
347
348        let storage = Arc::new(MemoryStorage::new());
349        let par_manager = PARManager::with_expiration(storage, Duration::from_millis(50));
350        let request = create_test_request();
351
352        // Store the request
353        let response = par_manager.store_request(request).await.unwrap();
354
355        // Wait for expiration
356        sleep(Duration::from_millis(100)).await;
357
358        // Try to consume (should fail due to expiration)
359        let result = par_manager.consume_request(&response.request_uri).await;
360        assert!(result.is_err());
361    }
362
363    #[tokio::test]
364    async fn test_invalid_request_validation() {
365        use crate::storage::MemoryStorage;
366        use std::sync::Arc;
367
368        let storage = Arc::new(MemoryStorage::new());
369        let par_manager = PARManager::new(storage);
370
371        // Test missing client_id
372        let mut request = create_test_request();
373        request.client_id = "".to_string();
374        let result = par_manager.store_request(request).await;
375        assert!(result.is_err());
376
377        // Test invalid redirect_uri
378        let mut request = create_test_request();
379        request.redirect_uri = "invalid-uri".to_string();
380        let result = par_manager.store_request(request).await;
381        assert!(result.is_err());
382
383        // Test invalid PKCE method
384        let mut request = create_test_request();
385        request.code_challenge_method = Some("invalid".to_string());
386        let result = par_manager.store_request(request).await;
387        assert!(result.is_err());
388    }
389
390    #[tokio::test]
391    async fn test_statistics() {
392        use crate::storage::MemoryStorage;
393        use std::sync::Arc;
394
395        let storage = Arc::new(MemoryStorage::new());
396        let par_manager = PARManager::new(storage);
397        let request = create_test_request();
398
399        // Initial statistics
400        let stats = par_manager.get_statistics().await;
401        assert_eq!(stats.total_requests, 0);
402
403        // Store a request
404        let response = par_manager.store_request(request).await.unwrap();
405        let stats = par_manager.get_statistics().await;
406        assert_eq!(stats.total_requests, 1);
407        assert_eq!(stats.active_requests, 1);
408
409        // Consume the request
410        par_manager
411            .consume_request(&response.request_uri)
412            .await
413            .unwrap();
414        let stats = par_manager.get_statistics().await;
415        assert_eq!(stats.total_requests, 0); // Removed after consumption
416    }
417}
418
419