use super::Cookie;
use http::Uri;
use std::{
collections::HashSet,
error::Error,
fmt,
hash::{Hash, Hasher},
net::{Ipv4Addr, Ipv6Addr},
sync::{Arc, RwLock},
};
#[derive(Clone, Debug)]
pub struct CookieRejectedError {
kind: CookieRejectedErrorKind,
cookie: Cookie,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum CookieRejectedErrorKind {
InvalidRequestDomain,
InvalidCookieDomain,
DomainMismatch,
}
impl CookieRejectedError {
pub fn kind(&self) -> CookieRejectedErrorKind {
self.kind
}
pub fn cookie(self) -> Cookie {
self.cookie
}
}
impl fmt::Display for CookieRejectedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "invalid cookie for given request URI")
}
}
impl Error for CookieRejectedError {}
#[derive(Clone, Debug, Default)]
pub struct CookieJar {
cookies: Arc<RwLock<HashSet<CookieWithContext>>>,
}
impl CookieJar {
pub fn new() -> Self {
Self::default()
}
pub fn get_by_name(&self, uri: &Uri, cookie_name: &str) -> Option<Cookie> {
self.cookies
.read()
.unwrap()
.iter()
.filter(|cookie| cookie.matches(uri))
.filter(|cookie| cookie.cookie.name() == cookie_name)
.map(|c| c.cookie.clone())
.next()
}
pub fn get_for_uri(&self, uri: &Uri) -> impl IntoIterator<Item = Cookie> {
let jar = self.cookies.read().unwrap();
let mut cookies = jar
.iter()
.filter(|cookie| cookie.matches(uri))
.map(|c| c.cookie.clone())
.collect::<Vec<_>>();
cookies.sort_by(|a, b| a.name().cmp(b.name()));
cookies
}
pub fn clear(&self) {
self.cookies.write().unwrap().clear();
}
pub fn set(
&self,
cookie: Cookie,
request_uri: &Uri,
) -> Result<Option<Cookie>, CookieRejectedError> {
let request_host = if let Some(host) = request_uri.host() {
host
} else {
tracing::warn!(
"cookie '{}' dropped, no domain specified in request URI",
cookie.name()
);
return Err(CookieRejectedError {
kind: CookieRejectedErrorKind::InvalidRequestDomain,
cookie,
});
};
if let Some(domain) = cookie.domain() {
if !domain_matches(request_host, domain) {
tracing::warn!(
"cookie '{}' dropped, domain '{}' not allowed to set cookies for '{}'",
cookie.name(),
request_host,
domain
);
return Err(CookieRejectedError {
kind: CookieRejectedErrorKind::DomainMismatch,
cookie,
});
}
if !domain.contains('.') {
tracing::warn!(
"cookie '{}' dropped, setting cookies for domain '{}' is not allowed",
cookie.name(),
domain
);
return Err(CookieRejectedError {
kind: CookieRejectedErrorKind::InvalidCookieDomain,
cookie,
});
}
#[cfg(feature = "psl")]
{
if super::psl::is_public_suffix(domain) {
tracing::warn!(
"cookie '{}' dropped, setting cookies for domain '{}' is not allowed",
cookie.name(),
domain
);
return Err(CookieRejectedError {
kind: CookieRejectedErrorKind::InvalidCookieDomain,
cookie,
});
}
}
}
let cookie_with_context = CookieWithContext {
domain_value: cookie
.domain()
.map(ToOwned::to_owned)
.unwrap_or_else(|| request_host.to_owned()),
path_value: cookie
.path()
.map(ToOwned::to_owned)
.unwrap_or_else(|| default_path(request_uri).to_owned()),
cookie,
};
let mut jar = self.cookies.write().unwrap();
let existing = jar
.replace(cookie_with_context)
.map(|cookie_with_context| cookie_with_context.cookie);
jar.retain(|cookie| !cookie.cookie.is_expired());
Ok(existing)
}
}
#[derive(Debug)]
struct CookieWithContext {
domain_value: String,
path_value: String,
cookie: Cookie,
}
impl CookieWithContext {
fn is_host_only(&self) -> bool {
self.cookie.domain().is_none()
}
fn matches(&self, uri: &Uri) -> bool {
if self.cookie.is_secure() && uri.scheme() != Some(&::http::uri::Scheme::HTTPS) {
return false;
}
let request_host = uri.host().unwrap_or("");
if self.is_host_only() {
if !self.domain_value.eq_ignore_ascii_case(request_host) {
return false;
}
} else if !domain_matches(request_host, &self.domain_value) {
return false;
}
if !path_matches(uri.path(), &self.path_value) {
return false;
}
if self.cookie.is_expired() {
return false;
}
true
}
}
impl Hash for CookieWithContext {
fn hash<H: Hasher>(&self, state: &mut H) {
self.domain_value.hash(state);
self.path_value.hash(state);
self.cookie.name().hash(state);
}
}
impl PartialEq for CookieWithContext {
fn eq(&self, other: &Self) -> bool {
self.domain_value == other.domain_value
&& self.path_value == other.path_value
&& self.cookie.name() == other.cookie.name()
}
}
impl Eq for CookieWithContext {}
fn domain_matches(string: &str, domain_string: &str) -> bool {
if domain_string.eq_ignore_ascii_case(string) {
return true;
}
let string = &string.to_lowercase();
let domain_string = &domain_string.to_lowercase();
string.ends_with(domain_string)
&& string.as_bytes()[string.len() - domain_string.len() - 1] == b'.'
&& string.parse::<Ipv4Addr>().is_err()
&& string.parse::<Ipv6Addr>().is_err()
}
fn path_matches(request_path: &str, cookie_path: &str) -> bool {
if request_path == cookie_path {
return true;
}
if request_path.starts_with(cookie_path)
&& (cookie_path.ends_with('/') || request_path[cookie_path.len()..].starts_with('/'))
{
return true;
}
false
}
fn default_path(uri: &Uri) -> &str {
if !uri.path().starts_with('/') {
return "/";
}
let rightmost_slash_idx = uri.path().rfind('/').unwrap();
if rightmost_slash_idx == 0 {
return "/";
}
&uri.path()[..rightmost_slash_idx]
}
#[cfg(test)]
mod tests {
use super::*;
use test_case::test_case;
#[test]
fn cookie_domain_not_allowed() {
let jar = CookieJar::default();
assert!(jar
.set(
Cookie::parse("foo=bar").unwrap(),
&"https://bar.baz.com".parse().unwrap()
)
.is_ok());
assert!(jar
.set(
Cookie::parse("foo=bar; domain=bar.baz.com").unwrap(),
&"https://bar.baz.com".parse().unwrap()
)
.is_ok());
assert!(jar
.set(
Cookie::parse("foo=bar; domain=baz.com").unwrap(),
&"https://bar.baz.com".parse().unwrap()
)
.is_ok());
assert!(
jar.set(
Cookie::parse("foo=bar; domain=www.bar.baz.com").unwrap(),
&"https://bar.baz.com".parse().unwrap(),
)
.unwrap_err()
.kind()
== CookieRejectedErrorKind::DomainMismatch
);
assert!(
jar.set(
Cookie::parse("foo=bar; domain=com").unwrap(),
&"https://bar.baz.com".parse().unwrap(),
)
.unwrap_err()
.kind()
== CookieRejectedErrorKind::InvalidCookieDomain
);
assert!(
jar.set(
Cookie::parse("foo=bar; domain=.com").unwrap(),
&"https://bar.baz.com".parse().unwrap(),
)
.unwrap_err()
.kind()
== CookieRejectedErrorKind::InvalidCookieDomain
);
if cfg!(feature = "psl") {
assert!(
jar.set(
Cookie::parse("foo=bar; domain=wi.us").unwrap(),
&"https://www.state.wi.us".parse().unwrap(),
)
.unwrap_err()
.kind()
== CookieRejectedErrorKind::InvalidCookieDomain
);
}
}
#[test]
fn expire_a_cookie() {
let uri: Uri = "https://example.com/foo".parse().unwrap();
let jar = CookieJar::default();
jar.set(Cookie::parse("foo=bar").unwrap(), &uri).unwrap();
assert_eq!(jar.get_by_name(&uri, "foo").unwrap(), "bar");
jar.set(
Cookie::parse("foo=; expires=Wed, 21 Oct 2015 07:28:00 GMT").unwrap(),
&uri,
)
.unwrap();
assert!(jar.get_for_uri(&uri).into_iter().next().is_none());
}
#[test_case("127.0.0.1", "127.0.0.1", true)]
#[test_case(".127.0.0.2", "127.0.0.2", true)]
#[test_case("bar.com", "bar.com", true)]
#[test_case("baz.com", "bar.com", false)]
#[test_case("baz.bar.com", "bar.com", true)]
#[test_case("www.baz.com", "baz.com", true)]
#[test_case("baz.bar.com", "com", true)]
fn test_domain_matches(string: &str, domain_string: &str, should_match: bool) {
assert_eq!(domain_matches(string, domain_string), should_match);
}
#[test_case("/foo", "/foo", true)]
#[test_case("/Bar", "/bar", false)]
#[test_case("/fo", "/foo", false)]
#[test_case("/foo/bar", "/foo", true)]
#[test_case("/foo/bar/baz", "/foo", true)]
#[test_case("/foo/bar//baz2", "/foo", true)]
#[test_case("/foobar", "/foo", false)]
#[test_case("/foo", "/foo/bar", false)]
#[test_case("/foobar", "/foo/bar", false)]
#[test_case("/foo/bar", "/foo/bar", true)]
#[test_case("/foo/bar2/", "/foo/bar2", true)]
#[test_case("/foo/bar/baz", "/foo/bar", true)]
#[test_case("/foo/bar3", "/foo/bar3/", false)]
#[test_case("/foo/bar4/", "/foo/bar4/", true)]
#[test_case("/foo/bar/baz2", "/foo/bar/", true)]
fn test_path_matches(request_path: &str, cookie_path: &str, should_match: bool) {
assert_eq!(path_matches(request_path, cookie_path), should_match);
}
}