use anyhow::{Context, Result};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
const CODE_VERIFIER_LENGTH: usize = 64;
const CODE_VERIFIER_CHARSET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
#[derive(Debug, Clone)]
pub struct PkceChallenge {
pub code_verifier: String,
pub code_challenge: String,
pub code_challenge_method: String,
}
impl PkceChallenge {
pub fn from_verifier(code_verifier: String) -> Result<Self> {
let code_challenge = compute_s256_challenge(&code_verifier)?;
Ok(Self {
code_verifier,
code_challenge,
code_challenge_method: "S256".to_string(),
})
}
}
pub fn generate_pkce_challenge() -> Result<PkceChallenge> {
let code_verifier = generate_code_verifier()?;
PkceChallenge::from_verifier(code_verifier)
}
fn generate_code_verifier() -> Result<String> {
use std::time::{SystemTime, UNIX_EPOCH};
let mut verifier = String::with_capacity(CODE_VERIFIER_LENGTH);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("System time before UNIX epoch")?
.as_nanos();
let pid = std::process::id() as u128;
let mut state = nanos.wrapping_add(pid);
for _ in 0..CODE_VERIFIER_LENGTH {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let extra = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
state = state.wrapping_add(extra);
let idx = (state % CODE_VERIFIER_CHARSET.len() as u128) as usize;
verifier.push(CODE_VERIFIER_CHARSET[idx] as char);
}
Ok(verifier)
}
fn compute_s256_challenge(code_verifier: &str) -> Result<String> {
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let hash = hasher.finalize();
Ok(URL_SAFE_NO_PAD.encode(hash))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_pkce_challenge() {
let challenge = generate_pkce_challenge().unwrap();
assert_eq!(challenge.code_verifier.len(), CODE_VERIFIER_LENGTH);
for c in challenge.code_verifier.chars() {
assert!(
CODE_VERIFIER_CHARSET.contains(&(c as u8)),
"Invalid character in verifier: {}",
c
);
}
assert_eq!(challenge.code_challenge_method, "S256");
assert_eq!(challenge.code_challenge.len(), 43);
}
#[test]
fn test_deterministic_challenge() {
let verifier = "test_verifier_string_for_deterministic_test";
let challenge1 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
let challenge2 = PkceChallenge::from_verifier(verifier.to_string()).unwrap();
assert_eq!(challenge1.code_challenge, challenge2.code_challenge);
}
#[test]
fn test_unique_verifiers() {
let c1 = generate_pkce_challenge().unwrap();
let c2 = generate_pkce_challenge().unwrap();
assert_ne!(c1.code_verifier, c2.code_verifier);
}
}