use crate::core::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Header {
pub name: String,
pub value: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SameSite {
Strict,
Lax,
None,
}
pub struct SafeHeader;
impl SafeHeader {
pub fn has_crlf(s: &str) -> bool {
s.contains('\r') || s.contains('\n')
}
pub fn is_valid_name(name: &str) -> bool {
if name.is_empty() || name.len() > 256 {
return false;
}
name.chars().all(Self::is_valid_token_char)
}
fn is_valid_token_char(c: char) -> bool {
matches!(c,
'!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' |
'^' | '_' | '`' | '|' | '~' |
'a'..='z' | 'A'..='Z' | '0'..='9'
)
}
const DANGEROUS_HEADERS: &'static [&'static str] = &[
"proxy-authorization",
"proxy-authenticate",
"proxy-connection",
"transfer-encoding",
"content-length",
"host",
"connection",
"keep-alive",
"upgrade",
"te",
"trailer",
];
pub fn is_dangerous(name: &str) -> bool {
let lower = name.to_lowercase();
Self::DANGEROUS_HEADERS.contains(&lower.as_str())
}
pub fn make(name: &str, value: &str) -> Result<Header> {
let trimmed_name = name.trim();
let trimmed_value = value.trim();
if !Self::is_valid_name(trimmed_name) {
return Err(Error::InvalidFormat("Invalid header name".into()));
}
if Self::has_crlf(trimmed_value) {
return Err(Error::InvalidFormat("Header value contains CRLF".into()));
}
if trimmed_value.len() > 8192 {
return Err(Error::TooLong("Header value too long".into()));
}
Ok(Header {
name: trimmed_name.to_string(),
value: trimmed_value.to_string(),
})
}
pub fn make_safe(name: &str, value: &str) -> Result<Header> {
if Self::is_dangerous(name) {
return Err(Error::InvalidFormat("Dangerous header not allowed".into()));
}
Self::make(name, value)
}
pub fn render(header: &Header) -> String {
format!("{}: {}", header.name, header.value)
}
pub fn build_hsts(max_age: u64, include_subdomains: bool, preload: bool) -> String {
let mut value = format!("max-age={}", max_age);
if include_subdomains {
value.push_str("; includeSubDomains");
}
if preload {
value.push_str("; preload");
}
value
}
pub fn build_csp(directives: &[(String, Vec<String>)]) -> String {
directives
.iter()
.map(|(name, sources)| {
if sources.is_empty() {
name.clone()
} else {
format!("{} {}", name, sources.join(" "))
}
})
.collect::<Vec<_>>()
.join("; ")
}
pub fn security_headers() -> Vec<Header> {
vec![
Header {
name: "X-Frame-Options".to_string(),
value: "DENY".to_string(),
},
Header {
name: "X-Content-Type-Options".to_string(),
value: "nosniff".to_string(),
},
Header {
name: "Referrer-Policy".to_string(),
value: "strict-origin-when-cross-origin".to_string(),
},
Header {
name: "X-XSS-Protection".to_string(),
value: "1; mode=block".to_string(),
},
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_has_crlf() {
assert!(!SafeHeader::has_crlf("application/json"));
assert!(SafeHeader::has_crlf("text\r\nX-Injected: evil"));
assert!(SafeHeader::has_crlf("value\ninjected"));
}
#[test]
fn test_is_valid_name() {
assert!(SafeHeader::is_valid_name("Content-Type"));
assert!(SafeHeader::is_valid_name("X-Custom-Header"));
assert!(!SafeHeader::is_valid_name("Content:Type"));
assert!(!SafeHeader::is_valid_name(""));
}
#[test]
fn test_is_dangerous() {
assert!(SafeHeader::is_dangerous("Host"));
assert!(SafeHeader::is_dangerous("Transfer-Encoding"));
assert!(!SafeHeader::is_dangerous("X-Custom-Header"));
}
#[test]
fn test_build_hsts() {
let hsts = SafeHeader::build_hsts(31536000, true, true);
assert_eq!(hsts, "max-age=31536000; includeSubDomains; preload");
}
#[test]
fn test_make_header() {
let header = SafeHeader::make("Content-Type", "application/json").unwrap();
assert_eq!(header.name, "Content-Type");
assert_eq!(header.value, "application/json");
}
#[test]
fn test_reject_crlf() {
let result = SafeHeader::make("Content-Type", "text\r\nX-Injected: evil");
assert!(result.is_err());
}
}