use crate::{Error, Result};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::time::Duration;
type HmacSha256 = Hmac<Sha256>;
const SIGNATURE_SEPARATOR: &str = ".";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum SameSite {
Strict,
#[default]
Lax,
None,
}
impl std::fmt::Display for SameSite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SameSite::Strict => write!(f, "Strict"),
SameSite::Lax => write!(f, "Lax"),
SameSite::None => write!(f, "None"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureCookie {
pub name: String,
pub value: String,
#[serde(default)]
pub http_only: bool,
#[serde(default)]
pub secure: bool,
#[serde(default)]
pub same_site: SameSite,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_age: Option<Duration>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub domain: Option<String>,
#[serde(default)]
is_signed: bool,
#[serde(skip_serializing_if = "Option::is_none")]
original_value: Option<String>,
}
impl SecureCookie {
pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
Self {
name: name.into(),
value: value.into(),
http_only: false,
secure: false,
same_site: SameSite::Lax,
max_age: None,
expires: None,
path: None,
domain: None,
is_signed: false,
original_value: None,
}
}
pub fn session(name: impl Into<String>, value: impl Into<String>) -> Self {
Self::new(name, value)
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
}
pub fn http_only(mut self, http_only: bool) -> Self {
self.http_only = http_only;
self
}
pub fn secure(mut self, secure: bool) -> Self {
self.secure = secure;
self
}
pub fn same_site(mut self, same_site: SameSite) -> Self {
self.same_site = same_site;
self
}
pub fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = Some(max_age);
self
}
pub fn max_age_secs(mut self, secs: u64) -> Self {
self.max_age = Some(Duration::from_secs(secs));
self
}
pub fn expires(mut self, expires: impl Into<String>) -> Self {
self.expires = Some(expires.into());
self
}
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
pub fn domain(mut self, domain: impl Into<String>) -> Self {
self.domain = Some(domain.into());
self
}
pub fn signed(mut self, secret: &[u8]) -> Self {
self.original_value = Some(self.value.clone());
self.value = sign_cookie(&self.value, secret);
self.is_signed = true;
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn value(&self) -> &str {
&self.value
}
pub fn is_signed(&self) -> bool {
self.is_signed
}
pub fn verify_value(&self, secret: &[u8]) -> Result<String> {
verify_cookie(&self.value, secret)
}
pub fn to_header_value(&self) -> String {
let mut parts = vec![format!("{}={}", self.name, self.value)];
if self.http_only {
parts.push("HttpOnly".to_string());
}
if self.secure {
parts.push("Secure".to_string());
}
parts.push(format!("SameSite={}", self.same_site));
if let Some(ref max_age) = self.max_age {
parts.push(format!("Max-Age={}", max_age.as_secs()));
}
if let Some(ref expires) = self.expires {
parts.push(format!("Expires={}", expires));
}
if let Some(ref path) = self.path {
parts.push(format!("Path={}", path));
}
if let Some(ref domain) = self.domain {
parts.push(format!("Domain={}", domain));
}
parts.join("; ")
}
pub fn parse(cookie_str: &str) -> Option<Self> {
let parts: Vec<&str> = cookie_str.splitn(2, '=').collect();
if parts.len() != 2 {
return None;
}
let name = parts[0].trim();
let value = parts[1].split(';').next()?.trim();
Some(Self::new(name, value))
}
}
pub fn sign_cookie(value: &str, secret: &[u8]) -> String {
let encoded_value = URL_SAFE_NO_PAD.encode(value.as_bytes());
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(encoded_value.as_bytes());
let signature = mac.finalize().into_bytes();
let encoded_signature = URL_SAFE_NO_PAD.encode(signature);
format!(
"{}{}{}",
encoded_value, SIGNATURE_SEPARATOR, encoded_signature
)
}
pub fn verify_cookie(signed_value: &str, secret: &[u8]) -> Result<String> {
let parts: Vec<&str> = signed_value.rsplitn(2, SIGNATURE_SEPARATOR).collect();
if parts.len() != 2 {
return Err(Error::validation("Invalid signed cookie format"));
}
let encoded_signature = parts[0];
let encoded_value = parts[1];
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(encoded_value.as_bytes());
let expected_signature = URL_SAFE_NO_PAD
.decode(encoded_signature)
.map_err(|_| Error::validation("Invalid signature encoding"))?;
mac.verify_slice(&expected_signature)
.map_err(|_| Error::validation("Cookie signature verification failed"))?;
let value_bytes = URL_SAFE_NO_PAD
.decode(encoded_value)
.map_err(|_| Error::validation("Invalid value encoding"))?;
String::from_utf8(value_bytes).map_err(|_| Error::validation("Invalid UTF-8 in cookie value"))
}
pub fn delete_cookie_header(name: &str, path: Option<&str>) -> String {
let mut parts = vec![
format!("{}=", name),
"Max-Age=0".to_string(),
"Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
];
if let Some(p) = path {
parts.push(format!("Path={}", p));
}
parts.join("; ")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &[u8] = b"test-secret-key-at-least-32-bytes!!";
#[test]
fn test_sign_and_verify() {
let value = "test_value_123";
let signed = sign_cookie(value, TEST_SECRET);
assert!(signed.contains(SIGNATURE_SEPARATOR));
let verified = verify_cookie(&signed, TEST_SECRET).unwrap();
assert_eq!(verified, value);
}
#[test]
fn test_verify_with_wrong_secret() {
let value = "test_value";
let signed = sign_cookie(value, TEST_SECRET);
let wrong_secret = b"wrong-secret-key-at-least-32-bytes!!";
let result = verify_cookie(&signed, wrong_secret);
assert!(result.is_err());
}
#[test]
fn test_verify_tampered_value() {
let value = "original_value";
let signed = sign_cookie(value, TEST_SECRET);
let parts: Vec<&str> = signed.split('.').collect();
if parts.len() == 2 {
let mut sig_bytes = parts[1].as_bytes().to_vec();
if !sig_bytes.is_empty() {
sig_bytes[0] = if sig_bytes[0] == b'A' { b'B' } else { b'A' };
}
let tampered_sig = String::from_utf8(sig_bytes).unwrap();
let tampered = format!("{}.{}", parts[0], tampered_sig);
let result = verify_cookie(&tampered, TEST_SECRET);
assert!(result.is_err());
}
}
#[test]
fn test_verify_invalid_format() {
let result = verify_cookie("no_separator_here", TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_secure_cookie_new() {
let cookie = SecureCookie::new("session", "abc123");
assert_eq!(cookie.name(), "session");
assert_eq!(cookie.value(), "abc123");
assert!(!cookie.http_only);
assert!(!cookie.secure);
assert_eq!(cookie.same_site, SameSite::Lax);
}
#[test]
fn test_secure_cookie_session() {
let cookie = SecureCookie::session("session", "abc123");
assert!(cookie.http_only);
assert!(cookie.secure);
assert_eq!(cookie.same_site, SameSite::Strict);
}
#[test]
fn test_secure_cookie_builder() {
let cookie = SecureCookie::new("session", "abc123")
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
.max_age(Duration::from_secs(3600))
.path("/app")
.domain("example.com");
assert!(cookie.http_only);
assert!(cookie.secure);
assert_eq!(cookie.same_site, SameSite::Strict);
assert_eq!(cookie.max_age, Some(Duration::from_secs(3600)));
assert_eq!(cookie.path, Some("/app".to_string()));
assert_eq!(cookie.domain, Some("example.com".to_string()));
}
#[test]
fn test_secure_cookie_signed() {
let cookie = SecureCookie::new("session", "user123").signed(TEST_SECRET);
assert!(cookie.is_signed());
assert!(cookie.value().contains(SIGNATURE_SEPARATOR));
let verified = cookie.verify_value(TEST_SECRET).unwrap();
assert_eq!(verified, "user123");
}
#[test]
fn test_to_header_value() {
let cookie = SecureCookie::new("session", "abc123")
.http_only(true)
.secure(true)
.same_site(SameSite::Strict)
.max_age(Duration::from_secs(3600))
.path("/");
let header = cookie.to_header_value();
assert!(header.contains("session=abc123"));
assert!(header.contains("HttpOnly"));
assert!(header.contains("Secure"));
assert!(header.contains("SameSite=Strict"));
assert!(header.contains("Max-Age=3600"));
assert!(header.contains("Path=/"));
}
#[test]
fn test_to_header_value_minimal() {
let cookie = SecureCookie::new("name", "value");
let header = cookie.to_header_value();
assert!(header.contains("name=value"));
assert!(header.contains("SameSite=Lax"));
assert!(!header.contains("HttpOnly"));
assert!(!header.contains("Secure"));
}
#[test]
fn test_delete_cookie_header() {
let header = delete_cookie_header("session", Some("/"));
assert!(header.contains("session="));
assert!(header.contains("Max-Age=0"));
assert!(header.contains("Expires="));
assert!(header.contains("Path=/"));
}
#[test]
fn test_parse_cookie() {
let cookie = SecureCookie::parse("session=abc123; HttpOnly; Secure").unwrap();
assert_eq!(cookie.name(), "session");
assert_eq!(cookie.value(), "abc123");
}
#[test]
fn test_parse_cookie_invalid() {
let result = SecureCookie::parse("invalid");
assert!(result.is_none());
}
#[test]
fn test_same_site_display() {
assert_eq!(format!("{}", SameSite::Strict), "Strict");
assert_eq!(format!("{}", SameSite::Lax), "Lax");
assert_eq!(format!("{}", SameSite::None), "None");
}
#[test]
fn test_sign_cookie_with_special_chars() {
let value = "user@example.com|token=abc123&data=xyz";
let signed = sign_cookie(value, TEST_SECRET);
let verified = verify_cookie(&signed, TEST_SECRET).unwrap();
assert_eq!(verified, value);
}
#[test]
fn test_sign_cookie_with_unicode() {
let value = "用户名: 张三";
let signed = sign_cookie(value, TEST_SECRET);
let verified = verify_cookie(&signed, TEST_SECRET).unwrap();
assert_eq!(verified, value);
}
#[test]
fn test_max_age_secs() {
let cookie = SecureCookie::new("test", "value").max_age_secs(7200);
assert_eq!(cookie.max_age, Some(Duration::from_secs(7200)));
}
#[test]
fn test_same_site_none_with_secure() {
let cookie = SecureCookie::new("cross_site", "value")
.same_site(SameSite::None)
.secure(true);
let header = cookie.to_header_value();
assert!(header.contains("SameSite=None"));
assert!(header.contains("Secure"));
}
}