use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq as _;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PKCEChallenge {
pub code_verifier: String,
pub code_challenge: String,
pub code_challenge_method: String,
}
impl PKCEChallenge {
pub fn new() -> Self {
use sha2::{Digest, Sha256};
let verifier = format!("{}", uuid::Uuid::new_v4());
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let digest = hasher.finalize();
let challenge = urlencoding::encode_binary(&digest).to_string();
Self {
code_verifier: verifier,
code_challenge: challenge,
code_challenge_method: "S256".to_string(),
}
}
pub fn verify(&self, verifier: &str) -> bool {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let digest = hasher.finalize();
let computed_challenge = urlencoding::encode_binary(&digest).to_string();
computed_challenge.as_bytes().ct_eq(self.code_challenge.as_bytes()).into()
}
}
impl Default for PKCEChallenge {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateParameter {
pub state: String,
pub expires_at: DateTime<Utc>,
}
impl StateParameter {
pub fn new() -> Self {
Self {
state: uuid::Uuid::new_v4().to_string(),
expires_at: Utc::now() + Duration::minutes(10),
}
}
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
pub fn verify(&self, provided_state: &str) -> bool {
let match_ok: bool = self.state.as_bytes().ct_eq(provided_state.as_bytes()).into();
match_ok && !self.is_expired()
}
}
impl Default for StateParameter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NonceParameter {
pub nonce: String,
pub expires_at: DateTime<Utc>,
}
impl NonceParameter {
pub fn new() -> Self {
Self {
nonce: uuid::Uuid::new_v4().to_string(),
expires_at: Utc::now() + Duration::minutes(10),
}
}
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
pub fn verify(&self, provided_nonce: &str) -> bool {
let match_ok: bool = self.nonce.as_bytes().ct_eq(provided_nonce.as_bytes()).into();
match_ok && !self.is_expired()
}
}
impl Default for NonceParameter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pkce_challenge_method_is_s256() {
let challenge = PKCEChallenge::new();
assert_eq!(challenge.code_challenge_method, "S256", "PKCE challenge method must be S256");
}
#[test]
fn test_pkce_verifier_is_uuid_format() {
let challenge = PKCEChallenge::new();
assert!(
uuid::Uuid::parse_str(&challenge.code_verifier).is_ok(),
"PKCE code_verifier must be a valid UUID"
);
}
#[test]
fn test_pkce_challenge_is_not_empty() {
let challenge = PKCEChallenge::new();
assert!(!challenge.code_challenge.is_empty(), "PKCE code_challenge must not be empty");
}
#[test]
fn test_pkce_verify_correct_verifier() {
let challenge = PKCEChallenge::new();
let verifier = challenge.code_verifier.clone();
assert!(
challenge.verify(&verifier),
"PKCEChallenge::verify must succeed for the original verifier"
);
}
#[test]
fn test_pkce_verify_wrong_verifier_fails() {
let challenge = PKCEChallenge::new();
assert!(
!challenge.verify("definitely-wrong-verifier"),
"PKCEChallenge::verify must fail for an incorrect verifier"
);
}
#[test]
fn test_pkce_two_challenges_differ() {
let c1 = PKCEChallenge::new();
let c2 = PKCEChallenge::new();
assert_ne!(
c1.code_verifier, c2.code_verifier,
"consecutive PKCE challenges must have unique verifiers"
);
assert_ne!(
c1.code_challenge, c2.code_challenge,
"consecutive PKCE challenges must have unique challenges"
);
}
#[test]
fn test_state_parameter_not_expired_on_creation() {
let state = StateParameter::new();
assert!(!state.is_expired(), "freshly created StateParameter must not be expired");
}
#[test]
fn test_state_verify_correct_value() {
let state = StateParameter::new();
let value = state.state.clone();
assert!(
state.verify(&value),
"StateParameter::verify must succeed for the correct state value"
);
}
#[test]
fn test_state_verify_wrong_value_fails() {
let state = StateParameter::new();
assert!(
!state.verify("wrong-state-value"),
"StateParameter::verify must fail for an incorrect state value"
);
}
#[test]
fn test_state_parameters_are_unique() {
let s1 = StateParameter::new();
let s2 = StateParameter::new();
assert_ne!(s1.state, s2.state, "consecutive StateParameter values must be unique");
}
#[test]
fn test_nonce_not_expired_on_creation() {
let nonce = NonceParameter::new();
assert!(!nonce.is_expired(), "freshly created NonceParameter must not be expired");
}
#[test]
fn test_nonce_verify_correct_value() {
let nonce = NonceParameter::new();
let value = nonce.nonce.clone();
assert!(
nonce.verify(&value),
"NonceParameter::verify must succeed for the correct nonce value"
);
}
#[test]
fn test_nonce_verify_wrong_value_fails() {
let nonce = NonceParameter::new();
assert!(
!nonce.verify("wrong-nonce-value"),
"NonceParameter::verify must fail for an incorrect nonce value"
);
}
#[test]
fn test_nonce_parameters_are_unique() {
let n1 = NonceParameter::new();
let n2 = NonceParameter::new();
assert_ne!(n1.nonce, n2.nonce, "consecutive NonceParameter values must be unique");
}
}