use crate::{
HttpRequest,
di::Container,
error::Error,
headers::HeaderMap,
http::{
Extensions, Parts, Request,
body::Incoming,
cookie::get_cookies,
endpoints::args::{
FromPayload, FromRawRequest, FromRequestParts, FromRequestRef, Payload, Source,
},
},
};
use cookie::{CookieJar, Key, SignedJar};
use futures_util::future::{Ready, ready};
use std::{fs::File, io::Read, path::Path};
#[derive(Clone)]
pub struct SignedKey(Key);
pub struct SignedCookies(SignedKey, CookieJar);
impl Default for SignedKey {
#[inline]
fn default() -> Self {
Self::from(&[])
}
}
impl std::fmt::Debug for SignedCookies {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SignedCookies").field(&"[redacted]").finish()
}
}
impl std::fmt::Debug for SignedKey {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SignedKey").field(&"[redacted]").finish()
}
}
impl SignedKey {
#[inline]
pub fn from(bytes: &[u8]) -> Self {
Self(Key::from(bytes))
}
#[inline]
pub fn from_file(path: impl AsRef<Path>) -> Self {
let mut file = File::open(path).expect("File must exists");
let mut buffer = [0u8; 64];
file.read_exact(&mut buffer).expect("File must be readable");
Self(Key::from(&buffer))
}
#[inline]
pub fn generate() -> Self {
Self(Key::generate())
}
}
impl SignedCookies {
#[inline]
pub fn new(key: SignedKey) -> Self {
Self(key, CookieJar::default())
}
#[inline]
pub fn from_headers(key: SignedKey, headers: &HeaderMap) -> Self {
let mut jar = CookieJar::new();
let mut signed_jar = jar.signed_mut(&key.0);
for cookie in get_cookies(headers) {
if let Some(cookie) = signed_jar.verify(cookie) {
signed_jar.add_original(cookie);
}
}
Self(key, jar)
}
#[inline]
pub fn into_parts(self) -> (SignedKey, CookieJar) {
(self.0, self.1)
}
pub fn get(&self, name: &str) -> Option<cookie::Cookie<'static>> {
self.signed().get(name)
}
#[allow(clippy::should_implement_trait)]
pub fn add<C: Into<cookie::Cookie<'static>>>(mut self, cookie: C) -> Self {
self.signed_mut().add(cookie);
self
}
pub fn remove<C: Into<cookie::Cookie<'static>>>(mut self, cookie: C) -> Self {
self.signed_mut().remove(cookie);
self
}
pub fn verify(&self, cookie: cookie::Cookie<'static>) -> Option<cookie::Cookie<'static>> {
self.signed().verify(cookie)
}
pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> + '_ {
self.1.iter()
}
#[inline]
fn signed(&self) -> SignedJar<&'_ CookieJar> {
self.1.signed(&self.0.0)
}
#[inline]
fn signed_mut(&mut self) -> SignedJar<&'_ mut CookieJar> {
self.1.signed_mut(&self.0.0)
}
}
impl TryFrom<&Parts> for SignedCookies {
type Error = Error;
#[inline]
fn try_from(parts: &Parts) -> Result<Self, Self::Error> {
Container::try_from(parts)?
.resolve::<SignedKey>()
.map(|key| SignedCookies::from_headers(key, &parts.headers))
.map_err(Into::into)
}
}
impl TryFrom<(&Extensions, &HeaderMap)> for SignedCookies {
type Error = Error;
#[inline]
fn try_from((extensions, headers): (&Extensions, &HeaderMap)) -> Result<Self, Self::Error> {
Container::try_from(extensions)?
.resolve::<SignedKey>()
.map(|key| SignedCookies::from_headers(key, headers))
.map_err(Into::into)
}
}
impl FromRequestRef for SignedCookies {
#[inline]
fn from_request(req: &HttpRequest) -> Result<Self, Error> {
Self::try_from((req.extensions(), req.headers()))
}
}
impl FromRawRequest for SignedCookies {
#[inline]
fn from_request(req: Request<Incoming>) -> impl Future<Output = Result<Self, Error>> + Send {
ready(Self::try_from((req.extensions(), req.headers())))
}
}
impl FromRequestParts for SignedCookies {
#[inline]
fn from_parts(parts: &Parts) -> Result<Self, Error> {
parts.try_into()
}
}
impl FromPayload for SignedCookies {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
#[inline]
fn from_payload(payload: Payload<'_>) -> Self::Future {
let Payload::Parts(parts) = payload else {
unreachable!()
};
ready(Self::from_parts(parts))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::di::ContainerBuilder;
use crate::headers::{COOKIE, SET_COOKIE};
use crate::http::{HttpBody, HttpRequest, Request, cookie::set_cookies};
#[test]
fn it_creates_cookies_from_empty_headers() {
let key = SignedKey(Key::generate());
let cookies = SignedCookies::new(key);
assert_eq!(cookies.iter().count(), 0);
}
#[test]
fn it_creates_cookies() {
let key = SignedKey::from(
b"f3d9e2a44c6b172a1ea9b9d05e5fe1bcaa8679d032ccae271c503af9618bb2ef7c4e51452dbfcd96f6e9c9d09166a3de77e");
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("session", "abc123"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let cookies = SignedCookies::from_headers(key, &headers);
let cookie = cookies.get("session").expect("Cookie should exist");
assert_eq!(cookie.value(), "abc123");
}
#[test]
fn it_creates_from_multiple_cookies() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies
.add(("session", "abc123"))
.add(("user", "john"))
.add(("theme", "dark"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let cookies = SignedCookies::from_headers(key, &headers);
assert_eq!(cookies.get("session").unwrap().value(), "abc123");
assert_eq!(cookies.get("user").unwrap().value(), "john");
assert_eq!(cookies.get("theme").unwrap().value(), "dark");
}
#[test]
fn it_adds_and_removes_cookies() {
let key = SignedKey::generate();
let mut cookies = SignedCookies::new(key);
cookies = cookies.add(cookie::Cookie::new("test", "value"));
assert_eq!(cookies.get("test").unwrap().value(), "value");
cookies = cookies.remove(cookie::Cookie::new("test", ""));
assert!(cookies.get("test").is_none());
}
#[test]
fn it_sets_cookies_to_headers() {
let key = SignedKey::generate();
let mut cookies = SignedCookies::new(key);
cookies = cookies.add(cookie::Cookie::new("session", "xyz789"));
let mut headers = HeaderMap::new();
set_cookies(cookies.1, &mut headers);
let cookie_header = headers
.get(SET_COOKIE)
.expect("Cookie header should be set");
assert!(cookie_header.to_str().unwrap().contains("session"));
}
#[tokio::test]
async fn it_extracts_from_payload() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("test", "value"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let mut container = ContainerBuilder::new();
container.register_singleton(key);
let container = container.build();
let mut request = Request::builder()
.extension(container.create_scope())
.body(())
.unwrap();
request.headers_mut().extend(headers);
let (parts, _) = request.into_parts();
let payload = Payload::Parts(&parts);
let cookies = SignedCookies::from_payload(payload).await.unwrap();
assert_eq!(cookies.get("test").unwrap().value(), "value");
}
#[test]
fn it_extracts_from_request_ref() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("test", "value"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let mut container = ContainerBuilder::new();
container.register_singleton(key);
let container = container.build();
let mut request = Request::builder()
.extension(container.create_scope())
.body(HttpBody::empty())
.unwrap();
request.headers_mut().extend(headers);
let (parts, body) = request.into_parts();
let req = HttpRequest::from_parts(parts, body);
let cookies = <SignedCookies as FromRequestRef>::from_request(&req).unwrap();
assert_eq!(cookies.get("test").unwrap().value(), "value");
}
#[test]
fn it_extracts_from_parts() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("test", "value"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let mut container = ContainerBuilder::new();
container.register_singleton(key);
let container = container.build();
let mut request = Request::builder()
.extension(container.create_scope())
.body(())
.unwrap();
request.headers_mut().extend(headers);
let (parts, _) = request.into_parts();
let cookies = SignedCookies::from_parts(&parts).unwrap();
assert_eq!(cookies.get("test").unwrap().value(), "value");
}
#[test]
fn it_tries_extracts_from_parts() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("test", "value"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let mut container = ContainerBuilder::new();
container.register_singleton(key);
let container = container.build();
let mut request = Request::builder()
.extension(container.create_scope())
.body(())
.unwrap();
request.headers_mut().extend(headers);
let (parts, _) = request.into_parts();
let cookies = SignedCookies::try_from(&parts).unwrap();
assert_eq!(cookies.get("test").unwrap().value(), "value");
}
#[test]
fn it_tries_extracts_from_extensions_and_headers() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
let cookies = cookies.add(("test", "value"));
let mut headers = HeaderMap::new();
set_cookies_for_request(cookies.1, &mut headers);
let mut container = ContainerBuilder::new();
container.register_singleton(key);
let container = container.build();
let mut request = Request::builder()
.extension(container.create_scope())
.body(())
.unwrap();
request.headers_mut().extend(headers);
let cookies = SignedCookies::try_from((request.extensions(), request.headers())).unwrap();
assert_eq!(cookies.get("test").unwrap().value(), "value");
}
#[test]
fn if_return_parts_source() {
assert_eq!(SignedCookies::SOURCE, Source::Parts);
}
#[tokio::test]
async fn it_reads_signed_key_from_bytes() {
let temp_file = crate::test::TempFile::new(
"f3d9e2a44c6b172a1ea9b9d05e5fe1bcaa8679d032ccae271c503af9618bb2ef7c4e51452dbfcd96f6e9c9d09166a3de77e"
).await;
let _ = SignedKey::from_file(&temp_file.path);
}
#[test]
fn it_debugs() {
let key = SignedKey::generate();
let cookies = SignedCookies::new(key.clone());
assert_eq!(format!("{key:?}"), r#"SignedKey("[redacted]")"#);
assert_eq!(format!("{cookies:?}"), r#"SignedCookies("[redacted]")"#);
}
fn set_cookies_for_request(jar: CookieJar, headers: &mut HeaderMap) {
for cookie in jar.delta() {
if let Ok(header_value) = cookie.encoded().to_string().parse() {
headers.append(COOKIE, header_value);
}
}
}
}