#[cfg(test)]
mod tests;
use rand_core::{OsRng, RngCore};
use crate::application::Application;
use crate::error::{AppError, IntoResponse};
use crate::header::Header;
use crate::middleware::Middleware;
use crate::request::Request;
use crate::response::Response;
use crate::server::ConnectionInfo;
const INJECTED_HEADER: &str = "X-Rws-Csrf-Token";
pub struct CsrfLayer {
cookie_name: String,
field_name: String,
header_name: String,
http_only: bool,
secure: bool,
same_site: String,
}
impl Default for CsrfLayer {
fn default() -> Self {
CsrfLayer {
cookie_name: "_csrf".to_string(),
field_name: "_csrf".to_string(),
header_name: "X-CSRF-Token".to_string(),
http_only: false,
secure: false,
same_site: "Strict".to_string(),
}
}
}
impl CsrfLayer {
pub fn new() -> Self {
Self::default()
}
pub fn cookie_name(mut self, name: &str) -> Self {
self.cookie_name = name.to_string();
self
}
pub fn field_name(mut self, name: &str) -> Self {
self.field_name = name.to_string();
self
}
pub fn header_name(mut self, name: &str) -> Self {
self.header_name = name.to_string();
self
}
pub fn http_only(mut self, http_only: bool) -> Self {
self.http_only = http_only;
self
}
pub fn secure(mut self, secure: bool) -> Self {
self.secure = secure;
self
}
fn generate_token(&self) -> String {
let mut bytes = [0u8; 32];
OsRng.fill_bytes(&mut bytes);
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
fn get_cookie_token(&self, req: &Request) -> Option<String> {
get_cookie(req, &self.cookie_name)
}
fn get_submitted_token(&self, req: &Request) -> Option<String> {
for h in &req.headers {
if h.name.eq_ignore_ascii_case(&self.header_name) {
let val = h.value.trim().to_string();
if !val.is_empty() {
return Some(val);
}
}
}
let is_form = req.headers.iter().any(|h| {
h.name.eq_ignore_ascii_case("content-type")
&& h.value
.to_lowercase()
.contains("application/x-www-form-urlencoded")
});
if is_form {
return get_form_field(&req.body, &self.field_name);
}
None
}
fn cookie_header_value(&self, token: &str) -> String {
let mut s = format!(
"{}={}; Path=/; SameSite={}",
self.cookie_name, token, self.same_site
);
if self.http_only {
s.push_str("; HttpOnly");
}
if self.secure {
s.push_str("; Secure");
}
s
}
}
impl Middleware for CsrfLayer {
fn handle(
&self,
request: &Request,
connection: &ConnectionInfo,
next: &dyn Application,
) -> Result<Response, String> {
if is_safe_method(&request.method) {
let token = self
.get_cookie_token(request)
.unwrap_or_else(|| self.generate_token());
let mut req = request.clone();
req.headers.push(Header {
name: INJECTED_HEADER.to_string(),
value: token.clone(),
});
let mut response = next.execute(&req, connection)?;
response.headers.push(Header {
name: "Set-Cookie".to_string(),
value: self.cookie_header_value(&token),
});
Ok(response)
} else {
let cookie_token = match self.get_cookie_token(request) {
Some(t) => t,
None => return Ok(AppError::Forbidden.into_response()),
};
let submitted_token = match self.get_submitted_token(request) {
Some(t) => t,
None => return Ok(AppError::Forbidden.into_response()),
};
if !ct_eq(cookie_token.as_bytes(), submitted_token.as_bytes()) {
return Ok(AppError::Forbidden.into_response());
}
next.execute(request, connection)
}
}
}
pub struct CsrfToken(String);
impl CsrfToken {
pub fn value(&self) -> &str {
&self.0
}
pub fn from_request(req: &Request) -> Option<Self> {
for h in &req.headers {
if h.name.eq_ignore_ascii_case(INJECTED_HEADER) {
return Some(CsrfToken(h.value.clone()));
}
}
None
}
}
impl std::fmt::Display for CsrfToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
fn is_safe_method(method: &str) -> bool {
matches!(method, "GET" | "HEAD" | "OPTIONS" | "TRACE")
}
fn get_cookie(req: &Request, name: &str) -> Option<String> {
for h in &req.headers {
if h.name.eq_ignore_ascii_case("cookie") {
for part in h.value.split(';') {
let part = part.trim();
if let Some(pos) = part.find('=') {
let k = part[..pos].trim();
if k.eq_ignore_ascii_case(name) {
return Some(part[pos + 1..].trim().to_string());
}
}
}
}
}
None
}
fn get_form_field(body: &[u8], field: &str) -> Option<String> {
let s = std::str::from_utf8(body).ok()?;
for pair in s.split('&') {
let mut parts = pair.splitn(2, '=');
let k = parts.next()?.trim();
if k == field {
return Some(parts.next().unwrap_or("").to_string());
}
}
None
}
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}