use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime};
use async_trait::async_trait;
use crate::error::{Error, ErrorKind, Result};
use crate::middleware::{Middleware, Next};
use crate::request::Request;
use crate::response::Response;
use crate::url::Url;
#[derive(Clone, Debug, Eq, PartialEq)]
struct StoredCookie {
name: String,
value: String,
domain: String,
host_only: bool,
path: String,
secure: bool,
expires_at: Option<Instant>,
}
#[derive(Clone, Default)]
pub struct CookieJar {
inner: Arc<Mutex<Vec<StoredCookie>>>,
}
impl CookieJar {
pub fn new() -> Self {
Self::default()
}
pub fn get_cookie_header(&self, url: &Url) -> Option<String> {
let mut cookies = self.inner.lock().unwrap_or_else(|err| err.into_inner());
cookies.retain(|cookie| !cookie.is_expired());
let mut pairs = Vec::new();
for cookie in cookies.iter() {
if cookie_matches(cookie, url) {
pairs.push(format!("{}={}", cookie.name, cookie.value));
}
}
if pairs.is_empty() {
None
} else {
Some(pairs.join("; "))
}
}
pub fn store_set_cookie(&self, url: &Url, header: &str) -> Result<()> {
let cookie = parse_set_cookie(url, header)?;
let mut cookies = self.inner.lock().unwrap_or_else(|err| err.into_inner());
cookies.retain(|existing| {
!(existing.name == cookie.name
&& existing.domain == cookie.domain
&& existing.path == cookie.path)
});
if !cookie.is_expired() {
cookies.push(cookie);
}
Ok(())
}
pub fn store_set_cookies<'a, I>(&self, url: &Url, headers: I) -> Result<()>
where
I: IntoIterator<Item = &'a str>,
{
for header in headers {
self.store_set_cookie(url, header)?;
}
Ok(())
}
}
pub struct CookieMiddleware {
jar: CookieJar,
}
impl CookieMiddleware {
pub fn new(jar: CookieJar) -> Self {
Self { jar }
}
pub fn jar(&self) -> &CookieJar {
&self.jar
}
}
impl StoredCookie {
fn is_expired(&self) -> bool {
self.expires_at
.is_some_and(|expires_at| Instant::now() >= expires_at)
}
}
#[async_trait]
impl Middleware for CookieMiddleware {
async fn handle(&self, mut req: Request, next: Next<'_>) -> Result<Response> {
if let Some(cookie_header) = self.jar.get_cookie_header(req.url()) {
req.headers_mut().insert("cookie", cookie_header)?;
}
let url = req.url().clone();
let response = next.run(req).await?;
self.jar
.store_set_cookies(&url, response.headers().get_all("set-cookie"))?;
Ok(response)
}
}
fn parse_set_cookie(url: &Url, header: &str) -> Result<StoredCookie> {
let mut parts = header.split(';');
let name_value = parts
.next()
.ok_or_else(|| Error::new(ErrorKind::Decode, "set-cookie header is empty"))?;
let (name, value) = name_value
.split_once('=')
.ok_or_else(|| Error::new(ErrorKind::Decode, "set-cookie is missing name/value"))?;
let mut domain = url.host().to_ascii_lowercase();
let mut host_only = true;
let mut path = default_cookie_path(url);
let mut secure = false;
let mut expires_at = None;
for attribute in parts {
let attribute = attribute.trim();
if attribute.eq_ignore_ascii_case("secure") {
secure = true;
continue;
}
if let Some((key, value)) = attribute.split_once('=') {
if key.eq_ignore_ascii_case("domain") && !value.trim().is_empty() {
domain = value.trim().trim_start_matches('.').to_ascii_lowercase();
host_only = false;
} else if key.eq_ignore_ascii_case("path") && !value.trim().is_empty() {
path = normalize_cookie_path(value.trim());
} else if key.eq_ignore_ascii_case("max-age") {
let seconds: i64 = value.trim().parse().map_err(|_| {
Error::new(ErrorKind::Decode, "invalid max-age attribute in set-cookie")
})?;
expires_at = if seconds <= 0 {
Some(Instant::now())
} else {
Some(Instant::now() + Duration::from_secs(seconds as u64))
};
} else if key.eq_ignore_ascii_case("expires") {
expires_at = Some(parse_cookie_expires(value.trim())?);
}
}
}
Ok(StoredCookie {
name: name.trim().to_owned(),
value: value.trim().to_owned(),
domain,
host_only,
path,
secure,
expires_at,
})
}
fn default_cookie_path(url: &Url) -> String {
let path = url.path_and_query().split('?').next().unwrap_or("/");
if path == "/" {
return "/".to_owned();
}
match path.rsplit_once('/') {
Some(("", _)) | None => "/".to_owned(),
Some((prefix, _)) => format!("{prefix}/"),
}
}
fn normalize_cookie_path(path: &str) -> String {
if path.starts_with('/') {
path.to_owned()
} else {
format!("/{path}")
}
}
fn cookie_matches(cookie: &StoredCookie, url: &Url) -> bool {
let host = url.host().to_ascii_lowercase();
let request_path = url.path_and_query().split('?').next().unwrap_or("/");
cookie_domain_matches(cookie, &host)
&& path_matches(&cookie.path, request_path)
&& (!cookie.secure || matches!(url.scheme(), "https" | "wss"))
}
fn domain_matches(cookie_domain: &str, host: &str) -> bool {
host == cookie_domain || host.ends_with(&format!(".{cookie_domain}"))
}
fn cookie_domain_matches(cookie: &StoredCookie, host: &str) -> bool {
if cookie.host_only {
host == cookie.domain
} else {
domain_matches(&cookie.domain, host)
}
}
fn path_matches(cookie_path: &str, request_path: &str) -> bool {
request_path == cookie_path
|| request_path.starts_with(cookie_path)
|| (cookie_path.ends_with('/') && request_path.starts_with(cookie_path))
}
fn parse_cookie_expires(value: &str) -> Result<Instant> {
let expires_at = httpdate::parse_http_date(value).map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"invalid expires attribute in set-cookie",
err,
)
})?;
match expires_at.duration_since(SystemTime::now()) {
Ok(duration) => Ok(Instant::now() + duration),
Err(_) => Ok(Instant::now()),
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::{CookieJar, parse_set_cookie};
use crate::Url;
#[test]
fn stores_and_formats_cookie_header() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(&url, "session=abc; Path=/; Secure")
.unwrap();
assert_eq!(jar.get_cookie_header(&url).as_deref(), Some("session=abc"));
}
#[test]
fn cookie_path_must_match_request_path() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/account/login").unwrap();
jar.store_set_cookie(&url, "session=abc; Path=/account")
.unwrap();
assert_eq!(
jar.get_cookie_header(&Url::parse("https://api.example.com/account/me").unwrap())
.as_deref(),
Some("session=abc")
);
assert_eq!(
jar.get_cookie_header(&Url::parse("https://api.example.com/admin").unwrap()),
None
);
}
#[test]
fn parses_domain_and_secure_attributes() {
let url = Url::parse("https://api.example.com/users").unwrap();
let cookie =
parse_set_cookie(&url, "theme=dark; Domain=.example.com; Path=/; Secure").unwrap();
assert_eq!(cookie.domain, "example.com");
assert!(!cookie.host_only);
assert_eq!(cookie.path, "/");
assert!(cookie.secure);
}
#[test]
fn host_only_cookie_does_not_match_subdomains() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(&url, "session=abc; Path=/").unwrap();
assert_eq!(
jar.get_cookie_header(&Url::parse("https://api.example.com/me").unwrap())
.as_deref(),
Some("session=abc")
);
assert_eq!(
jar.get_cookie_header(&Url::parse("https://sub.api.example.com/me").unwrap()),
None
);
}
#[test]
fn domain_cookie_matches_subdomains() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(&url, "session=abc; Domain=example.com; Path=/")
.unwrap();
assert_eq!(
jar.get_cookie_header(&Url::parse("https://sub.example.com/me").unwrap())
.as_deref(),
Some("session=abc")
);
}
#[test]
fn max_age_zero_removes_existing_cookie() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(&url, "session=abc; Path=/").unwrap();
jar.store_set_cookie(&url, "session=gone; Path=/; Max-Age=0")
.unwrap();
assert_eq!(jar.get_cookie_header(&url), None);
}
#[test]
fn expired_cookie_is_not_returned() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(&url, "session=abc; Path=/; Max-Age=1")
.unwrap();
std::thread::sleep(Duration::from_secs(2));
assert_eq!(jar.get_cookie_header(&url), None);
}
#[test]
fn future_expires_cookie_is_returned() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(
&url,
"session=abc; Path=/; Expires=Wed, 01 Jan 3000 00:00:00 GMT",
)
.unwrap();
assert_eq!(jar.get_cookie_header(&url).as_deref(), Some("session=abc"));
}
#[test]
fn past_expires_cookie_is_not_returned() {
let jar = CookieJar::new();
let url = Url::parse("https://api.example.com/users").unwrap();
jar.store_set_cookie(
&url,
"session=abc; Path=/; Expires=Sat, 01 Jan 2000 00:00:00 GMT",
)
.unwrap();
assert_eq!(jar.get_cookie_header(&url), None);
}
}