ugi 0.2.1

Runtime-agnostic Rust request client with HTTP/1.1, HTTP/2, HTTP/3, H2C, WebSocket, SSE, and gRPC support
Documentation
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime};

use async_trait::async_trait;

use crate::error::{Error, ErrorKind, Result};
use crate::middleware::{Middleware, Next};
use crate::request::Request;
use crate::response::Response;
use crate::url::Url;

#[derive(Clone, Debug, Eq, PartialEq)]
struct StoredCookie {
    name: String,
    value: String,
    domain: String,
    host_only: bool,
    path: String,
    secure: bool,
    expires_at: Option<Instant>,
}

#[derive(Clone, Default)]
pub struct CookieJar {
    inner: Arc<Mutex<Vec<StoredCookie>>>,
}

impl CookieJar {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn get_cookie_header(&self, url: &Url) -> Option<String> {
        let mut cookies = self.inner.lock().unwrap_or_else(|err| err.into_inner());
        cookies.retain(|cookie| !cookie.is_expired());
        let mut pairs = Vec::new();
        for cookie in cookies.iter() {
            if cookie_matches(cookie, url) {
                pairs.push(format!("{}={}", cookie.name, cookie.value));
            }
        }
        if pairs.is_empty() {
            None
        } else {
            Some(pairs.join("; "))
        }
    }

    pub fn store_set_cookie(&self, url: &Url, header: &str) -> Result<()> {
        let cookie = parse_set_cookie(url, header)?;
        let mut cookies = self.inner.lock().unwrap_or_else(|err| err.into_inner());
        cookies.retain(|existing| {
            !(existing.name == cookie.name
                && existing.domain == cookie.domain
                && existing.path == cookie.path)
        });
        if !cookie.is_expired() {
            cookies.push(cookie);
        }
        Ok(())
    }

    pub fn store_set_cookies<'a, I>(&self, url: &Url, headers: I) -> Result<()>
    where
        I: IntoIterator<Item = &'a str>,
    {
        for header in headers {
            self.store_set_cookie(url, header)?;
        }
        Ok(())
    }
}

pub struct CookieMiddleware {
    jar: CookieJar,
}

impl CookieMiddleware {
    pub fn new(jar: CookieJar) -> Self {
        Self { jar }
    }

    pub fn jar(&self) -> &CookieJar {
        &self.jar
    }
}

impl StoredCookie {
    fn is_expired(&self) -> bool {
        self.expires_at
            .is_some_and(|expires_at| Instant::now() >= expires_at)
    }
}

#[async_trait]
impl Middleware for CookieMiddleware {
    async fn handle(&self, mut req: Request, next: Next<'_>) -> Result<Response> {
        if let Some(cookie_header) = self.jar.get_cookie_header(req.url()) {
            req.headers_mut().insert("cookie", cookie_header)?;
        }

        let url = req.url().clone();
        let response = next.run(req).await?;
        self.jar
            .store_set_cookies(&url, response.headers().get_all("set-cookie"))?;
        Ok(response)
    }
}

fn parse_set_cookie(url: &Url, header: &str) -> Result<StoredCookie> {
    let mut parts = header.split(';');
    let name_value = parts
        .next()
        .ok_or_else(|| Error::new(ErrorKind::Decode, "set-cookie header is empty"))?;
    let (name, value) = name_value
        .split_once('=')
        .ok_or_else(|| Error::new(ErrorKind::Decode, "set-cookie is missing name/value"))?;

    let mut domain = url.host().to_ascii_lowercase();
    let mut host_only = true;
    let mut path = default_cookie_path(url);
    let mut secure = false;
    let mut expires_at = None;

    for attribute in parts {
        let attribute = attribute.trim();
        if attribute.eq_ignore_ascii_case("secure") {
            secure = true;
            continue;
        }
        if let Some((key, value)) = attribute.split_once('=') {
            if key.eq_ignore_ascii_case("domain") && !value.trim().is_empty() {
                domain = value.trim().trim_start_matches('.').to_ascii_lowercase();
                host_only = false;
            } else if key.eq_ignore_ascii_case("path") && !value.trim().is_empty() {
                path = normalize_cookie_path(value.trim());
            } else if key.eq_ignore_ascii_case("max-age") {
                let seconds: i64 = value.trim().parse().map_err(|_| {
                    Error::new(ErrorKind::Decode, "invalid max-age attribute in set-cookie")
                })?;
                expires_at = if seconds <= 0 {
                    Some(Instant::now())
                } else {
                    Some(Instant::now() + Duration::from_secs(seconds as u64))
                };
            } else if key.eq_ignore_ascii_case("expires") {
                expires_at = Some(parse_cookie_expires(value.trim())?);
            }
        }
    }

    Ok(StoredCookie {
        name: name.trim().to_owned(),
        value: value.trim().to_owned(),
        domain,
        host_only,
        path,
        secure,
        expires_at,
    })
}

fn default_cookie_path(url: &Url) -> String {
    let path = url.path_and_query().split('?').next().unwrap_or("/");
    if path == "/" {
        return "/".to_owned();
    }
    match path.rsplit_once('/') {
        Some(("", _)) | None => "/".to_owned(),
        Some((prefix, _)) => format!("{prefix}/"),
    }
}

fn normalize_cookie_path(path: &str) -> String {
    if path.starts_with('/') {
        path.to_owned()
    } else {
        format!("/{path}")
    }
}

fn cookie_matches(cookie: &StoredCookie, url: &Url) -> bool {
    let host = url.host().to_ascii_lowercase();
    let request_path = url.path_and_query().split('?').next().unwrap_or("/");

    cookie_domain_matches(cookie, &host)
        && path_matches(&cookie.path, request_path)
        && (!cookie.secure || matches!(url.scheme(), "https" | "wss"))
}

fn domain_matches(cookie_domain: &str, host: &str) -> bool {
    host == cookie_domain || host.ends_with(&format!(".{cookie_domain}"))
}

fn cookie_domain_matches(cookie: &StoredCookie, host: &str) -> bool {
    if cookie.host_only {
        host == cookie.domain
    } else {
        domain_matches(&cookie.domain, host)
    }
}

fn path_matches(cookie_path: &str, request_path: &str) -> bool {
    request_path == cookie_path
        || request_path.starts_with(cookie_path)
        || (cookie_path.ends_with('/') && request_path.starts_with(cookie_path))
}

fn parse_cookie_expires(value: &str) -> Result<Instant> {
    let expires_at = httpdate::parse_http_date(value).map_err(|err| {
        Error::with_source(
            ErrorKind::Decode,
            "invalid expires attribute in set-cookie",
            err,
        )
    })?;
    match expires_at.duration_since(SystemTime::now()) {
        Ok(duration) => Ok(Instant::now() + duration),
        Err(_) => Ok(Instant::now()),
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::{CookieJar, parse_set_cookie};
    use crate::Url;

    #[test]
    fn stores_and_formats_cookie_header() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(&url, "session=abc; Path=/; Secure")
            .unwrap();

        assert_eq!(jar.get_cookie_header(&url).as_deref(), Some("session=abc"));
    }

    #[test]
    fn cookie_path_must_match_request_path() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/account/login").unwrap();
        jar.store_set_cookie(&url, "session=abc; Path=/account")
            .unwrap();

        assert_eq!(
            jar.get_cookie_header(&Url::parse("https://api.example.com/account/me").unwrap())
                .as_deref(),
            Some("session=abc")
        );
        assert_eq!(
            jar.get_cookie_header(&Url::parse("https://api.example.com/admin").unwrap()),
            None
        );
    }

    #[test]
    fn parses_domain_and_secure_attributes() {
        let url = Url::parse("https://api.example.com/users").unwrap();
        let cookie =
            parse_set_cookie(&url, "theme=dark; Domain=.example.com; Path=/; Secure").unwrap();

        assert_eq!(cookie.domain, "example.com");
        assert!(!cookie.host_only);
        assert_eq!(cookie.path, "/");
        assert!(cookie.secure);
    }

    #[test]
    fn host_only_cookie_does_not_match_subdomains() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(&url, "session=abc; Path=/").unwrap();

        assert_eq!(
            jar.get_cookie_header(&Url::parse("https://api.example.com/me").unwrap())
                .as_deref(),
            Some("session=abc")
        );
        assert_eq!(
            jar.get_cookie_header(&Url::parse("https://sub.api.example.com/me").unwrap()),
            None
        );
    }

    #[test]
    fn domain_cookie_matches_subdomains() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(&url, "session=abc; Domain=example.com; Path=/")
            .unwrap();

        assert_eq!(
            jar.get_cookie_header(&Url::parse("https://sub.example.com/me").unwrap())
                .as_deref(),
            Some("session=abc")
        );
    }

    #[test]
    fn max_age_zero_removes_existing_cookie() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(&url, "session=abc; Path=/").unwrap();
        jar.store_set_cookie(&url, "session=gone; Path=/; Max-Age=0")
            .unwrap();

        assert_eq!(jar.get_cookie_header(&url), None);
    }

    #[test]
    fn expired_cookie_is_not_returned() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(&url, "session=abc; Path=/; Max-Age=1")
            .unwrap();
        std::thread::sleep(Duration::from_secs(2));

        assert_eq!(jar.get_cookie_header(&url), None);
    }

    #[test]
    fn future_expires_cookie_is_returned() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(
            &url,
            "session=abc; Path=/; Expires=Wed, 01 Jan 3000 00:00:00 GMT",
        )
        .unwrap();

        assert_eq!(jar.get_cookie_header(&url).as_deref(), Some("session=abc"));
    }

    #[test]
    fn past_expires_cookie_is_not_returned() {
        let jar = CookieJar::new();
        let url = Url::parse("https://api.example.com/users").unwrap();
        jar.store_set_cookie(
            &url,
            "session=abc; Path=/; Expires=Sat, 01 Jan 2000 00:00:00 GMT",
        )
        .unwrap();

        assert_eq!(jar.get_cookie_header(&url), None);
    }
}