use std::{convert::TryInto, fmt, sync::Arc, time::SystemTime};
use arc_swap::ArcSwap;
use bytes::Bytes;
use cookie::{Cookie as RawCookie, CookieJar, Expiration, SameSite};
use http::Uri;
use crate::{
IntoUri,
error::Error,
ext::UriExt,
hash::{HASHER, HashMap},
header::HeaderValue,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Cookies {
Compressed(HeaderValue),
Uncompressed(Vec<HeaderValue>),
Empty,
}
pub trait CookieStore: Send + Sync {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, uri: &Uri);
fn cookies(&self, uri: &Uri) -> Cookies;
}
pub trait IntoCookieStore {
fn into_cookie_store(self) -> Arc<dyn CookieStore>;
}
pub trait IntoCookie {
fn into_cookie(self) -> Option<Cookie<'static>>;
}
#[derive(Debug, Clone)]
pub struct Cookie<'a>(RawCookie<'a>);
pub struct Jar {
compression: bool,
store: Arc<ArcSwap<HashMap<String, HashMap<String, CookieJar>>>>,
}
impl_request_config_value!(Arc<dyn CookieStore>);
impl IntoCookieStore for Arc<dyn CookieStore> {
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
self
}
}
impl<R> IntoCookieStore for Arc<R>
where
R: CookieStore + 'static,
{
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
self
}
}
impl<R> IntoCookieStore for R
where
R: CookieStore + 'static,
{
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
Arc::new(self)
}
}
impl IntoCookie for Cookie<'_> {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
Some(self.into_owned())
}
}
impl IntoCookie for RawCookie<'_> {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
Some(Cookie(self.into_owned()))
}
}
impl IntoCookie for &str {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
RawCookie::parse(self).map(|c| Cookie(c.into_owned())).ok()
}
}
impl<'a> Cookie<'a> {
pub(crate) fn parse(value: &'a HeaderValue) -> crate::Result<Cookie<'a>> {
std::str::from_utf8(value.as_bytes())
.map_err(cookie::ParseError::from)
.and_then(cookie::Cookie::parse)
.map_err(Error::decode)
.map(Cookie)
}
#[inline]
pub fn name(&self) -> &str {
self.0.name()
}
#[inline]
pub fn value(&self) -> &str {
self.0.value()
}
#[inline]
pub fn http_only(&self) -> bool {
self.0.http_only().unwrap_or(false)
}
#[inline]
pub fn secure(&self) -> bool {
self.0.secure().unwrap_or(false)
}
#[inline]
pub fn same_site_lax(&self) -> bool {
self.0.same_site() == Some(SameSite::Lax)
}
#[inline]
pub fn same_site_strict(&self) -> bool {
self.0.same_site() == Some(SameSite::Strict)
}
#[inline]
pub fn path(&self) -> Option<&str> {
self.0.path()
}
#[inline]
pub fn domain(&self) -> Option<&str> {
self.0.domain()
}
#[inline]
pub fn max_age(&self) -> Option<std::time::Duration> {
self.0.max_age().and_then(|d| d.try_into().ok())
}
#[inline]
pub fn expires(&self) -> Option<SystemTime> {
match self.0.expires() {
Some(Expiration::DateTime(offset)) => Some(SystemTime::from(offset)),
None | Some(Expiration::Session) => None,
}
}
#[inline]
pub fn into_owned(self) -> Cookie<'static> {
Cookie(self.0.into_owned())
}
}
impl fmt::Display for Cookie<'_> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl<'c> From<RawCookie<'c>> for Cookie<'c> {
#[inline]
fn from(cookie: RawCookie<'c>) -> Cookie<'c> {
Cookie(cookie)
}
}
impl<'c> From<Cookie<'c>> for RawCookie<'c> {
#[inline]
fn from(cookie: Cookie<'c>) -> RawCookie<'c> {
cookie.0
}
}
macro_rules! into_uri {
($expr:expr) => {
match $expr.into_uri() {
Ok(u) => u,
Err(_) => return,
}
};
}
impl Jar {
pub fn new(compression: bool) -> Self {
Self {
compression,
store: Arc::new(ArcSwap::from_pointee(HashMap::with_hasher(HASHER))),
}
}
pub fn compressed(self: &Arc<Self>) -> Arc<Self> {
Arc::new(Jar {
compression: true,
store: self.store.clone(),
})
}
pub fn uncompressed(self: &Arc<Self>) -> Arc<Self> {
Arc::new(Jar {
compression: false,
store: self.store.clone(),
})
}
pub fn get<U: IntoUri>(&self, name: &str, uri: U) -> Option<Cookie<'static>> {
let uri = uri.into_uri().ok()?;
let store = self.store.load();
let cookie = store
.get(uri.host()?)?
.get(uri.path())?
.get(name)?
.clone()
.into_owned();
Some(Cookie(cookie))
}
pub fn get_all(&self) -> impl Iterator<Item = Cookie<'static>> {
let store = self.store.load();
store
.values()
.flat_map(|path_map| {
path_map.values().flat_map(|name_map| {
name_map
.iter()
.map(|cookie| Cookie(cookie.clone().into_owned()))
})
})
.collect::<Vec<_>>()
.into_iter()
}
pub fn add_cookie_str<U: IntoUri>(&self, cookie: &str, uri: U) {
self.add(cookie, uri);
}
pub fn add<C, U>(&self, cookie: C, uri: U)
where
C: IntoCookie,
U: IntoUri,
{
if let Some(cookie) = cookie.into_cookie() {
let cookie: RawCookie<'static> = cookie.into();
let uri = into_uri!(uri);
let domain = cookie
.domain()
.map(normalize_domain)
.or_else(|| uri.host())
.unwrap_or_default();
let path = cookie.path().unwrap_or_else(|| normalize_path(&uri));
self.store.rcu(|current| {
let mut inner = (**current).clone();
let name_map = inner
.entry(domain.to_owned())
.or_insert_with(|| HashMap::with_hasher(HASHER))
.entry(path.to_owned())
.or_default();
let expired = cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
|| cookie.max_age().is_some_and(|age| age.is_zero());
if expired {
name_map.remove(cookie.clone());
} else {
name_map.add(cookie.clone());
}
inner
});
}
}
pub fn add_cookie<C, U>(&self, cookie: C, uri: U)
where
C: Into<RawCookie<'static>>,
U: IntoUri,
{
let cookie: RawCookie<'static> = cookie.into();
let uri = into_uri!(uri);
let domain = cookie
.domain()
.map(normalize_domain)
.or_else(|| uri.host())
.unwrap_or_default();
let path = cookie.path().unwrap_or_else(|| normalize_path(&uri));
self.store.rcu(|current| {
let mut inner = (**current).clone();
let name_map = inner
.entry(domain.to_owned())
.or_insert_with(|| HashMap::with_hasher(HASHER))
.entry(path.to_owned())
.or_default();
let expired = cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
|| cookie.max_age().is_some_and(|age| age.is_zero());
if expired {
name_map.remove(cookie.clone());
} else {
name_map.add(cookie.clone());
}
inner
});
}
pub fn remove<C, U>(&self, cookie: C, uri: U)
where
C: Into<RawCookie<'static>>,
U: IntoUri,
{
let cookie_raw = cookie.into();
let uri = into_uri!(uri);
if let Some(host) = uri.host() {
let path = uri.path().to_owned();
self.store.rcu(|current| {
let mut inner = (**current).clone();
if let Some(path_map) = inner.get_mut(host)
&& let Some(name_map) = path_map.get_mut(&path)
{
name_map.remove(cookie_raw.clone());
}
inner
});
}
}
pub fn clear(&self) {
self.store.store(Arc::new(HashMap::with_hasher(HASHER)));
}
}
impl CookieStore for Jar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, uri: &Uri) {
let cookies = cookie_headers
.map(Cookie::parse)
.filter_map(Result::ok)
.map(|cookie| cookie.0.into_owned());
for cookie in cookies {
self.add_cookie(cookie, uri);
}
}
fn cookies(&self, uri: &Uri) -> Cookies {
let host = match uri.host() {
Some(h) => h,
None => return Cookies::Empty,
};
let store = self.store.load();
let iter = store
.iter()
.filter(|(domain, _)| domain_match(host, domain))
.flat_map(|(_, path_map)| {
path_map
.iter()
.filter(|(path, _)| path_match(uri.path(), path))
.flat_map(|(_, name_map)| {
name_map.iter().filter(|cookie| {
if cookie.secure() == Some(true) && uri.is_http() {
return false;
}
if cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
{
return false;
}
true
})
})
});
if self.compression {
let cookies = iter.fold(String::new(), |mut cookies, cookie| {
if !cookies.is_empty() {
cookies.push_str("; ");
}
cookies.push_str(cookie.name());
cookies.push('=');
cookies.push_str(cookie.value());
cookies
});
if cookies.is_empty() {
return Cookies::Empty;
}
HeaderValue::from_maybe_shared(Bytes::from(cookies))
.map(Cookies::Compressed)
.unwrap_or(Cookies::Empty)
} else {
let cookies = iter
.map(|cookie| {
let name = cookie.name();
let value = cookie.value();
let mut cookie_str = String::with_capacity(name.len() + 1 + value.len());
cookie_str.push_str(name);
cookie_str.push('=');
cookie_str.push_str(value);
HeaderValue::from_maybe_shared(Bytes::from(cookie_str))
})
.filter_map(Result::ok)
.collect();
Cookies::Uncompressed(cookies)
}
}
}
impl Default for Jar {
fn default() -> Self {
Self::new(true)
}
}
impl fmt::Debug for Jar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Jar")
.field("compression", &self.compression)
.finish()
}
}
const DEFAULT_PATH: &str = "/";
fn domain_match(host: &str, domain: &str) -> bool {
if domain.is_empty() {
return false;
}
if host == domain {
return true;
}
host.len() > domain.len()
&& host.as_bytes()[host.len() - domain.len() - 1] == b'.'
&& host.ends_with(domain)
}
fn path_match(req_path: &str, cookie_path: &str) -> bool {
req_path == cookie_path
|| req_path.starts_with(cookie_path)
&& (cookie_path.ends_with(DEFAULT_PATH)
|| req_path[cookie_path.len()..].starts_with(DEFAULT_PATH))
}
fn normalize_domain(domain: &str) -> &str {
domain.split(':').next().unwrap_or(domain)
}
fn normalize_path(uri: &Uri) -> &str {
let path = uri.path();
if !path.starts_with(DEFAULT_PATH) {
return DEFAULT_PATH;
}
if let Some(pos) = path.rfind(DEFAULT_PATH) {
if pos == 0 {
return DEFAULT_PATH;
}
return &path[..pos];
}
DEFAULT_PATH
}