use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum CodeChallenge {
None,
Plain {
code_challenge: String,
},
S256 {
code_challenge: String,
},
}
impl CodeChallenge {
pub(crate) fn take(&mut self) -> Self {
std::mem::replace(self, Self::None)
}
pub fn verify(&self, verifier: &str) -> bool {
match self {
Self::None => false, Self::Plain { code_challenge } => {
verifier.as_bytes().ct_eq(code_challenge.as_bytes()).into()
}
Self::S256 { code_challenge } => {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(hasher.finalize());
code_verifier.as_bytes().ct_eq(code_challenge.as_bytes()).into()
}
}
}
}
#[cfg(test)]
mod test {
#[test]
fn test_code_challenge_verify() {
use super::CodeChallenge;
let plain_challenge = CodeChallenge::Plain { code_challenge: "verifier".to_string() };
let s256_challenge = CodeChallenge::S256 {
code_challenge: "qoJXAtQ-gjzfDmoMrHt1a2AFVe1Tn3-HX0VC2_UtezA".to_string(),
};
assert!(plain_challenge.verify("verifier"));
assert!(!plain_challenge.verify("wrong_verifier"));
assert!(!plain_challenge.verify(""));
assert!(s256_challenge.verify("code_challenge"));
assert!(!s256_challenge.verify("wrong_code_challenge"));
assert!(!s256_challenge.verify(""));
}
}