use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use subtle::ConstantTimeEq;
use crate::error::{Error, Result};
use crate::random::generate_random_bytes;
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Clone)]
pub struct CsrfConfig {
secret: Vec<u8>,
token_length: usize,
ttl: Duration,
}
impl Default for CsrfConfig {
fn default() -> Self {
let secret = generate_random_bytes(32).expect("Failed to generate random secret");
Self {
secret,
token_length: 32,
ttl: Duration::from_secs(3600), }
}
}
impl CsrfConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_secret(mut self, secret: &[u8]) -> Self {
self.secret = secret.to_vec();
self
}
pub fn with_token_length(mut self, length: usize) -> Self {
self.token_length = length;
self
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
}
#[derive(Debug, Clone)]
pub struct CsrfToken {
pub token: String,
pub created_at: u64,
pub expires_at: u64,
}
impl CsrfToken {
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now > self.expires_at
}
pub fn remaining_ttl(&self) -> Option<u64> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now > self.expires_at {
None
} else {
Some(self.expires_at - now)
}
}
}
#[derive(Debug, Clone)]
pub struct CsrfProtection {
config: CsrfConfig,
}
impl CsrfProtection {
pub fn new(config: CsrfConfig) -> Self {
Self { config }
}
pub fn with_default() -> Self {
Self::new(CsrfConfig::default())
}
pub fn with_secret(secret: &[u8]) -> Self {
Self::new(CsrfConfig::default().with_secret(secret))
}
pub fn generate_token(&self) -> Result<CsrfToken> {
let random_data = generate_random_bytes(self.config.token_length)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::internal(format!("系统时间错误: {}", e)))?
.as_secs();
let expires_at = now + self.config.ttl.as_secs();
let random_b64 = URL_SAFE_NO_PAD.encode(&random_data);
let timestamp_b64 = URL_SAFE_NO_PAD.encode(now.to_be_bytes());
let signature = self.sign(&random_data, now)?;
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
let token = format!("{}.{}.{}", random_b64, timestamp_b64, signature_b64);
Ok(CsrfToken {
token,
created_at: now,
expires_at,
})
}
pub fn verify(&self, token: &str) -> Result<bool> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Ok(false);
}
let random_data = match URL_SAFE_NO_PAD.decode(parts[0]) {
Ok(data) => data,
Err(_) => return Ok(false),
};
let timestamp_bytes = match URL_SAFE_NO_PAD.decode(parts[1]) {
Ok(data) => data,
Err(_) => return Ok(false),
};
let provided_signature = match URL_SAFE_NO_PAD.decode(parts[2]) {
Ok(data) => data,
Err(_) => return Ok(false),
};
if timestamp_bytes.len() != 8 {
return Ok(false);
}
let mut timestamp_arr = [0u8; 8];
timestamp_arr.copy_from_slice(×tamp_bytes);
let timestamp = u64::from_be_bytes(timestamp_arr);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::internal(format!("系统时间错误: {}", e)))?
.as_secs();
let expires_at = timestamp + self.config.ttl.as_secs();
if now > expires_at {
return Ok(false);
}
let expected_signature = self.sign(&random_data, timestamp)?;
Ok(provided_signature.ct_eq(&expected_signature).into())
}
pub fn verify_and_decode(&self, token: &str) -> Result<CsrfToken> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(Error::validation("无效的 token 格式"));
}
let random_data = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|_| Error::validation("无效的 token 格式"))?;
let timestamp_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|_| Error::validation("无效的 token 格式"))?;
let provided_signature = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|_| Error::validation("无效的 token 格式"))?;
if timestamp_bytes.len() != 8 {
return Err(Error::validation("无效的时间戳"));
}
let mut timestamp_arr = [0u8; 8];
timestamp_arr.copy_from_slice(×tamp_bytes);
let timestamp = u64::from_be_bytes(timestamp_arr);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error::internal(format!("系统时间错误: {}", e)))?
.as_secs();
let expires_at = timestamp + self.config.ttl.as_secs();
if now > expires_at {
return Err(Error::validation("Token 已过期"));
}
let expected_signature = self.sign(&random_data, timestamp)?;
if !bool::from(provided_signature.ct_eq(&expected_signature)) {
return Err(Error::validation("签名验证失败"));
}
Ok(CsrfToken {
token: token.to_string(),
created_at: timestamp,
expires_at,
})
}
fn sign(&self, data: &[u8], timestamp: u64) -> Result<Vec<u8>> {
let mut mac = HmacSha256::new_from_slice(&self.config.secret)
.map_err(|e| Error::internal(format!("HMAC 初始化失败: {}", e)))?;
mac.update(data);
mac.update(×tamp.to_be_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
}
#[derive(Debug, Clone)]
pub struct DoubleSubmitCsrf {
config: CsrfConfig,
}
impl DoubleSubmitCsrf {
pub fn new(config: CsrfConfig) -> Self {
Self { config }
}
pub fn with_default() -> Self {
Self::new(CsrfConfig::default())
}
pub fn generate_token_pair(&self) -> Result<(String, String)> {
let random_data = generate_random_bytes(self.config.token_length)?;
let token = URL_SAFE_NO_PAD.encode(&random_data);
Ok((token.clone(), token))
}
pub fn verify(&self, cookie_token: &str, request_token: &str) -> bool {
cookie_token
.as_bytes()
.ct_eq(request_token.as_bytes())
.into()
}
}
#[derive(Debug, Clone)]
pub struct SignedDoubleSubmitCsrf {
protection: CsrfProtection,
}
impl SignedDoubleSubmitCsrf {
pub fn new(config: CsrfConfig) -> Self {
Self {
protection: CsrfProtection::new(config),
}
}
pub fn with_secret(secret: &[u8]) -> Self {
Self {
protection: CsrfProtection::with_secret(secret),
}
}
pub fn generate_token_pair(&self) -> Result<(String, String)> {
let csrf_token = self.protection.generate_token()?;
let cookie_token = csrf_token.token.clone();
let parts: Vec<&str> = csrf_token.token.split('.').collect();
let request_token = parts[0].to_string();
Ok((cookie_token, request_token))
}
pub fn verify(&self, cookie_token: &str, request_token: &str) -> Result<bool> {
if !self.protection.verify(cookie_token)? {
return Ok(false);
}
let parts: Vec<&str> = cookie_token.split('.').collect();
if parts.is_empty() {
return Ok(false);
}
Ok(parts[0].as_bytes().ct_eq(request_token.as_bytes()).into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_csrf_token_generation() {
let csrf = CsrfProtection::with_default();
let token = csrf.generate_token().unwrap();
assert!(!token.token.is_empty());
assert!(token.created_at > 0);
assert!(token.expires_at > token.created_at);
}
#[test]
fn test_csrf_token_verification() {
let csrf = CsrfProtection::with_default();
let token = csrf.generate_token().unwrap();
assert!(csrf.verify(&token.token).unwrap());
}
#[test]
fn test_csrf_invalid_token() {
let csrf = CsrfProtection::with_default();
assert!(!csrf.verify("invalid").unwrap_or(false));
assert!(!csrf.verify("a.b.c").unwrap_or(false));
assert!(!csrf.verify("").unwrap_or(false));
}
#[test]
fn test_csrf_tampered_token() {
let csrf = CsrfProtection::with_default();
let token = csrf.generate_token().unwrap();
let mut tampered = token.token.clone();
tampered.push('x');
assert!(!csrf.verify(&tampered).unwrap_or(false));
}
#[test]
fn test_csrf_different_secrets() {
let csrf1 = CsrfProtection::with_secret(b"secret1-32-bytes-long-xxxxxxxxx");
let csrf2 = CsrfProtection::with_secret(b"secret2-32-bytes-long-xxxxxxxxx");
let token = csrf1.generate_token().unwrap();
assert!(!csrf2.verify(&token.token).unwrap());
}
#[test]
fn test_csrf_expiration() {
let config = CsrfConfig::new().with_ttl(Duration::from_secs(1));
let csrf = CsrfProtection::new(config);
let token = csrf.generate_token().unwrap();
assert!(csrf.verify(&token.token).unwrap());
thread::sleep(Duration::from_secs(2));
assert!(!csrf.verify(&token.token).unwrap());
}
#[test]
fn test_csrf_token_is_expired() {
let config = CsrfConfig::new().with_ttl(Duration::from_secs(1));
let csrf = CsrfProtection::new(config);
let token = csrf.generate_token().unwrap();
assert!(!token.is_expired());
thread::sleep(Duration::from_secs(2));
assert!(token.is_expired());
}
#[test]
fn test_csrf_remaining_ttl() {
let config = CsrfConfig::new().with_ttl(Duration::from_secs(60));
let csrf = CsrfProtection::new(config);
let token = csrf.generate_token().unwrap();
let remaining = token.remaining_ttl().unwrap();
assert!(remaining > 0);
assert!(remaining <= 60);
}
#[test]
fn test_double_submit_csrf() {
let csrf = DoubleSubmitCsrf::with_default();
let (cookie, request) = csrf.generate_token_pair().unwrap();
assert!(csrf.verify(&cookie, &request));
assert!(!csrf.verify(&cookie, "wrong"));
}
#[test]
fn test_signed_double_submit_csrf() {
let csrf = SignedDoubleSubmitCsrf::with_secret(b"my-secret-key-32-bytes-long!!!!");
let (cookie, request) = csrf.generate_token_pair().unwrap();
assert!(csrf.verify(&cookie, &request).unwrap());
assert!(!csrf.verify(&cookie, "wrong").unwrap());
}
#[test]
fn test_verify_and_decode() {
let csrf = CsrfProtection::with_default();
let token = csrf.generate_token().unwrap();
let decoded = csrf.verify_and_decode(&token.token).unwrap();
assert_eq!(decoded.created_at, token.created_at);
assert_eq!(decoded.expires_at, token.expires_at);
}
#[test]
fn test_verify_and_decode_expired() {
let config = CsrfConfig::new().with_ttl(Duration::from_secs(1));
let csrf = CsrfProtection::new(config);
let token = csrf.generate_token().unwrap();
thread::sleep(Duration::from_secs(2));
assert!(csrf.verify_and_decode(&token.token).is_err());
}
#[test]
fn test_config_builder() {
let config = CsrfConfig::new()
.with_secret(b"test-secret-key-32-bytes-long!!")
.with_token_length(64)
.with_ttl(Duration::from_secs(7200));
let csrf = CsrfProtection::new(config);
let token = csrf.generate_token().unwrap();
assert!(token.token.len() > 100);
}
}