use chrono::Utc;
use chrono::NaiveDate;
use std::time::UNIX_EPOCH;
use chrono::DateTime;
use std::str::FromStr;
use std::io::{Error, ErrorKind};
use std::time::{Duration, SystemTime};
use std::ops::Add;
use std::collections::{HashMap, HashSet};
use std::cmp::{PartialEq, Eq};
use std::hash::{Hash, Hasher};
use lazy_static::lazy_static;
use regex::Regex;
pub(crate) const COOKIE: &str = "cookie";
pub(crate) const COOKIE_EXPIRES: &str = "expires";
pub(crate) const COOKIE_MAX_AGE: &str = "max-age";
pub(crate) const COOKIE_DOMAIN: &str = "domain";
pub(crate) const COOKIE_PATH: &str = "path";
pub(crate) const COOKIE_SAME_SITE: &str = "samesite";
pub(crate) const COOKIE_SAME_SITE_STRICT: &str = "strict";
pub(crate) const COOKIE_SAME_SITE_LAX: &str = "lax";
pub(crate) const COOKIE_SAME_SITE_NONE: &str = "none";
pub(crate) const COOKIE_SECURE: &str = "secure";
pub(crate) const COOKIE_HTTP_ONLY: &str = "httponly";
#[derive(Debug,Copy,Clone,PartialEq)]
pub enum SameSiteValue {Strict, Lax, None}
impl FromStr for SameSiteValue {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
return match s {
COOKIE_SAME_SITE_STRICT => Ok(SameSiteValue::Strict),
COOKIE_SAME_SITE_LAX => Ok(SameSiteValue::Lax),
COOKIE_SAME_SITE_NONE => Ok(SameSiteValue::None),
_ => Err(
Error::new(ErrorKind::InvalidData,
format!("Invalid SameSite cookie directive value: {}", s)))
}
}
}
#[derive(Debug,Clone)]
pub struct Cookie {
pub(crate) name: String,
pub (crate) value: String,
pub(crate) domain: String,
pub(crate) path: String,
pub(crate) expires: Option<SystemTime>,
pub(crate) same_site: SameSiteValue,
pub(crate) secure: bool,
pub(crate) http_only: bool,
pub(crate) extensions: HashMap<String, String>
}
impl Cookie {
pub fn new (name: String, value: String, domain: String, path: String) -> Cookie {
Cookie {
name,
value,
domain,
path,
expires: None,
same_site: SameSiteValue::Lax,
secure: false,
http_only: false,
extensions: HashMap::new()
}
}
pub fn name(& self) -> &str {
self.name.as_str()
}
pub fn value(& self) -> &str {
self.value.as_str()
}
pub fn domain(& self) -> &str {
self.domain.as_str()
}
pub fn path(& self) -> &str {
self.path.as_str()
}
pub fn expires(& self) -> Option<SystemTime> {
self.expires.clone()
}
pub fn same_site(& self) -> SameSiteValue {
self.same_site
}
pub fn secure(& self) -> bool {
self.secure
}
pub fn http_only(& self) -> bool {
self.http_only
}
pub fn extensions(&self) -> &HashMap<String, String> {
&self.extensions
}
pub fn path_match(&self, request_path: &str) -> bool {
let cookie_path = self.path();
let cookie_path_len = cookie_path.len();
let request_path_len = request_path.len();
if !request_path.starts_with(cookie_path) { return false;
}
return request_path_len == cookie_path_len || cookie_path.chars().nth(cookie_path_len - 1).unwrap() == '/'
|| request_path.chars().nth(cookie_path_len).unwrap() == '/';
}
pub fn domain_match(&self, request_domain: &str) -> bool {
let cookie_domain = self.domain();
if let Some(index) = request_domain.rfind(cookie_domain) {
if index == 0 { return true;
}
return request_domain.chars().nth(index-1).unwrap() == '.';
}
return false;
}
pub fn request_match(&self, request_domain: &str, request_path: &str, secure: bool) -> bool {
if self.secure && !secure {
return false;
}
if self.same_site == SameSiteValue::Strict && self.domain != request_domain {
return false;
}
if self.same_site() == SameSiteValue::Lax && !self.domain_match(request_domain) {
return false;
}
if self.same_site == SameSiteValue::None && ! self.secure {
return false;
}
return self.path_match(request_path);
}
pub fn parse(s: &str, domain: &str, path: &str) -> Result<Cookie, Error> {
let mut components = s.split(';');
return if let Some(slice) = components.next() {
let (key, value) = parse_cookie_value(slice)?;
let mut cookie = Cookie::new(key, value, String::from(domain), String::from(path));
while let Some(param) = components.next() {
let directive = CookieDirective::from_str(param)?;
match directive {
CookieDirective::Expires(date) => {
if cookie.expires().is_none() { cookie.expires = Some(date);
}
},
CookieDirective::MaxAge(seconds) => {
cookie.expires = Some(SystemTime::now().add(seconds));
},
CookieDirective::Domain(url) => { cookie.domain = if let Some(stripped) = url.as_str().strip_prefix(".") {
String::from(stripped)
} else {
url
}
},
CookieDirective::Path(path) => cookie.path = path,
CookieDirective::SameSite(val) => cookie.same_site = val,
CookieDirective::Secure => cookie.secure = true,
CookieDirective::HttpOnly => cookie.http_only = true,
CookieDirective::Extension(name, value) => {
let _res = cookie.extensions.insert(name, value);
}
}
}
Ok(cookie)
} else {
if CookieDirective::from_str(s).is_ok() {
return Err(Error::new(ErrorKind::InvalidData, "Cookie has not got name/value"));
};
let (key, value) = parse_cookie_value(s)?;
Ok(Cookie::new(key, value, String::from(domain), String::from(path)))
}
}
}
impl PartialEq for Cookie {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Cookie{}
impl Hash for Cookie {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.domain.hash(state);
}
}
pub(crate) fn parse_cookie_value(cookie: &str) -> Result<(String, String), Error>{
if let Some(index) = cookie.find('=') {
let key = String::from(cookie[0..index].trim());
let value = String::from(cookie[index + 1..].trim());
return Ok((key, value))
} else {
Err(Error::new(ErrorKind::InvalidData,
format!("Malformed HTTP cookie: {}", cookie)))
}
}
enum CookieDirective {
Expires(SystemTime),
MaxAge(Duration),
Domain(String),
Path(String),
SameSite(SameSiteValue),
Secure,
HttpOnly,
Extension(String, String)
}
const DATE_FORMAT_850: &str= "(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday|Mon|Tue|Wed|Thu|Fri|Sat|Sun), \
(0[1-9]|[123][0-9])-(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)-([0-9]{4}|[0-9]{2}) \
([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]) GMT";
const DATE_FORMAT_1123: &str= "(Mon|Tue|Wed|Thu|Fri|Sat|Sun), \
(0[1-9]|[123][0-9]) (Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec) ([0-9]{4}) \
([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]) GMT";
const DATE_FORMAT_ASCT: &str= "(Mon|Tue|Wed|Thu|Fri|Sat|Sun) \
(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[ ]{1,2}([1-9]|0[1-9]|[123][0-9]) \
([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]) ([0-9]{4})";
fn parse_rfc_850_date(date: &str) -> Result<SystemTime, Error> {
lazy_static! {
static ref RE: Regex = Regex::new(DATE_FORMAT_850).unwrap();
}
if let Some(captures) = RE.captures(date) {
let day : u32 = captures.get(2).unwrap().as_str().parse().unwrap();
let month = match captures.get(3).unwrap().as_str() {
"Jan" => 1,
"Feb" => 2,
"Mar" => 3,
"Apr" => 4,
"May" => 5,
"Jun" => 6,
"Jul" => 7,
"Aug" => 8,
"Sep" => 9,
"Oct" => 10,
"Nov" => 11,
"Dec" => 12,
_ => return Err(Error::new(ErrorKind::InvalidData, "Invalid date"))
};
let mut year: i32 = captures.get(4).unwrap().as_str().parse().unwrap();
year+= if year < 70 {2000} else if year < 100 {1900} else {0};
let hour : u32 = captures.get(5).unwrap().as_str().parse().unwrap();
let min : u32 = captures.get(6).unwrap().as_str().parse().unwrap();
let secs : u32 = captures.get(7).unwrap().as_str().parse().unwrap();
let naive =
NaiveDate::from_ymd(year, month, day)
.and_hms(hour,min,secs);
let time = DateTime::<Utc>::from_utc(naive, Utc);
let millis = Duration::from_millis(time.timestamp_millis() as u64);
let time = UNIX_EPOCH.clone().add(millis);
return Ok(time);
} else {
return Err(Error::new(ErrorKind::InvalidData, "Invalid date"));
}
}
fn parse_rfc_1123_date(date: &str) -> Result<SystemTime, Error> {
lazy_static! {
static ref RE: Regex = Regex::new(DATE_FORMAT_1123).unwrap();
}
if let Some(captures) = RE.captures(date) {
let day : u32 = captures.get(2).unwrap().as_str().parse().unwrap();
let month = match captures.get(3).unwrap().as_str() {
"Jan" => 1,
"Feb" => 2,
"Mar" => 3,
"Apr" => 4,
"May" => 5,
"Jun" => 6,
"Jul" => 7,
"Aug" => 8,
"Sep" => 9,
"Oct" => 10,
"Nov" => 11,
"Dec" => 12,
_ => return Err(Error::new(ErrorKind::InvalidData, "Invalid date"))
};
let year: i32 = captures.get(4).unwrap().as_str().parse().unwrap();
let hour : u32 = captures.get(5).unwrap().as_str().parse().unwrap();
let min : u32 = captures.get(6).unwrap().as_str().parse().unwrap();
let secs : u32 = captures.get(7).unwrap().as_str().parse().unwrap();
let naive =
NaiveDate::from_ymd(year, month, day)
.and_hms(hour,min,secs);
let time = DateTime::<Utc>::from_utc(naive, Utc);
let millis = Duration::from_millis(time.timestamp_millis() as u64);
let time = UNIX_EPOCH.clone().add(millis);
return Ok(time);
} else {
return Err(Error::new(ErrorKind::InvalidData, "Invalid date"));
}
}
fn parse_asct_date(date: &str) -> Result<SystemTime, Error> {
lazy_static! {
static ref RE: Regex = Regex::new(DATE_FORMAT_ASCT).unwrap();
}
if let Some(captures) = RE.captures(date) {
let month = match captures.get(2).unwrap().as_str() {
"Jan" => 1,
"Feb" => 2,
"Mar" => 3,
"Apr" => 4,
"May" => 5,
"Jun" => 6,
"Jul" => 7,
"Aug" => 8,
"Sep" => 9,
"Oct" => 10,
"Nov" => 11,
"Dec" => 12,
_ => return Err(Error::new(ErrorKind::InvalidData, "Invalid date"))
};
let day : u32 = captures.get(3).unwrap().as_str().parse().unwrap();
let hour : u32 = captures.get(4).unwrap().as_str().parse().unwrap();
let min : u32 = captures.get(5).unwrap().as_str().parse().unwrap();
let secs : u32 = captures.get(6).unwrap().as_str().parse().unwrap();
let year: i32 = captures.get(7).unwrap().as_str().parse().unwrap();
let naive =
NaiveDate::from_ymd(year, month, day)
.and_hms(hour,min,secs);
let time = DateTime::<Utc>::from_utc(naive, Utc);
let millis = Duration::from_millis(time.timestamp_millis() as u64);
let time = UNIX_EPOCH.clone().add(millis);
return Ok(time);
} else {
return Err(Error::new(ErrorKind::InvalidData, "Invalid date"));
}
}
impl FromStr for CookieDirective {
type Err = Error;
fn from_str(s: &str) -> Result<CookieDirective,Error> {
if let Some(index) = s.find('=') { let key = s[0..index].trim().to_ascii_lowercase();
let value = s[index + 1..].trim();
return match key.as_str() {
COOKIE_EXPIRES => {
let expires = parse_rfc_1123_date(value)
.or_else(|_| parse_rfc_850_date(value))
.or_else(|_| parse_asct_date(value))?;
Ok(CookieDirective::Expires(expires))
},
COOKIE_MAX_AGE => { let digit = u64::from_str(value)
.or_else(|e| {
Err(Error::new(ErrorKind::InvalidData, e))
})?;
Ok(CookieDirective::MaxAge(Duration::from_secs(digit)))
},
COOKIE_DOMAIN => {
Ok(CookieDirective::Domain(String::from(value)))
},
COOKIE_PATH => {
Ok(CookieDirective::Path(String::from(value)))
}
COOKIE_SAME_SITE => {
let lower_case = value.to_ascii_lowercase();
match SameSiteValue::from_str(lower_case.as_str()) {
Ok(site_value) => Ok(CookieDirective::SameSite(site_value)),
Err(e) => Err(e)
}
},
_ => Ok(CookieDirective::Extension(key, value.to_string()))
}
} else {
match s.trim().to_ascii_lowercase().as_str() {
COOKIE_SECURE => Ok(CookieDirective::Secure),
COOKIE_HTTP_ONLY => Ok(CookieDirective::HttpOnly),
_ => return Err(
Error::new(ErrorKind::InvalidData,
format!("Invalid HTTP cookie directive: {}", s)))
}
}
}
}
pub trait CookieJar {
fn cookie(&mut self, value: Cookie, request_domain: &str);
fn active_cookies(&mut self, request_domain: &str, request_path: &str, secure: bool) -> Vec<(String, String)>;
}
pub struct MemCookieJar {
cookies: HashSet<Cookie>
}
impl MemCookieJar {
pub fn new() -> MemCookieJar{
MemCookieJar {
cookies: HashSet::new()
}
}
}
impl CookieJar for MemCookieJar {
fn cookie(&mut self, value: Cookie, request_domain: &str) {
if !value.domain_match(request_domain) {
return; }
let now = SystemTime::now();
if let Some(expires) = value.expires() {
if expires < now {
return; }
}
self.cookies.insert(value);
}
fn active_cookies(&mut self, request_domain: &str, request_path: &str, secure: bool) -> Vec<(String, String)> {
let mut result = Vec::new();
let now = SystemTime::now();
self.cookies.retain( |c| {
if let Some(time) = c.expires {
return time < now;
}
return true;
});
for cookie in self.cookies.iter() {
if cookie.request_match(request_domain, request_path, secure) {
result.push((cookie.name.clone(), cookie.value.clone()));
}
}
return result;
}
}
#[cfg(test)]
mod cookie_test;