use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone)]
pub struct CsrfTokens {
key: Vec<u8>,
}
impl CsrfTokens {
pub fn new(key: impl Into<Vec<u8>>) -> Self {
let key = key.into();
assert!(key.len() >= 32, "veer csrf secret must be >= 32 bytes");
Self { key }
}
fn sign(&self, payload: &[u8]) -> String {
let mut mac = HmacSha256::new_from_slice(&self.key).expect("hmac key");
mac.update(payload);
URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes())
}
pub fn generate(&self) -> String {
let mut rand = [0u8; 32];
getrandom::fill(&mut rand).expect("getrandom");
let rand_b64 = URL_SAFE_NO_PAD.encode(rand);
let sig = self.sign(rand_b64.as_bytes());
format!("{rand_b64}.{sig}")
}
pub(crate) fn is_valid(&self, token: &str) -> bool {
let Some((rand_b64, sig_b64)) = token.split_once('.') else {
return false;
};
let Ok(sig_bytes) = URL_SAFE_NO_PAD.decode(sig_b64) else {
return false;
};
let Ok(mut mac) = HmacSha256::new_from_slice(&self.key) else {
return false;
};
mac.update(rand_b64.as_bytes());
mac.verify_slice(&sig_bytes).is_ok()
}
pub fn verify(&self, cookie_value: &str, header_value: &str) -> bool {
ct_eq(cookie_value.as_bytes(), header_value.as_bytes()) && self.is_valid(cookie_value)
}
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
a.ct_eq(b).into()
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: &[u8] = b"0123456789012345678901234567890123456789";
#[test]
fn generate_then_verify_roundtrips() {
let t = CsrfTokens::new(KEY.to_vec());
let token = t.generate();
assert!(t.verify(&token, &token));
assert!(t.is_valid(&token));
}
#[test]
fn tampered_signature_fails() {
let t = CsrfTokens::new(KEY.to_vec());
let token = t.generate();
let longer = format!("{token}x");
assert!(!t.is_valid(&longer));
assert!(!t.verify(&longer, &longer));
let (rand_b64, sig_b64) = token.split_once('.').unwrap();
let mut sig: Vec<u8> = sig_b64.bytes().collect();
let mid = sig.len() / 2;
sig[mid] = if sig[mid] == b'A' { b'B' } else { b'A' };
let flipped = format!("{rand_b64}.{}", String::from_utf8(sig).unwrap());
assert_eq!(flipped.len(), token.len());
assert!(!t.is_valid(&flipped));
assert!(!t.verify(&flipped, &flipped));
}
#[test]
fn token_from_other_key_fails() {
let a = CsrfTokens::new(KEY.to_vec());
let b = CsrfTokens::new(b"abcdefghabcdefghabcdefghabcdefgh".to_vec());
let token = a.generate();
assert!(!b.verify(&token, &token));
}
#[test]
fn cookie_header_mismatch_fails_even_when_both_valid() {
let t = CsrfTokens::new(KEY.to_vec());
let c = t.generate();
let h = t.generate();
assert!(t.is_valid(&c) && t.is_valid(&h));
assert!(!t.verify(&c, &h));
}
#[test]
fn malformed_token_fails() {
let t = CsrfTokens::new(KEY.to_vec());
assert!(!t.is_valid("no-dot"));
assert!(!t.verify("no-dot", "no-dot"));
}
}