use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
use crate::error::{ConfigError, Error, Result};
use crate::random::generate_random_bytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum PkceMethod {
Plain,
#[default]
S256,
}
impl PkceMethod {
pub fn parse(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"plain" => Ok(PkceMethod::Plain),
"s256" => Ok(PkceMethod::S256),
_ => Err(Error::Config(ConfigError::InvalidValue {
key: "code_challenge_method".to_string(),
message: format!("unsupported PKCE method: {}", s),
})),
}
}
pub fn as_str(&self) -> &'static str {
match self {
PkceMethod::Plain => "plain",
PkceMethod::S256 => "S256",
}
}
}
impl std::fmt::Display for PkceMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::str::FromStr for PkceMethod {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::parse(s)
}
}
#[derive(Debug, Clone)]
pub struct PkceConfig {
pub verifier_length: usize,
pub method: PkceMethod,
}
impl Default for PkceConfig {
fn default() -> Self {
Self {
verifier_length: 32, method: PkceMethod::S256,
}
}
}
impl PkceConfig {
pub fn high_security() -> Self {
Self {
verifier_length: 64,
method: PkceMethod::S256,
}
}
pub fn validate(&self) -> Result<()> {
let encoded_len = (self.verifier_length * 4).div_ceil(3);
if encoded_len < 43 {
return Err(Error::Config(ConfigError::InvalidValue {
key: "verifier_length".to_string(),
message: format!(
"verifier too short: encoded length {} < 43 minimum",
encoded_len
),
}));
}
if encoded_len > 128 {
return Err(Error::Config(ConfigError::InvalidValue {
key: "verifier_length".to_string(),
message: format!(
"verifier too long: encoded length {} > 128 maximum",
encoded_len
),
}));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PkceChallenge {
verifier: String,
challenge: String,
method: PkceMethod,
}
impl PkceChallenge {
pub fn new(method: PkceMethod) -> Result<Self> {
Self::with_config(PkceConfig {
method,
..Default::default()
})
}
pub fn with_config(config: PkceConfig) -> Result<Self> {
config.validate()?;
let verifier_bytes = generate_random_bytes(config.verifier_length)?;
let verifier = URL_SAFE_NO_PAD.encode(&verifier_bytes);
let challenge = Self::compute_challenge(&verifier, config.method);
Ok(Self {
verifier,
challenge,
method: config.method,
})
}
pub fn from_verifier(verifier: String, method: PkceMethod) -> Result<Self> {
if verifier.len() < 43 || verifier.len() > 128 {
return Err(Error::Config(ConfigError::InvalidValue {
key: "code_verifier".to_string(),
message: format!(
"verifier length must be 43-128 characters, got {}",
verifier.len()
),
}));
}
if !verifier
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
{
return Err(Error::Config(ConfigError::InvalidValue {
key: "code_verifier".to_string(),
message: "verifier contains invalid characters".to_string(),
}));
}
let challenge = Self::compute_challenge(&verifier, method);
Ok(Self {
verifier,
challenge,
method,
})
}
fn compute_challenge(verifier: &str, method: PkceMethod) -> String {
match method {
PkceMethod::Plain => verifier.to_string(),
PkceMethod::S256 => {
let hash = Sha256::digest(verifier.as_bytes());
URL_SAFE_NO_PAD.encode(hash)
}
}
}
pub fn verifier(&self) -> &str {
&self.verifier
}
pub fn challenge(&self) -> &str {
&self.challenge
}
pub fn method(&self) -> PkceMethod {
self.method
}
pub fn authorization_params(&self) -> (&str, &str) {
(&self.challenge, self.method.as_str())
}
pub fn verify(verifier: &str, challenge: &str, method: PkceMethod) -> bool {
let computed = Self::compute_challenge(verifier, method);
crate::random::constant_time_compare_str(&computed, challenge)
}
}
#[derive(Debug, Clone)]
pub struct PkceVerifier(String);
impl PkceVerifier {
pub fn from_challenge(challenge: &PkceChallenge) -> Self {
Self(challenge.verifier.clone())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
}
impl AsRef<str> for PkceVerifier {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct PkceCodeChallenge {
pub challenge: String,
pub method: PkceMethod,
}
impl PkceCodeChallenge {
pub fn new(challenge: String, method: PkceMethod) -> Self {
Self { challenge, method }
}
pub fn from_challenge(pkce: &PkceChallenge) -> Self {
Self {
challenge: pkce.challenge.clone(),
method: pkce.method,
}
}
pub fn verify(&self, verifier: &str) -> bool {
PkceChallenge::verify(verifier, &self.challenge, self.method)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pkce_s256_challenge() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
assert!(challenge.verifier().len() >= 43);
assert!(challenge.verifier().len() <= 128);
assert_eq!(challenge.challenge().len(), 43);
assert!(PkceChallenge::verify(
challenge.verifier(),
challenge.challenge(),
PkceMethod::S256
));
}
#[test]
fn test_pkce_plain_challenge() {
let challenge = PkceChallenge::new(PkceMethod::Plain).unwrap();
assert_eq!(challenge.verifier(), challenge.challenge());
assert!(PkceChallenge::verify(
challenge.verifier(),
challenge.challenge(),
PkceMethod::Plain
));
}
#[test]
fn test_pkce_verification_fails_with_wrong_verifier() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
assert!(!PkceChallenge::verify(
"wrong_verifier_that_is_long_enough_to_pass_length_check",
challenge.challenge(),
PkceMethod::S256
));
}
#[test]
fn test_pkce_verification_fails_with_wrong_method() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
assert!(!PkceChallenge::verify(
challenge.verifier(),
challenge.challenge(),
PkceMethod::Plain
));
}
#[test]
fn test_pkce_method_from_str() {
use std::str::FromStr;
assert_eq!(PkceMethod::from_str("S256").unwrap(), PkceMethod::S256);
assert_eq!(PkceMethod::from_str("s256").unwrap(), PkceMethod::S256);
assert_eq!(PkceMethod::from_str("plain").unwrap(), PkceMethod::Plain);
assert_eq!(PkceMethod::from_str("PLAIN").unwrap(), PkceMethod::Plain);
assert!(PkceMethod::from_str("invalid").is_err());
assert_eq!(PkceMethod::parse("S256").unwrap(), PkceMethod::S256);
assert!(PkceMethod::parse("invalid").is_err());
}
#[test]
fn test_pkce_method_as_str() {
assert_eq!(PkceMethod::S256.as_str(), "S256");
assert_eq!(PkceMethod::Plain.as_str(), "plain");
}
#[test]
fn test_pkce_from_verifier() {
let original = PkceChallenge::new(PkceMethod::S256).unwrap();
let verifier = original.verifier().to_string();
let restored = PkceChallenge::from_verifier(verifier, PkceMethod::S256).unwrap();
assert_eq!(original.verifier(), restored.verifier());
assert_eq!(original.challenge(), restored.challenge());
}
#[test]
fn test_pkce_from_verifier_invalid_length() {
assert!(PkceChallenge::from_verifier("short".to_string(), PkceMethod::S256).is_err());
let long = "a".repeat(129);
assert!(PkceChallenge::from_verifier(long, PkceMethod::S256).is_err());
}
#[test]
fn test_pkce_from_verifier_invalid_chars() {
let invalid = "a".repeat(43) + "!@#"; assert!(PkceChallenge::from_verifier(invalid, PkceMethod::S256).is_err());
}
#[test]
fn test_pkce_config_validation() {
assert!(PkceConfig::default().validate().is_ok());
assert!(PkceConfig::high_security().validate().is_ok());
let short_config = PkceConfig {
verifier_length: 10,
method: PkceMethod::S256,
};
assert!(short_config.validate().is_err());
let long_config = PkceConfig {
verifier_length: 200,
method: PkceMethod::S256,
};
assert!(long_config.validate().is_err());
}
#[test]
fn test_pkce_authorization_params() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
let (code_challenge, method) = challenge.authorization_params();
assert_eq!(code_challenge, challenge.challenge());
assert_eq!(method, "S256");
}
#[test]
fn test_pkce_code_challenge() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
let code_challenge = PkceCodeChallenge::from_challenge(&challenge);
assert!(code_challenge.verify(challenge.verifier()));
assert!(!code_challenge.verify("wrong_verifier_long_enough_to_pass"));
}
#[test]
fn test_pkce_verifier() {
let challenge = PkceChallenge::new(PkceMethod::S256).unwrap();
let verifier = PkceVerifier::from_challenge(&challenge);
assert_eq!(verifier.as_str(), challenge.verifier());
assert_eq!(verifier.as_ref(), challenge.verifier());
}
#[test]
fn test_pkce_uniqueness() {
let c1 = PkceChallenge::new(PkceMethod::S256).unwrap();
let c2 = PkceChallenge::new(PkceMethod::S256).unwrap();
assert_ne!(c1.verifier(), c2.verifier());
assert_ne!(c1.challenge(), c2.challenge());
}
#[test]
fn test_pkce_high_security_config() {
let config = PkceConfig::high_security();
let challenge = PkceChallenge::with_config(config).unwrap();
assert!(challenge.verifier().len() > 43);
assert!(PkceChallenge::verify(
challenge.verifier(),
challenge.challenge(),
PkceMethod::S256
));
}
}