use chrono::Duration;
use chrono::prelude::*;
use crate::{Request, Response};
use crate::middleware::Middleware;
use http::Uri;
use log::*;
use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::RwLock;
pub struct Cookie {
name: String,
value: String,
domain: String,
path: String,
secure: bool,
host_only: bool,
expiration: Option<DateTime<Utc>>,
}
impl Cookie {
fn parse(header: &str, uri: &Uri) -> Option<Self> {
let mut attributes = header.split(";")
.map(str::trim)
.map(|item| item
.splitn(2, "=")
.map(str::trim));
let mut first_pair = attributes.next()?;
let cookie_name = first_pair.next()?.into();
let cookie_value = first_pair.next()?.into();
let mut cookie_domain = None;
let mut cookie_path = None;
let mut cookie_secure = false;
let mut cookie_expiration = None;
for mut attribute in attributes {
let name = attribute.next()?;
let value = attribute.next();
if name.eq_ignore_ascii_case("Expires") {
if cookie_expiration.is_none() {
if let Some(value) = value {
if let Ok(time) = DateTime::parse_from_rfc2822(value) {
cookie_expiration = Some(time.with_timezone(&Utc));
}
}
}
} else if name.eq_ignore_ascii_case("Domain") {
cookie_domain = value
.map(|s| s.trim_start_matches("."))
.map(str::to_lowercase);
} else if name.eq_ignore_ascii_case("Max-Age") {
if let Some(value) = value {
if let Ok(seconds) = value.parse() {
cookie_expiration = Some(Utc::now() + Duration::seconds(seconds));
}
}
} else if name.eq_ignore_ascii_case("Path") {
cookie_path = value.map(ToOwned::to_owned);
} else if name.eq_ignore_ascii_case("Secure") {
cookie_secure = true;
}
}
if let Some(domain) = cookie_domain.as_ref() {
if !Cookie::domain_matches(uri.host()?, domain) {
warn!("cookie '{}' dropped, domain '{}' not allowed to set cookies for '{}'", cookie_name, uri.host()?, domain);
return None;
}
#[cfg(feature = "psl")] {
use ::psl::Psl;
let list = ::psl::List::new();
if let Some(suffix) = list.suffix(domain) {
if domain == suffix.to_str() {
warn!("cookie '{}' dropped, setting cookies for domain '{}' is not allowed", cookie_name, domain);
return None;
}
}
}
}
Some(Self {
name: cookie_name,
value: cookie_value,
secure: cookie_secure,
expiration: cookie_expiration,
host_only: cookie_domain.is_none(),
domain: cookie_domain.or_else(|| {
uri.host().map(ToOwned::to_owned)
})?,
path: cookie_path.unwrap_or_else(|| {
Cookie::default_path(uri).to_owned()
}),
})
}
fn is_expired(&self) -> bool {
match self.expiration {
Some(time) => time < Utc::now(),
None => false,
}
}
fn key(&self) -> String {
format!("{}.{}.{}", self.domain, self.path, self.name)
}
fn matches(&self, uri: &Uri) -> bool {
if self.secure && uri.scheme_part() != Some(&::http::uri::Scheme::HTTPS) {
return false;
}
let request_host = uri.host().unwrap_or("");
if self.host_only {
if !self.domain.eq_ignore_ascii_case(request_host) {
return false;
}
} else {
if !Cookie::domain_matches(request_host, &self.domain) {
return false;
}
}
if !Cookie::path_matches(uri.path(), &self.path) {
return false;
}
if self.is_expired() {
return false;
}
true
}
fn domain_matches(string: &str, domain_string: &str) -> bool {
if domain_string.eq_ignore_ascii_case(string) {
return true;
}
let string = &string.to_lowercase();
let domain_string = &domain_string.to_lowercase();
string.ends_with(domain_string) &&
string.as_bytes()[string.len() - domain_string.len() - 1] == b'.' &&
string.parse::<Ipv4Addr>().is_err() &&
string.parse::<Ipv6Addr>().is_err()
}
fn path_matches(request_path: &str, cookie_path: &str) -> bool {
if request_path == cookie_path {
return true;
}
if request_path.starts_with(cookie_path) {
if cookie_path.ends_with("/") || request_path[cookie_path.len()..].starts_with("/") {
return true;
}
}
false
}
fn default_path(uri: &Uri) -> &str {
if uri.path().chars().next() != Some('/') {
return "/";
}
let rightmost_slash_idx = uri.path().rfind("/").unwrap();
if rightmost_slash_idx == 0 {
return "/";
}
&uri.path()[..rightmost_slash_idx]
}
}
#[derive(Default)]
pub struct CookieJar {
cookies: RwLock<HashMap<String, Cookie>>,
}
impl CookieJar {
pub fn add(&self, cookies: impl Iterator<Item=Cookie>) {
let mut jar = self.cookies.write().unwrap();
for cookie in cookies {
jar.insert(cookie.key(), cookie);
}
jar.retain(|_, cookie| {
!cookie.is_expired()
});
}
fn get_cookies(&self, uri: &Uri) -> Option<String> {
let jar = self.cookies.read().unwrap();
let mut values: Vec<String> = jar.values()
.filter(|cookie| cookie.matches(uri))
.map(|cookie| format!("{}={}", cookie.name, cookie.value))
.collect();
if values.is_empty() {
None
} else {
values.sort();
Some(values.join("; "))
}
}
}
impl Middleware for CookieJar {
fn filter_request(&self, mut request: Request) -> Request {
if let Some(header) = self.get_cookies(request.uri()) {
request.headers_mut().insert(http::header::COOKIE, header.parse().unwrap());
}
request
}
fn filter_response(&self, response: Response) -> Response {
if response.headers().contains_key(http::header::SET_COOKIE) {
let cookies = response.headers()
.get_all(http::header::SET_COOKIE)
.into_iter()
.filter_map(|header| {
match header.to_str() {
Ok(header) => match Cookie::parse(header, response.extensions().get().unwrap()) {
Some(cookie) => return Some(cookie),
_ => warn!("could not parse Set-Cookie header"),
},
_ => warn!("invalid encoding in Set-Cookie header"),
}
None
});
self.add(cookies);
}
response
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_set_cookie_header() {
let uri = "https://baz.com".parse().unwrap();
let cookie = Cookie::parse("foo=bar; path=/sub;Secure ; expires =Wed, 21 Oct 2015 07:28:00 GMT", &uri).unwrap();
assert_eq!(cookie.name, "foo");
assert_eq!(cookie.value, "bar");
assert_eq!(cookie.path, "/sub");
assert_eq!(cookie.domain, "baz.com");
assert!(cookie.secure);
assert!(cookie.host_only);
assert_eq!(cookie.expiration.as_ref().map(|t| t.timestamp()), Some(1445412480));
}
#[test]
fn cookie_domain_not_allowed() {
let uri = "https://bar.baz.com".parse().unwrap();
assert!(Cookie::parse("foo=bar", &uri).is_some());
assert!(Cookie::parse("foo=bar; domain=bar.baz.com", &uri).is_some());
assert!(Cookie::parse("foo=bar; domain=baz.com", &uri).is_some());
assert!(Cookie::parse("foo=bar; domain=www.bar.baz.com", &uri).is_none());
if cfg!(feature = "psl") {
assert!(Cookie::parse("foo=bar; domain=com", &uri).is_none());
assert!(Cookie::parse("foo=bar; domain=.com", &uri).is_none());
} else {
assert!(Cookie::parse("foo=bar; domain=com", &uri).is_some());
}
}
#[test]
fn domain_matches() {
for case in &[
("127.0.0.1", "127.0.0.1", true),
(".127.0.0.1", "127.0.0.1", true),
("bar.com", "bar.com", true),
("baz.com", "bar.com", false),
("baz.bar.com", "bar.com", true),
("www.baz.com", "baz.com", true),
("baz.bar.com", "com", true),
] {
assert_eq!(Cookie::domain_matches(case.0, case.1), case.2);
}
}
#[test]
fn path_matches() {
for case in &[
("/foo", "/foo", true),
("/Foo", "/foo", false),
("/fo", "/foo", false),
("/foo/bar", "/foo", true),
("/foo/bar/baz", "/foo", true),
("/foo/bar//baz", "/foo", true),
("/foobar", "/foo", false),
("/foo", "/foo/bar", false),
("/foobar", "/foo/bar", false),
("/foo/bar", "/foo/bar", true),
("/foo/bar/", "/foo/bar", true),
("/foo/bar/baz", "/foo/bar", true),
("/foo/bar", "/foo/bar/", false),
("/foo/bar/", "/foo/bar/", true),
("/foo/bar/baz", "/foo/bar/", true),
] {
assert_eq!(Cookie::path_matches(case.0, case.1), case.2);
}
}
#[test]
fn cookie_lifecycle() {
let uri: Uri = "https://example.com/foo".parse().unwrap();
let jar = CookieJar::default();
jar.filter_response(http::Response::builder()
.header(http::header::SET_COOKIE, "foo=bar")
.header(http::header::SET_COOKIE, "baz=123")
.extension(uri.clone())
.body(crate::Body::default())
.unwrap());
let request = jar.filter_request(http::Request::builder()
.uri(uri)
.body(crate::Body::default())
.unwrap());
assert_eq!(request.headers()[http::header::COOKIE], "baz=123; foo=bar");
}
#[test]
fn expire_a_cookie() {
let uri: Uri = "https://example.com/foo".parse().unwrap();
let jar = CookieJar::default();
jar.add(Cookie::parse("foo=bar", &uri).into_iter());
assert_eq!(jar.get_cookies(&uri).unwrap(), "foo=bar");
jar.add(Cookie::parse("foo=; expires=Wed, 21 Oct 2015 07:28:00 GMT", &uri).into_iter());
assert_eq!(jar.get_cookies(&uri), None);
}
}