use crate::errors::{AuthError, Result};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
pub mod jwt {
use super::*;
use jsonwebtoken::decode_header;
pub fn validate_jwt_format(token: &str) -> Result<()> {
if token.is_empty() {
return Err(AuthError::validation("JWT token is empty"));
}
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::validation(
"Invalid JWT format: must have 3 parts",
));
}
decode_header(token)
.map_err(|e| AuthError::validation(format!("Invalid JWT header: {}", e)))?;
Ok(())
}
pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
validate_jwt_format(token)?;
let parts: Vec<&str> = token.split('.').collect();
let payload = parts[1];
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
let decoded = URL_SAFE_NO_PAD
.decode(payload)
.map_err(|e| AuthError::validation(format!("Invalid JWT payload encoding: {}", e)))?;
let claims: serde_json::Value = serde_json::from_slice(&decoded)
.map_err(|e| AuthError::validation(format!("Invalid JWT payload JSON: {}", e)))?;
Ok(claims)
}
pub fn validate_time_claims(claims: &serde_json::Value) -> Result<()> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64())
&& now >= exp
{
return Err(AuthError::validation("Token has expired"));
}
if let Some(nbf) = claims.get("nbf").and_then(|v| v.as_i64())
&& now < nbf
{
return Err(AuthError::validation("Token not yet valid (nbf)"));
}
if let Some(iat) = claims.get("iat").and_then(|v| v.as_i64()) {
let max_age = 24 * 60 * 60; if now - iat > max_age {
return Err(AuthError::validation("Token too old"));
}
}
Ok(())
}
pub fn validate_required_claims(claims: &serde_json::Value, required: &[&str]) -> Result<()> {
for claim in required {
if claims.get(claim).is_none() {
return Err(AuthError::validation(format!(
"Missing required claim: {}",
claim
)));
}
}
Ok(())
}
}
pub mod token {
use super::*;
pub fn validate_token_type(token_type: &str, allowed_types: &[&str]) -> Result<()> {
if !allowed_types.contains(&token_type) {
return Err(AuthError::validation(format!(
"Unsupported token type: {}",
token_type
)));
}
Ok(())
}
pub fn validate_token_format(token: &str, token_type: &str) -> Result<()> {
if token.is_empty() {
return Err(AuthError::validation("Token is empty"));
}
match token_type {
"urn:ietf:params:oauth:token-type:jwt" => jwt::validate_jwt_format(token),
"urn:ietf:params:oauth:token-type:access_token" => {
if token.len() < 10 {
return Err(AuthError::validation("Access token too short"));
}
Ok(())
}
"urn:ietf:params:oauth:token-type:refresh_token" => {
if token.len() < 20 {
return Err(AuthError::validation("Refresh token too short"));
}
Ok(())
}
_ => {
Err(AuthError::validation(format!(
"Unsupported token type: {}",
token_type
)))
}
}
}
pub fn validate_scope(scope: &str) -> Result<Vec<String>> {
if scope.is_empty() {
return Ok(vec![]);
}
let scopes: Vec<String> = scope.split_whitespace().map(|s| s.to_string()).collect();
for scope in &scopes {
if scope.is_empty() {
return Err(AuthError::validation("Empty scope value"));
}
if !scope.chars().all(|c| {
c.is_alphanumeric() || c == ':' || c == '/' || c == '.' || c == '-' || c == '_'
}) {
return Err(AuthError::validation(format!(
"Invalid scope format: {}",
scope
)));
}
}
Ok(scopes)
}
}
pub mod client {
use super::*;
pub fn validate_client_id(client_id: &str) -> Result<()> {
if client_id.is_empty() {
return Err(AuthError::validation("Client ID is empty"));
}
if client_id.len() < 3 {
return Err(AuthError::validation("Client ID too short"));
}
if client_id.len() > 255 {
return Err(AuthError::validation("Client ID too long"));
}
if !client_id
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
{
return Err(AuthError::validation(
"Client ID contains invalid characters",
));
}
Ok(())
}
pub fn validate_redirect_uri(uri: &str) -> Result<()> {
if uri.is_empty() {
return Err(AuthError::validation("Redirect URI is empty"));
}
if !uri.starts_with("http://")
&& !uri.starts_with("https://")
&& !uri.starts_with("custom://")
{
return Err(AuthError::validation("Redirect URI must be absolute"));
}
if uri.contains('#') {
return Err(AuthError::validation(
"Redirect URI cannot contain fragments",
));
}
Ok(())
}
pub fn validate_grant_type(grant_type: &str, allowed_grants: &[&str]) -> Result<()> {
if !allowed_grants.contains(&grant_type) {
return Err(AuthError::validation(format!(
"Unsupported grant type: {}",
grant_type
)));
}
Ok(())
}
}
pub mod request {
use super::*;
pub fn validate_required_params(
params: &HashMap<String, String>,
required: &[&str],
) -> Result<()> {
for param in required {
if !params.contains_key(*param) || params[*param].trim().is_empty() {
return Err(AuthError::validation(format!(
"Missing parameter: {}",
param
)));
}
}
Ok(())
}
pub fn validate_param_format(value: &str, param_name: &str, pattern: &str) -> Result<()> {
if value.is_empty() {
return Err(AuthError::validation(format!(
"Parameter {} cannot be empty",
param_name
)));
}
match pattern {
"alphanum" => {
if !value.chars().all(|c| c.is_alphanumeric()) {
return Err(AuthError::validation(format!(
"Parameter {} must be alphanumeric",
param_name
)));
}
}
_ => {
if value.trim().is_empty() {
return Err(AuthError::validation(format!(
"Parameter {} has invalid format",
param_name
)));
}
}
}
Ok(())
}
pub fn validate_code_challenge_method(method: &str) -> Result<()> {
match method {
"plain" | "S256" => Ok(()),
_ => Err(AuthError::validation("Invalid code challenge method")),
}
}
pub fn validate_response_type(response_type: &str, allowed_types: &[&str]) -> Result<()> {
let types: Vec<&str> = response_type.split_whitespace().collect();
for response_type in &types {
if !allowed_types.contains(response_type) {
return Err(AuthError::validation(format!(
"Unsupported response type: {}",
response_type
)));
}
}
Ok(())
}
}
pub mod url {
use super::*;
pub fn validate_url_format(url: &str) -> Result<()> {
if url.is_empty() {
return Err(AuthError::validation("URL is empty"));
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(AuthError::validation("URL must use HTTP or HTTPS scheme"));
}
if !url.contains("://") {
return Err(AuthError::validation("Invalid URL format"));
}
Ok(())
}
pub fn validate_https_required(url: &str) -> Result<()> {
validate_url_format(url)?;
if !url.starts_with("https://") {
return Err(AuthError::validation("HTTPS is required"));
}
Ok(())
}
}
pub fn collect_validation_errors(validations: Vec<Result<()>>) -> Result<()> {
let errors: Vec<String> = validations
.into_iter()
.filter_map(|result| result.err())
.map(|e| format!("{}", e))
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(AuthError::validation(errors.join("; ")))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_validate_jwt_format_empty() {
assert!(jwt::validate_jwt_format("").is_err());
}
#[test]
fn test_validate_jwt_format_wrong_parts() {
assert!(jwt::validate_jwt_format("one.two").is_err());
assert!(jwt::validate_jwt_format("a.b.c.d").is_err());
}
#[test]
fn test_validate_required_claims_missing() {
let claims = json!({"sub": "user1"});
assert!(jwt::validate_required_claims(&claims, &["sub"]).is_ok());
assert!(jwt::validate_required_claims(&claims, &["aud"]).is_err());
}
#[test]
fn test_validate_time_claims_expired() {
let claims = json!({"exp": 1000000});
assert!(jwt::validate_time_claims(&claims).is_err());
}
#[test]
fn test_validate_time_claims_future_nbf() {
let far_future = (SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 999999) as i64;
let claims = json!({"nbf": far_future});
assert!(jwt::validate_time_claims(&claims).is_err());
}
#[test]
fn test_validate_time_claims_valid_no_claims() {
let claims = json!({});
assert!(jwt::validate_time_claims(&claims).is_ok());
}
#[test]
fn test_validate_token_type_success() {
assert!(token::validate_token_type("bearer", &["bearer", "dpop"]).is_ok());
}
#[test]
fn test_validate_token_type_unsupported() {
assert!(token::validate_token_type("mac", &["bearer"]).is_err());
}
#[test]
fn test_validate_token_format_empty() {
assert!(token::validate_token_format("", "anything").is_err());
}
#[test]
fn test_validate_token_format_access_token_too_short() {
assert!(
token::validate_token_format("short", "urn:ietf:params:oauth:token-type:access_token")
.is_err()
);
}
#[test]
fn test_validate_token_format_refresh_token_too_short() {
assert!(
token::validate_token_format(
"shorttoken",
"urn:ietf:params:oauth:token-type:refresh_token"
)
.is_err()
);
}
#[test]
fn test_validate_scope_empty() {
let scopes = token::validate_scope("").unwrap();
assert!(scopes.is_empty());
}
#[test]
fn test_validate_scope_valid() {
let scopes = token::validate_scope("read write openid").unwrap();
assert_eq!(scopes, vec!["read", "write", "openid"]);
}
#[test]
fn test_validate_scope_invalid_chars() {
assert!(token::validate_scope("read <script>").is_err());
}
#[test]
fn test_validate_client_id_valid() {
assert!(client::validate_client_id("my-client.app_01").is_ok());
}
#[test]
fn test_validate_client_id_empty() {
assert!(client::validate_client_id("").is_err());
}
#[test]
fn test_validate_client_id_too_short() {
assert!(client::validate_client_id("ab").is_err());
}
#[test]
fn test_validate_client_id_too_long() {
let long_id = "a".repeat(256);
assert!(client::validate_client_id(&long_id).is_err());
}
#[test]
fn test_validate_client_id_invalid_chars() {
assert!(client::validate_client_id("my client!").is_err());
}
#[test]
fn test_validate_redirect_uri_valid() {
assert!(client::validate_redirect_uri("https://example.com/callback").is_ok());
assert!(client::validate_redirect_uri("http://localhost:8080/cb").is_ok());
assert!(client::validate_redirect_uri("custom://app/callback").is_ok());
}
#[test]
fn test_validate_redirect_uri_empty() {
assert!(client::validate_redirect_uri("").is_err());
}
#[test]
fn test_validate_redirect_uri_not_absolute() {
assert!(client::validate_redirect_uri("/callback").is_err());
}
#[test]
fn test_validate_redirect_uri_with_fragment() {
assert!(client::validate_redirect_uri("https://example.com/cb#section").is_err());
}
#[test]
fn test_validate_grant_type_success() {
assert!(
client::validate_grant_type(
"authorization_code",
&["authorization_code", "refresh_token"]
)
.is_ok()
);
}
#[test]
fn test_validate_grant_type_unsupported() {
assert!(client::validate_grant_type("implicit", &["authorization_code"]).is_err());
}
#[test]
fn test_validate_required_params() {
let mut params = HashMap::new();
params.insert("code".to_string(), "abc123".to_string());
assert!(request::validate_required_params(¶ms, &["code"]).is_ok());
assert!(request::validate_required_params(¶ms, &["code", "state"]).is_err());
}
#[test]
fn test_validate_required_params_empty_value() {
let mut params = HashMap::new();
params.insert("code".to_string(), " ".to_string());
assert!(request::validate_required_params(¶ms, &["code"]).is_err());
}
#[test]
fn test_validate_param_format_alphanum() {
assert!(request::validate_param_format("abc123", "nonce", "alphanum").is_ok());
assert!(request::validate_param_format("abc-123", "nonce", "alphanum").is_err());
}
#[test]
fn test_validate_param_format_empty() {
assert!(request::validate_param_format("", "nonce", "alphanum").is_err());
}
#[test]
fn test_validate_code_challenge_method() {
assert!(request::validate_code_challenge_method("S256").is_ok());
assert!(request::validate_code_challenge_method("plain").is_ok());
assert!(request::validate_code_challenge_method("S512").is_err());
}
#[test]
fn test_validate_response_type() {
assert!(request::validate_response_type("code", &["code", "token"]).is_ok());
assert!(request::validate_response_type("id_token", &["code"]).is_err());
}
#[test]
fn test_validate_url_format_valid() {
assert!(url::validate_url_format("https://example.com").is_ok());
assert!(url::validate_url_format("http://localhost:8080").is_ok());
}
#[test]
fn test_validate_url_format_empty() {
assert!(url::validate_url_format("").is_err());
}
#[test]
fn test_validate_url_format_no_scheme() {
assert!(url::validate_url_format("example.com").is_err());
}
#[test]
fn test_validate_https_required() {
assert!(url::validate_https_required("https://example.com").is_ok());
assert!(url::validate_https_required("http://example.com").is_err());
}
#[test]
fn test_collect_validation_errors_all_ok() {
let validations = vec![Ok(()), Ok(())];
assert!(collect_validation_errors(validations).is_ok());
}
#[test]
fn test_collect_validation_errors_some_fail() {
let validations = vec![
Ok(()),
Err(AuthError::validation("err1")),
Err(AuthError::validation("err2")),
];
let err = collect_validation_errors(validations).unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("err1"));
assert!(msg.contains("err2"));
}
}