use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::Engine;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::{Digest, Sha256};
use crate::errors::RpcError;
type HmacSha256 = Hmac<Sha256>;
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[derive(Clone, Debug)]
pub struct PkcePair {
pub verifier: String,
pub challenge: String,
}
pub fn generate_pkce_pair() -> PkcePair {
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
let digest = Sha256::digest(verifier.as_bytes());
let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
PkcePair {
verifier,
challenge,
}
}
#[derive(Clone, Debug)]
pub struct PkceState {
pub cookie: String,
pub state: String,
}
pub fn new_state_cookie(signing_key: &[u8], return_to: &str, pair: &PkcePair) -> PkceState {
let state = random_state();
let created_at = unix_now();
let payload = format!("{state}\n{created_at}\n{return_to}\n{}", pair.verifier);
let mut mac = HmacSha256::new_from_slice(signing_key).expect("hmac key");
mac.update(payload.as_bytes());
let sig = mac.finalize().into_bytes();
let mut raw = Vec::with_capacity(payload.len() + sig.len() + 1);
raw.extend_from_slice(payload.as_bytes());
raw.push(b'|');
raw.extend_from_slice(&sig);
PkceState {
cookie: base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw),
state,
}
}
pub fn verify_state_cookie(
signing_key: &[u8],
cookie: &str,
max_age: Duration,
) -> Result<(String, String, String), RpcError> {
let raw = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(cookie.as_bytes())
.map_err(|_| RpcError::value_error("malformed PKCE state cookie"))?;
const SIG_LEN: usize = 32;
if raw.len() < SIG_LEN + 1 {
return Err(RpcError::value_error("malformed PKCE state cookie"));
}
let (payload_with_sep, sig) = raw.split_at(raw.len() - SIG_LEN);
let payload = payload_with_sep
.split_last()
.filter(|(sep, _)| **sep == b'|')
.map(|(_, p)| p)
.ok_or_else(|| RpcError::value_error("malformed PKCE state cookie"))?;
let mut mac = HmacSha256::new_from_slice(signing_key).expect("hmac key");
mac.update(payload);
mac.verify_slice(sig)
.map_err(|_| RpcError::value_error("PKCE state cookie signature mismatch"))?;
let s = std::str::from_utf8(payload)
.map_err(|_| RpcError::value_error("malformed PKCE state cookie"))?;
let mut parts = s.splitn(4, '\n');
let state = parts.next().unwrap_or("").to_string();
let created_at: u64 = parts
.next()
.and_then(|t| t.parse().ok())
.ok_or_else(|| RpcError::value_error("malformed PKCE state cookie"))?;
let return_to = parts.next().unwrap_or("").to_string();
let verifier = parts.next().unwrap_or("").to_string();
let age = unix_now().saturating_sub(created_at);
if age > max_age.as_secs() {
return Err(RpcError::value_error("PKCE state cookie expired"));
}
Ok((state, return_to, verifier))
}
pub fn is_allowed_return_origin(return_to: &str, allow: &[&str]) -> bool {
let Some((scheme_end, _)) = return_to.find("://").map(|i| (i, ())) else {
return false;
};
let after_scheme = &return_to[scheme_end + 3..];
let host = after_scheme.split(['/', '?', '#']).next().unwrap_or("");
let origin = &return_to[..scheme_end + 3 + host.len()];
allow.iter().any(|a| *a == origin)
}
fn random_state() -> String {
let mut b = [0u8; 24];
rand::thread_rng().fill_bytes(&mut b);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pair_challenge_matches_sha256() {
let p = generate_pkce_pair();
let expected = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(Sha256::digest(p.verifier.as_bytes()));
assert_eq!(p.challenge, expected);
}
#[test]
fn cookie_round_trip() {
let key = [9u8; 32];
let pair = PkcePair {
verifier: "v-abc".into(),
challenge: "c-abc".into(),
};
let pkce = new_state_cookie(&key, "https://app.example/welcome", &pair);
let (state, rt, verifier) =
verify_state_cookie(&key, &pkce.cookie, Duration::from_secs(600)).unwrap();
assert_eq!(state, pkce.state);
assert_eq!(rt, "https://app.example/welcome");
assert_eq!(verifier, "v-abc");
}
#[test]
fn cookie_rejects_bad_signature() {
let key = [1u8; 32];
let pair = PkcePair {
verifier: "v".into(),
challenge: "c".into(),
};
let pkce = new_state_cookie(&key, "/x", &pair);
let wrong_key = [2u8; 32];
assert!(verify_state_cookie(&wrong_key, &pkce.cookie, Duration::from_secs(600)).is_err());
}
#[test]
fn cookie_rejects_when_expired() {
let key = [3u8; 32];
let pair = PkcePair {
verifier: "v".into(),
challenge: "c".into(),
};
let pkce = new_state_cookie(&key, "/x", &pair);
let err = verify_state_cookie(&key, &pkce.cookie, Duration::ZERO);
if let Err(e) = err {
assert!(e.message.contains("expired"), "{}", e.message);
}
}
#[test]
fn state_param_does_not_leak_verifier() {
let key = [4u8; 32];
let pair = PkcePair {
verifier: "super-secret-verifier".into(),
challenge: "c".into(),
};
let pkce = new_state_cookie(&key, "/x", &pair);
assert_ne!(pkce.state, pkce.cookie);
assert!(!pkce.state.contains("super-secret-verifier"));
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(pkce.state.as_bytes())
.unwrap_or_default();
assert!(!String::from_utf8_lossy(&decoded).contains("super-secret-verifier"));
}
#[test]
fn allowed_origin_matches_scheme_and_host() {
let allow = ["https://app.example"];
assert!(is_allowed_return_origin("https://app.example/x", &allow));
assert!(!is_allowed_return_origin("https://evil.example/x", &allow));
assert!(!is_allowed_return_origin("http://app.example/x", &allow));
}
}