use std::{
convert::Infallible,
fmt::{self, Display, Formatter},
str::FromStr,
sync::Arc,
time::Duration,
};
use chrono::{DateTime, TimeZone, Utc};
use http::HeaderValue;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use crate::{
FromRequest, Request, RequestBody, Result,
error::ParseCookieError,
http::{HeaderMap, header},
};
pub type SameSite = libcookie::SameSite;
#[derive(Clone, Debug, PartialEq)]
pub struct Cookie(libcookie::Cookie<'static>);
impl Display for Cookie {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.0.encoded().fmt(f)
}
}
impl Cookie {
pub fn new(name: impl Into<String>, value: impl Serialize) -> Self {
#[cfg(not(feature = "sonic-rs"))]
{
Self(libcookie::Cookie::new(
name.into(),
serde_json::to_string(&value).unwrap_or_default(),
))
}
#[cfg(feature = "sonic-rs")]
{
Self(libcookie::Cookie::new(
name.into(),
sonic_rs::to_string(&value).unwrap_or_default(),
))
}
}
pub fn new_with_str(name: impl Into<String>, value: impl Into<String>) -> Self {
Self(libcookie::Cookie::new(name.into(), value.into()))
}
pub fn named(name: impl Into<String>) -> Self {
Self::new_with_str(name, "")
}
pub fn parse(s: impl AsRef<str>) -> Result<Self, ParseCookieError> {
Ok(Self(
libcookie::Cookie::parse_encoded(s.as_ref().to_string())
.map_err(|_| ParseCookieError::CookieIllegal)?,
))
}
pub fn domain(&self) -> Option<&str> {
self.0.domain()
}
pub fn expires(&self) -> Option<DateTime<Utc>> {
self.0
.expires_datetime()
.and_then(|t| Utc.timestamp_opt(t.unix_timestamp(), 0).single())
}
pub fn http_only(&self) -> bool {
self.0.http_only().unwrap_or_default()
}
pub fn make_permanent(&mut self) {
self.0.make_permanent();
}
pub fn make_removal(&mut self) {
self.0.make_removal();
}
pub fn max_age(&self) -> Option<Duration> {
self.0.max_age().map(|d| {
let seconds = d.whole_seconds().max(0) as u64;
let nano_seconds = d.subsec_nanoseconds().max(0) as u32;
Duration::new(seconds, nano_seconds)
})
}
pub fn name(&self) -> &str {
self.0.name()
}
pub fn path(&self) -> Option<&str> {
self.0.path()
}
pub fn same_site(&self) -> Option<SameSite> {
self.0.same_site()
}
pub fn secure(&self) -> bool {
self.0.secure().unwrap_or_default()
}
pub fn partitioned(&self) -> bool {
self.0.partitioned().unwrap_or_default()
}
pub fn set_domain(&mut self, domain: impl Into<String>) {
self.0.set_domain(domain.into());
}
pub fn set_expires(&mut self, time: DateTime<impl TimeZone>) {
self.0.set_expires(libcookie::Expiration::DateTime(
time::OffsetDateTime::from_unix_timestamp(time.timestamp()).unwrap(),
));
}
pub fn set_http_only(&mut self, value: impl Into<Option<bool>>) {
self.0.set_http_only(value);
}
pub fn set_max_age(&mut self, value: Duration) {
self.0.set_max_age(Some(time::Duration::new(
value.as_secs() as i64,
value.subsec_nanos() as i32,
)));
}
pub fn set_name(&mut self, name: impl Into<String>) {
self.0.set_name(name.into());
}
pub fn set_path(&mut self, path: impl Into<String>) {
self.0.set_path(path.into());
}
pub fn set_same_site(&mut self, value: impl Into<Option<SameSite>>) {
self.0.set_same_site(value);
}
pub fn set_secure(&mut self, value: impl Into<Option<bool>>) {
self.0.set_secure(value);
}
pub fn set_partitioned(&mut self, value: impl Into<Option<bool>>) {
self.0.set_partitioned(value);
}
pub fn set_value_str(&mut self, value: impl Into<String>) {
self.0.set_value(value.into());
}
pub fn set_value(&mut self, value: impl Serialize) {
#[cfg(not(feature = "sonic-rs"))]
let json_string = serde_json::to_string(&value);
#[cfg(feature = "sonic-rs")]
let json_string = sonic_rs::to_string(&value);
if let Ok(value) = json_string {
self.0.set_value(value);
}
}
pub fn value_str(&self) -> &str {
self.0.value()
}
pub fn value<'de, T: Deserialize<'de>>(&'de self) -> Result<T, ParseCookieError> {
#[cfg(not(feature = "sonic-rs"))]
{
serde_json::from_str(self.0.value()).map_err(ParseCookieError::ParseJsonValue)
}
#[cfg(feature = "sonic-rs")]
{
sonic_rs::from_str(self.0.value()).map_err(ParseCookieError::ParseJsonValue)
}
}
}
impl<'a> FromRequest<'a> for Cookie {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
let value = req
.headers()
.get(header::COOKIE)
.ok_or(ParseCookieError::CookieHeaderRequired)?;
let value = value
.to_str()
.map_err(|_| ParseCookieError::CookieIllegal)?;
let cookie = libcookie::Cookie::parse_encoded(value.to_string())
.map_err(|_| ParseCookieError::CookieIllegal)?;
Ok(Cookie(cookie))
}
}
impl From<libcookie::Cookie<'static>> for Cookie {
fn from(value: libcookie::Cookie<'static>) -> Self {
Self(value)
}
}
#[derive(Default, Clone)]
pub struct CookieJar {
jar: Arc<Mutex<libcookie::CookieJar>>,
pub(crate) key: Option<Arc<CookieKey>>,
}
impl CookieJar {
pub fn add(&self, cookie: Cookie) {
self.jar.lock().add(cookie.0);
}
pub fn remove(&self, name: impl AsRef<str>) {
self.jar
.lock()
.remove(libcookie::Cookie::build(name.as_ref().to_string()));
}
pub fn get(&self, name: &str) -> Option<Cookie> {
self.jar.lock().get(name).cloned().map(Cookie)
}
pub fn get_ignore_ascii_case(&self, name: &str) -> Option<Cookie> {
self.jar
.lock()
.iter()
.find(|cookie| cookie.name().eq_ignore_ascii_case(name))
.cloned()
.map(Cookie)
}
pub fn reset_delta(&self) {
self.jar.lock().reset_delta();
}
pub fn with_cookies<F, R>(&self, f: F) -> R
where
F: FnOnce(libcookie::Iter) -> R,
{
let jar = self.jar.lock();
let iter = jar.iter();
f(iter)
}
pub fn private_with_key<'a>(&'a self, key: &'a CookieKey) -> PrivateCookieJar<'a> {
PrivateCookieJar {
key,
cookie_jar: self,
}
}
pub fn private(&self) -> PrivateCookieJar<'_> {
self.private_with_key(
self.key
.as_ref()
.expect("You must use the `CookieJarManager::with_key` to specify a `CookieKey`."),
)
}
pub fn signed_with_key<'a>(&'a self, key: &'a CookieKey) -> SignedCookieJar<'a> {
SignedCookieJar {
key,
cookie_jar: self,
}
}
pub fn signed(&self) -> SignedCookieJar<'_> {
self.signed_with_key(
self.key
.as_ref()
.expect("You must use the `CookieJarManager::with_key` to specify a `CookieKey`."),
)
}
}
impl FromStr for CookieJar {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut cookie_jar = libcookie::CookieJar::new();
for cookie_str in s.split(';').map(str::trim) {
if let Ok(cookie) = libcookie::Cookie::parse_encoded(cookie_str) {
cookie_jar.add_original(cookie.into_owned());
}
}
Ok(CookieJar {
jar: Arc::new(Mutex::new(cookie_jar)),
key: None,
})
}
}
impl<'a> FromRequest<'a> for &'a CookieJar {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req.cookie())
}
}
impl CookieJar {
pub(crate) fn extract_from_headers(headers: &HeaderMap) -> Self {
let mut cookie_jar = libcookie::CookieJar::new();
for value in headers.get_all(header::COOKIE) {
if let Ok(value) = value.to_str() {
for cookie_str in value.split(';').map(str::trim) {
if let Ok(cookie) = libcookie::Cookie::parse_encoded(cookie_str) {
if cookie_jar.get(cookie.name()).is_none() {
cookie_jar.add_original(cookie.into_owned());
}
}
}
}
}
CookieJar {
jar: Arc::new(Mutex::new(cookie_jar)),
key: None,
}
}
pub(crate) fn append_delta_to_headers(&self, headers: &mut HeaderMap) {
let cookie = self.jar.lock();
for cookie in cookie.delta() {
let value = cookie.encoded().to_string();
if let Ok(value) = HeaderValue::from_str(&value) {
headers.append(header::SET_COOKIE, value);
}
}
}
}
pub type CookieKey = libcookie::Key;
pub struct PrivateCookieJar<'a> {
key: &'a CookieKey,
cookie_jar: &'a CookieJar,
}
impl PrivateCookieJar<'_> {
pub fn add(&self, cookie: Cookie) {
let mut cookie_jar = self.cookie_jar.jar.lock();
let mut private_cookie_jar = cookie_jar.private_mut(self.key);
private_cookie_jar.add(cookie.0);
}
pub fn remove(&self, name: impl AsRef<str>) {
let mut cookie_jar = self.cookie_jar.jar.lock();
let mut private_cookie_jar = cookie_jar.private_mut(self.key);
private_cookie_jar.remove(libcookie::Cookie::build(name.as_ref().to_string()));
}
pub fn get(&self, name: &str) -> Option<Cookie> {
let cookie_jar = self.cookie_jar.jar.lock();
let private_cookie_jar = cookie_jar.private(self.key);
private_cookie_jar.get(name).map(Cookie)
}
pub fn get_ignore_ascii_case(&self, name: &str) -> Option<Cookie> {
let cookie_jar = self.cookie_jar.jar.lock();
cookie_jar
.iter()
.find(|cookie| cookie.name().eq_ignore_ascii_case(name))
.and_then(|cookie| cookie_jar.private(self.key).decrypt(cookie.clone()))
.map(Cookie)
}
}
pub struct SignedCookieJar<'a> {
key: &'a CookieKey,
cookie_jar: &'a CookieJar,
}
impl SignedCookieJar<'_> {
pub fn add(&self, cookie: Cookie) {
let mut cookie_jar = self.cookie_jar.jar.lock();
let mut signed_cookie_jar = cookie_jar.signed_mut(self.key);
signed_cookie_jar.add(cookie.0);
}
pub fn remove(&self, name: impl AsRef<str>) {
let mut cookie_jar = self.cookie_jar.jar.lock();
let mut signed_cookie_jar = cookie_jar.signed_mut(self.key);
signed_cookie_jar.remove(libcookie::Cookie::build(name.as_ref().to_string()));
}
pub fn get(&self, name: &str) -> Option<Cookie> {
let cookie_jar = self.cookie_jar.jar.lock();
let signed_cookie_jar = cookie_jar.signed(self.key);
signed_cookie_jar.get(name).map(Cookie)
}
pub fn get_ignore_ascii_case(&self, name: &str) -> Option<Cookie> {
let cookie_jar = self.cookie_jar.jar.lock();
cookie_jar
.iter()
.find(|cookie| cookie.name().eq_ignore_ascii_case(name))
.and_then(|cookie| cookie_jar.signed(self.key).verify(cookie.clone()))
.map(Cookie)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cookie_jar() {
let a = Cookie::new_with_str("a", 100.to_string());
let b = Cookie::new_with_str("b", 200.to_string());
let c = Cookie::new_with_str("c", 300.to_string());
let cookie_str = format!("{a}; {b}");
let cookie_jar = CookieJar::from_str(&cookie_str).unwrap();
assert_eq!(cookie_jar.get("a").unwrap(), a);
assert_eq!(cookie_jar.get("b").unwrap(), b);
{
cookie_jar.add(c.clone());
let mut headers = HeaderMap::new();
cookie_jar.append_delta_to_headers(&mut headers);
let mut values = headers.get_all(header::SET_COOKIE).into_iter();
assert_eq!(
values.next().unwrap(),
&HeaderValue::from_str(&c.to_string()).unwrap()
);
assert!(values.next().is_none());
}
{
cookie_jar.reset_delta();
cookie_jar.remove("a");
let mut headers = HeaderMap::new();
cookie_jar.append_delta_to_headers(&mut headers);
let mut values = headers.get_all(header::SET_COOKIE).into_iter();
let value = values.next().unwrap();
let remove_c = Cookie::parse(value.to_str().unwrap()).unwrap();
assert_eq!(remove_c.name(), "a");
assert_eq!(remove_c.value_str(), "");
assert!(values.next().is_none());
}
}
#[tokio::test]
async fn test_cookie_extractor() {
let req = Request::builder()
.header(header::COOKIE, Cookie::new_with_str("a", "1").to_string())
.finish();
let (req, mut body) = req.split();
let cookie = Cookie::from_request(&req, &mut body).await.unwrap();
assert_eq!(cookie.name(), "a");
assert_eq!(cookie.value_str(), "1");
}
#[tokio::test]
async fn private() {
let key = CookieKey::generate();
let cookie_jar = CookieJar::default();
let private = cookie_jar.private_with_key(&key);
private.add(Cookie::new_with_str("a", "123"));
assert_eq!(private.get("a").unwrap().value_str(), "123");
assert!(!cookie_jar.get("a").unwrap().value_str().contains("123"));
let new_key = CookieKey::generate();
let private = cookie_jar.private_with_key(&new_key);
assert_eq!(private.get("a"), None);
}
#[tokio::test]
async fn signed() {
let key = CookieKey::generate();
let cookie_jar = CookieJar::default();
let signed = cookie_jar.signed_with_key(&key);
signed.add(Cookie::new_with_str("a", "123"));
assert_eq!(signed.get("a").unwrap().value_str(), "123");
assert!(cookie_jar.get("a").unwrap().value_str().contains("123"));
let new_key = CookieKey::generate();
let signed = cookie_jar.signed_with_key(&new_key);
assert_eq!(signed.get("a"), None);
}
#[test]
fn test_extract_from_multiple_cookie_headers() {
let mut headers = HeaderMap::new();
headers.append(header::COOKIE, HeaderValue::from_static("a=1"));
headers.append(header::COOKIE, HeaderValue::from_static("b=2; c=3"));
let cookie_jar = CookieJar::extract_from_headers(&headers);
assert_eq!(cookie_jar.get("a").unwrap().value_str(), "1");
assert_eq!(cookie_jar.get("b").unwrap().value_str(), "2");
assert_eq!(cookie_jar.get("c").unwrap().value_str(), "3");
}
#[test]
fn with_cookies() {
let key = CookieKey::generate();
let cookie_jar = CookieJar::default();
let signed = cookie_jar.signed_with_key(&key);
let private = cookie_jar.private_with_key(&key);
cookie_jar.add(Cookie::new_with_str("a", "123"));
signed.add(Cookie::new_with_str("b", "456"));
private.add(Cookie::new_with_str("c", "789"));
cookie_jar.with_cookies(|cookies| assert_eq!(cookies.count(), 3));
let mut cookie_names = cookie_jar
.with_cookies(|cookies| cookies.map(|c| c.name().to_string()).collect::<Vec<_>>());
cookie_names.sort();
assert_eq!(
cookie_names,
vec![String::from("a"), String::from("b"), String::from("c")]
);
}
}