use std::{fmt, sync::Arc, time::SystemTime};
use bytes::Bytes;
use cookie::{Cookie as RawCookie, CookieJar, Expiration, SameSite};
use http::Uri;
use scc::HashMap as SccHashMap;
use crate::{IntoUri, error::Error, ext::UriExt, header::HeaderValue};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Cookies {
Compressed(HeaderValue),
Uncompressed(Vec<HeaderValue>),
Empty,
}
pub trait CookieStore: Send + Sync {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, uri: &Uri);
fn cookies(&self, uri: &Uri) -> Cookies;
}
pub trait IntoCookieStore {
fn into_cookie_store(self) -> Arc<dyn CookieStore>;
}
pub trait IntoCookie {
fn into_cookie(self) -> Option<Cookie<'static>>;
}
#[derive(Debug, Clone)]
pub struct Cookie<'a>(RawCookie<'a>);
pub struct Jar {
compression: bool,
store: Arc<SccHashMap<String, SccHashMap<String, CookieJar>>>,
}
impl_request_config_value!(Arc<dyn CookieStore>);
impl IntoCookieStore for Arc<dyn CookieStore> {
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
self
}
}
impl<R> IntoCookieStore for Arc<R>
where
R: CookieStore + 'static,
{
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
self
}
}
impl<R> IntoCookieStore for R
where
R: CookieStore + 'static,
{
#[inline]
fn into_cookie_store(self) -> Arc<dyn CookieStore> {
Arc::new(self)
}
}
impl IntoCookie for Cookie<'_> {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
Some(self.into_owned())
}
}
impl IntoCookie for RawCookie<'_> {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
Some(Cookie(self.into_owned()))
}
}
impl IntoCookie for &str {
#[inline]
fn into_cookie(self) -> Option<Cookie<'static>> {
RawCookie::parse(self).map(|c| Cookie(c.into_owned())).ok()
}
}
impl<'a> Cookie<'a> {
pub(crate) fn parse(value: &'a HeaderValue) -> crate::Result<Cookie<'a>> {
std::str::from_utf8(value.as_bytes())
.map_err(cookie::ParseError::from)
.and_then(cookie::Cookie::parse)
.map_err(Error::decode)
.map(Cookie)
}
#[inline]
pub fn name(&self) -> &str {
self.0.name()
}
#[inline]
pub fn value(&self) -> &str {
self.0.value()
}
#[inline]
pub fn http_only(&self) -> bool {
self.0.http_only().unwrap_or(false)
}
#[inline]
pub fn secure(&self) -> bool {
self.0.secure().unwrap_or(false)
}
#[inline]
pub fn same_site_lax(&self) -> bool {
self.0.same_site() == Some(SameSite::Lax)
}
#[inline]
pub fn same_site_strict(&self) -> bool {
self.0.same_site() == Some(SameSite::Strict)
}
#[inline]
pub fn path(&self) -> Option<&str> {
self.0.path()
}
#[inline]
pub fn domain(&self) -> Option<&str> {
self.0.domain()
}
#[inline]
pub fn max_age(&self) -> Option<std::time::Duration> {
self.0.max_age().and_then(|d| d.try_into().ok())
}
#[inline]
pub fn expires(&self) -> Option<SystemTime> {
match self.0.expires() {
Some(Expiration::DateTime(offset)) => Some(SystemTime::from(offset)),
None | Some(Expiration::Session) => None,
}
}
#[inline]
pub fn into_owned(self) -> Cookie<'static> {
Cookie(self.0.into_owned())
}
}
impl fmt::Display for Cookie<'_> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}
impl<'c> From<RawCookie<'c>> for Cookie<'c> {
#[inline]
fn from(cookie: RawCookie<'c>) -> Cookie<'c> {
Cookie(cookie)
}
}
impl<'c> From<Cookie<'c>> for RawCookie<'c> {
#[inline]
fn from(cookie: Cookie<'c>) -> RawCookie<'c> {
cookie.0
}
}
impl Jar {
pub fn new(compression: bool) -> Self {
Self {
compression,
store: Arc::new(SccHashMap::new()),
}
}
pub fn compressed(self: &Arc<Self>) -> Arc<Self> {
Arc::new(Jar {
compression: true,
store: self.store.clone(),
})
}
pub fn uncompressed(self: &Arc<Self>) -> Arc<Self> {
Arc::new(Jar {
compression: false,
store: self.store.clone(),
})
}
pub fn get<U: IntoUri>(&self, name: &str, uri: U) -> Option<Cookie<'static>> {
let uri = uri.into_uri().ok()?;
let host = uri.host()?;
let path = uri.path();
let path_entry = self.store.get_sync(host)?;
let cookie_jar = path_entry.get_sync(path)?;
cookie_jar
.get()
.get(name)
.map(|c| Cookie(c.clone().into_owned()))
}
pub fn get_all(&self) -> Vec<Cookie<'static>> {
let mut cookies = Vec::new();
self.store.iter_sync(|_, path_map| {
path_map.iter_sync(|_, cookie_jar| {
for cookie in cookie_jar.iter() {
cookies.push(Cookie(cookie.clone().into_owned()));
}
true
});
true
});
cookies
}
pub fn add_cookie_str<U: IntoUri>(&self, cookie: &str, uri: U) {
self.add(cookie, uri);
}
#[allow(clippy::unwrap_or_default)]
pub fn add<C, U>(&self, cookie: C, uri: U)
where
C: IntoCookie,
U: IntoUri,
{
let Some(cookie) = cookie.into_cookie() else {
return;
};
let cookie: RawCookie<'static> = cookie.into();
let Ok(uri) = uri.into_uri() else { return };
let domain = cookie
.domain()
.map(normalize_domain)
.or_else(|| uri.host())
.unwrap_or_default();
let path = cookie.path().unwrap_or_else(|| normalize_path(&uri));
let expired = cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
|| cookie.max_age().is_some_and(|age| age.is_zero());
self.store
.entry_sync(domain.to_owned())
.or_insert_with(SccHashMap::new)
.get()
.entry_sync(path.to_owned())
.or_insert_with(CookieJar::default)
.get_mut()
.remove(cookie.name().to_owned());
if !expired {
self.store
.entry_sync(domain.to_owned())
.or_insert_with(SccHashMap::new)
.get()
.entry_sync(path.to_owned())
.or_insert_with(CookieJar::default)
.get_mut()
.add(cookie);
}
}
#[allow(clippy::unwrap_or_default)]
pub fn add_cookie<C, U>(&self, cookie: C, uri: U)
where
C: Into<RawCookie<'static>>,
U: IntoUri,
{
let cookie: RawCookie<'static> = cookie.into();
let Ok(uri) = uri.into_uri() else { return };
let domain = cookie
.domain()
.map(normalize_domain)
.or_else(|| uri.host())
.unwrap_or_default();
let path = cookie.path().unwrap_or_else(|| normalize_path(&uri));
let expired = cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
|| cookie.max_age().is_some_and(|age| age.is_zero());
self.store
.entry_sync(domain.to_owned())
.or_insert_with(SccHashMap::new)
.get()
.entry_sync(path.to_owned())
.or_insert_with(CookieJar::default)
.get_mut()
.remove(cookie.name().to_owned());
if !expired {
self.store
.entry_sync(domain.to_owned())
.or_insert_with(SccHashMap::new)
.get()
.entry_sync(path.to_owned())
.or_insert_with(CookieJar::default)
.get_mut()
.add(cookie);
}
}
#[allow(clippy::unwrap_or_default)]
pub fn remove<C, U>(&self, cookie: C, uri: U)
where
C: Into<RawCookie<'static>>,
U: IntoUri,
{
let cookie_raw = cookie.into();
let Ok(uri) = uri.into_uri() else { return };
let Some(host) = uri.host() else { return };
let path = uri.path().to_owned();
self.store
.entry_sync(host.to_owned())
.or_insert_with(SccHashMap::new)
.get()
.entry_sync(path)
.or_insert_with(CookieJar::default)
.get_mut()
.remove(cookie_raw.name().to_owned());
}
pub fn clear(&self) {
self.store.clear_sync();
}
}
impl CookieStore for Jar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, uri: &Uri) {
let cookies = cookie_headers
.map(Cookie::parse)
.filter_map(Result::ok)
.map(|cookie| cookie.0.into_owned());
for cookie in cookies {
self.add_cookie(cookie, uri);
}
}
fn cookies(&self, uri: &Uri) -> Cookies {
let host = match uri.host() {
Some(h) => h,
None => return Cookies::Empty,
};
if self.compression {
let mut cookies = String::new();
self.store.iter_sync(|domain, path_map| {
if !domain_match(host, domain) {
return true;
}
path_map.iter_sync(|path, cookie_jar| {
if !path_match(uri.path(), path) {
return true;
}
for cookie in cookie_jar.iter() {
if cookie.secure() == Some(true) && uri.is_http() {
continue;
}
if cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
{
continue;
}
if !cookies.is_empty() {
cookies.push_str("; ");
}
cookies.push_str(cookie.name());
cookies.push('=');
cookies.push_str(cookie.value());
}
true
});
true
});
if cookies.is_empty() {
return Cookies::Empty;
}
HeaderValue::from_maybe_shared(Bytes::from(cookies))
.map(Cookies::Compressed)
.unwrap_or(Cookies::Empty)
} else {
let mut cookie_strs = Vec::new();
self.store.iter_sync(|domain, path_map| {
if !domain_match(host, domain) {
return true;
}
path_map.iter_sync(|path, cookie_jar| {
if !path_match(uri.path(), path) {
return true;
}
for cookie in cookie_jar.iter() {
if cookie.secure() == Some(true) && uri.is_http() {
continue;
}
if cookie
.expires_datetime()
.is_some_and(|dt| dt <= SystemTime::now())
{
continue;
}
let name = cookie.name();
let value = cookie.value();
let mut s = String::with_capacity(name.len() + 1 + value.len());
s.push_str(name);
s.push('=');
s.push_str(value);
cookie_strs.push(s);
}
true
});
true
});
let cookies = cookie_strs
.into_iter()
.filter_map(|s| HeaderValue::from_maybe_shared(Bytes::from(s)).ok())
.collect();
Cookies::Uncompressed(cookies)
}
}
}
impl Default for Jar {
fn default() -> Self {
Self::new(true)
}
}
impl fmt::Debug for Jar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Jar")
.field("compression", &self.compression)
.finish()
}
}
const DEFAULT_PATH: &str = "/";
fn domain_match(host: &str, domain: &str) -> bool {
if domain.is_empty() {
return false;
}
if host == domain {
return true;
}
host.len() > domain.len()
&& host.as_bytes()[host.len() - domain.len() - 1] == b'.'
&& host.ends_with(domain)
}
fn path_match(req_path: &str, cookie_path: &str) -> bool {
req_path == cookie_path
|| req_path.starts_with(cookie_path)
&& (cookie_path.ends_with(DEFAULT_PATH)
|| req_path[cookie_path.len()..].starts_with(DEFAULT_PATH))
}
fn normalize_domain(domain: &str) -> &str {
domain.split(':').next().unwrap_or(domain)
}
fn normalize_path(uri: &Uri) -> &str {
let path = uri.path();
if !path.starts_with(DEFAULT_PATH) {
return DEFAULT_PATH;
}
if let Some(pos) = path.rfind(DEFAULT_PATH) {
if pos == 0 {
return DEFAULT_PATH;
}
return &path[..pos];
}
DEFAULT_PATH
}