use crate::errors::{AuthError, Result};
use crate::storage::AuthStorage;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PushedAuthorizationRequest {
pub client_id: String,
pub response_type: String,
pub redirect_uri: String,
pub scope: Option<String>,
pub state: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
#[serde(flatten)]
pub additional_params: HashMap<String, String>,
}
impl PushedAuthorizationRequest {
pub fn builder(
client_id: impl Into<String>,
response_type: impl Into<String>,
redirect_uri: impl Into<String>,
) -> PushedAuthorizationRequestBuilder {
PushedAuthorizationRequestBuilder {
client_id: client_id.into(),
response_type: response_type.into(),
redirect_uri: redirect_uri.into(),
scope: None,
state: None,
code_challenge: None,
code_challenge_method: None,
additional_params: HashMap::new(),
}
}
}
pub struct PushedAuthorizationRequestBuilder {
client_id: String,
response_type: String,
redirect_uri: String,
scope: Option<String>,
state: Option<String>,
code_challenge: Option<String>,
code_challenge_method: Option<String>,
additional_params: HashMap<String, String>,
}
impl PushedAuthorizationRequestBuilder {
pub fn scope(mut self, scope: impl Into<String>) -> Self {
self.scope = Some(scope.into());
self
}
pub fn state(mut self, state: impl Into<String>) -> Self {
self.state = Some(state.into());
self
}
pub fn code_challenge(mut self, challenge: impl Into<String>) -> Self {
self.code_challenge = Some(challenge.into());
self
}
pub fn code_challenge_method(mut self, method: impl Into<String>) -> Self {
self.code_challenge_method = Some(method.into());
self
}
pub fn pkce(mut self, challenge: impl Into<String>, method: impl Into<String>) -> Self {
self.code_challenge = Some(challenge.into());
self.code_challenge_method = Some(method.into());
self
}
pub fn add_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.additional_params.insert(key.into(), value.into());
self
}
pub fn build(self) -> PushedAuthorizationRequest {
PushedAuthorizationRequest {
client_id: self.client_id,
response_type: self.response_type,
redirect_uri: self.redirect_uri,
scope: self.scope,
state: self.state,
code_challenge: self.code_challenge,
code_challenge_method: self.code_challenge_method,
additional_params: self.additional_params,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PushedAuthorizationResponse {
pub request_uri: String,
pub expires_in: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredPushedRequest {
pub request: PushedAuthorizationRequest,
pub created_at: SystemTime,
pub expires_at: SystemTime,
pub used: bool,
}
use std::fmt;
#[derive(Clone)]
pub struct PARManager {
storage: Arc<dyn AuthStorage>,
requests: Arc<tokio::sync::RwLock<HashMap<String, StoredPushedRequest>>>,
default_expiration: Duration,
}
impl fmt::Debug for PARManager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PARManager")
.field("storage", &"<dyn AuthStorage>")
.field("default_expiration", &self.default_expiration)
.finish()
}
}
impl PARManager {
pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
Self {
storage,
requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
default_expiration: Duration::from_secs(90), }
}
pub fn with_expiration(storage: Arc<dyn AuthStorage>, expiration: Duration) -> Self {
Self {
storage,
requests: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
default_expiration: expiration,
}
}
pub fn expiration(mut self, expiration: Duration) -> Self {
self.default_expiration = expiration;
self
}
pub async fn store_request(
&self,
request: PushedAuthorizationRequest,
) -> Result<PushedAuthorizationResponse> {
self.validate_request(&request)?;
let request_id = Uuid::new_v4().to_string();
let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", request_id);
let now = SystemTime::now();
let expires_at = now + self.default_expiration;
let stored_request = StoredPushedRequest {
request: request.clone(),
created_at: now,
expires_at,
used: false,
};
let storage_key = format!("par:{}", request_uri);
let serialized = serde_json::to_string(&stored_request)
.map_err(|e| AuthError::internal(format!("Failed to serialize PAR request: {}", e)))?;
self.storage
.store_kv(
&storage_key,
&serialized.into_bytes(),
Some(self.default_expiration),
)
.await
.map_err(|e| AuthError::internal(format!("Failed to store PAR request: {}", e)))?;
let mut requests = self.requests.write().await;
requests.insert(request_uri.clone(), stored_request);
self.cleanup_expired_requests(&mut requests, now);
Ok(PushedAuthorizationResponse {
request_uri,
expires_in: self.default_expiration.as_secs(),
})
}
pub async fn consume_request(&self, request_uri: &str) -> Result<PushedAuthorizationRequest> {
let storage_key = format!("par:{}", request_uri);
let stored_request = if let Some(data) = self.storage.get_kv(&storage_key).await? {
let serialized = String::from_utf8(data)
.map_err(|_| AuthError::internal("Invalid UTF-8 in stored PAR data"))?;
serde_json::from_str::<StoredPushedRequest>(&serialized).map_err(|e| {
AuthError::internal(format!("Failed to deserialize PAR request: {}", e))
})?
} else {
let requests = self.requests.read().await;
requests
.get(request_uri)
.cloned()
.ok_or_else(|| AuthError::auth_method("par", "Invalid request_uri"))?
};
let now = SystemTime::now();
if now > stored_request.expires_at {
let _ = self.storage.delete_kv(&storage_key).await;
let mut requests = self.requests.write().await;
requests.remove(request_uri);
return Err(AuthError::auth_method("par", "Request URI expired"));
}
if stored_request.used {
return Err(AuthError::auth_method("par", "Request URI already used"));
}
self.storage
.delete_kv(&storage_key)
.await
.map_err(|e| AuthError::internal(format!("Failed to consume PAR request: {}", e)))?;
let mut requests = self.requests.write().await;
requests.remove(request_uri);
Ok(stored_request.request)
}
fn validate_request(&self, request: &PushedAuthorizationRequest) -> Result<()> {
if request.client_id.is_empty() {
return Err(AuthError::auth_method("par", "Missing client_id"));
}
if request.response_type.is_empty() {
return Err(AuthError::auth_method("par", "Missing response_type"));
}
if request.redirect_uri.is_empty() {
return Err(AuthError::auth_method("par", "Missing redirect_uri"));
}
if url::Url::parse(&request.redirect_uri).is_err() {
return Err(AuthError::auth_method("par", "Invalid redirect_uri format"));
}
if let (Some(challenge), Some(method)) =
(&request.code_challenge, &request.code_challenge_method)
{
if method != "S256" && method != "plain" {
return Err(AuthError::auth_method(
"par",
"Invalid code_challenge_method",
));
}
if challenge.is_empty() {
return Err(AuthError::auth_method("par", "Empty code_challenge"));
}
}
Ok(())
}
fn cleanup_expired_requests(
&self,
requests: &mut HashMap<String, StoredPushedRequest>,
now: SystemTime,
) {
requests.retain(|_, stored_request| now <= stored_request.expires_at);
}
pub async fn get_statistics(&self) -> PARStatistics {
let requests = self.requests.read().await;
let now = SystemTime::now();
let total_count = requests.len();
let expired_count = requests.values().filter(|req| now > req.expires_at).count();
let used_count = requests.values().filter(|req| req.used).count();
PARStatistics {
total_requests: total_count,
expired_requests: expired_count,
used_requests: used_count,
active_requests: total_count - expired_count - used_count,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PARStatistics {
pub total_requests: usize,
pub expired_requests: usize,
pub used_requests: usize,
pub active_requests: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::sleep;
#[test]
fn test_par_request_builder() {
let req = PushedAuthorizationRequest::builder("client_id", "code", "https://app/callback")
.scope("openid profile")
.state("state123")
.pkce("challenge_abc", "S256")
.add_param("custom", "value")
.build();
assert_eq!(req.client_id, "client_id");
assert_eq!(req.response_type, "code");
assert_eq!(req.redirect_uri, "https://app/callback");
assert_eq!(req.scope, Some("openid profile".to_string()));
assert_eq!(req.state, Some("state123".to_string()));
assert_eq!(req.code_challenge, Some("challenge_abc".to_string()));
assert_eq!(req.code_challenge_method, Some("S256".to_string()));
assert_eq!(req.additional_params.get("custom").map(String::as_str), Some("value"));
}
fn create_test_request() -> PushedAuthorizationRequest {
PushedAuthorizationRequest {
client_id: "test_client".to_string(),
response_type: "code".to_string(),
redirect_uri: "https://example.com/callback".to_string(),
scope: Some("openid profile".to_string()),
state: Some("test_state".to_string()),
code_challenge: Some("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()),
code_challenge_method: Some("S256".to_string()),
additional_params: HashMap::new(),
}
}
#[tokio::test]
async fn test_store_and_consume_request() {
use crate::storage::MemoryStorage;
use std::sync::Arc;
let storage = Arc::new(MemoryStorage::new());
let par_manager = PARManager::new(storage);
let request = create_test_request();
let response = par_manager.store_request(request.clone()).await.unwrap();
assert!(
response
.request_uri
.starts_with("urn:ietf:params:oauth:request_uri:")
);
assert_eq!(response.expires_in, 90);
let consumed_request = par_manager
.consume_request(&response.request_uri)
.await
.unwrap();
assert_eq!(consumed_request.client_id, request.client_id);
assert_eq!(consumed_request.response_type, request.response_type);
let result = par_manager.consume_request(&response.request_uri).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_request_expiration() {
use crate::storage::MemoryStorage;
use std::sync::Arc;
let storage = Arc::new(MemoryStorage::new());
let par_manager = PARManager::with_expiration(storage, Duration::from_millis(50));
let request = create_test_request();
let response = par_manager.store_request(request).await.unwrap();
sleep(Duration::from_millis(100)).await;
let result = par_manager.consume_request(&response.request_uri).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_invalid_request_validation() {
use crate::storage::MemoryStorage;
use std::sync::Arc;
let storage = Arc::new(MemoryStorage::new());
let par_manager = PARManager::new(storage);
let mut request = create_test_request();
request.client_id = "".to_string();
let result = par_manager.store_request(request).await;
assert!(result.is_err());
let mut request = create_test_request();
request.redirect_uri = "invalid-uri".to_string();
let result = par_manager.store_request(request).await;
assert!(result.is_err());
let mut request = create_test_request();
request.code_challenge_method = Some("invalid".to_string());
let result = par_manager.store_request(request).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_statistics() {
use crate::storage::MemoryStorage;
use std::sync::Arc;
let storage = Arc::new(MemoryStorage::new());
let par_manager = PARManager::new(storage);
let request = create_test_request();
let stats = par_manager.get_statistics().await;
assert_eq!(stats.total_requests, 0);
let response = par_manager.store_request(request).await.unwrap();
let stats = par_manager.get_statistics().await;
assert_eq!(stats.total_requests, 1);
assert_eq!(stats.active_requests, 1);
par_manager
.consume_request(&response.request_uri)
.await
.unwrap();
let stats = par_manager.get_statistics().await;
assert_eq!(stats.total_requests, 0); }
}