use std::time::Duration;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use http::header;
use serde::{Serialize, de::DeserializeOwned};
pub(crate) const DEFAULT_COOKIE_MAX_AGE: Duration = Duration::from_hours(9600);
pub(crate) fn encode_payload<T: Serialize>(
value: &T,
) -> Result<Vec<u8>, ciborium::ser::Error<std::io::Error>> {
let mut bytes = Vec::with_capacity(128);
ciborium::into_writer(value, &mut bytes)?;
Ok(bytes)
}
pub(crate) fn decode_payload<T: DeserializeOwned>(
bytes: &[u8],
) -> Result<T, ciborium::de::Error<std::io::Error>> {
ciborium::from_reader(bytes)
}
pub(crate) const DEFAULT_LOGIN_COOKIE_PREFIX: &str = "huskarl_login";
pub(crate) const MAX_OAUTH_STATE_LEN: usize = 256;
#[must_use]
pub fn is_valid_oauth_state(state: &str) -> bool {
!state.is_empty()
&& state.len() <= MAX_OAUTH_STATE_LEN
&& state
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
}
#[must_use]
pub fn login_state_cookie_name(state: &str, secure: bool, path: &str, prefix: &str) -> String {
debug_assert!(
is_valid_oauth_state(state),
"login_state_cookie_name called with state that is not URL-safe base64url"
);
let security_prefix = if secure {
if path == "/" { "__Host-" } else { "__Secure-" }
} else {
""
};
format!("{security_prefix}{prefix}_{state}")
}
#[must_use]
pub fn cookie_attrs(secure: bool, path: &str) -> String {
let secure = if secure { "; Secure" } else { "" };
format!("HttpOnly; SameSite=Lax; Path={path}{secure}")
}
pub(crate) const KID_COOKIE_SUFFIX: &str = ".kid";
#[must_use]
pub(crate) fn kid_cookie_name(base_name: &str) -> String {
format!("{base_name}{KID_COOKIE_SUFFIX}")
}
pub(crate) fn get_kid_cookie(headers: &http::HeaderMap, base_name: &str) -> Option<String> {
let name = kid_cookie_name(base_name);
let encoded = get_cookie(headers, &name)?;
let bytes = URL_SAFE_NO_PAD.decode(encoded).ok()?;
String::from_utf8(bytes).ok()
}
#[must_use]
pub(crate) fn encode_kid(identity: &str) -> String {
URL_SAFE_NO_PAD.encode(identity.as_bytes())
}
pub fn get_cookie<'a>(headers: &'a http::HeaderMap, name: &str) -> Option<&'a str> {
for value in headers.get_all(header::COOKIE) {
let Ok(s) = value.to_str() else { continue };
for pair in s.split(';') {
if let Some((k, v)) = pair.trim().split_once('=')
&& k.trim() == name
{
return Some(v.trim());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_cookie_present() {
let mut headers = http::HeaderMap::new();
headers.insert(header::COOKIE, "foo=bar".parse().unwrap());
assert_eq!(get_cookie(&headers, "foo"), Some("bar"));
}
#[test]
fn get_cookie_missing() {
let mut headers = http::HeaderMap::new();
headers.insert(header::COOKIE, "foo=bar".parse().unwrap());
assert_eq!(get_cookie(&headers, "baz"), None);
}
#[test]
fn get_cookie_multiple_pairs() {
let mut headers = http::HeaderMap::new();
headers.insert(header::COOKIE, "a=1; b=2; c=3".parse().unwrap());
assert_eq!(get_cookie(&headers, "a"), Some("1"));
assert_eq!(get_cookie(&headers, "b"), Some("2"));
assert_eq!(get_cookie(&headers, "c"), Some("3"));
}
#[test]
fn get_cookie_whitespace_trimmed() {
let mut headers = http::HeaderMap::new();
headers.insert(header::COOKIE, " foo = bar ".parse().unwrap());
assert_eq!(get_cookie(&headers, "foo"), Some("bar"));
}
#[test]
fn get_cookie_empty_headers() {
let headers = http::HeaderMap::new();
assert_eq!(get_cookie(&headers, "foo"), None);
}
#[test]
fn get_cookie_multiple_cookie_headers() {
let mut headers = http::HeaderMap::new();
headers.append(header::COOKIE, "a=1".parse().unwrap());
headers.append(header::COOKIE, "b=2".parse().unwrap());
assert_eq!(get_cookie(&headers, "a"), Some("1"));
assert_eq!(get_cookie(&headers, "b"), Some("2"));
}
#[test]
fn get_cookie_value_with_equals() {
let mut headers = http::HeaderMap::new();
headers.insert(header::COOKIE, "token=abc=def".parse().unwrap());
assert_eq!(get_cookie(&headers, "token"), Some("abc=def"));
}
#[test]
fn cookie_name_secure_root_uses_host_prefix() {
let name = login_state_cookie_name("abc123", true, "/", DEFAULT_LOGIN_COOKIE_PREFIX);
assert!(name.starts_with("__Host-"));
}
#[test]
fn cookie_name_secure_subpath_uses_secure_prefix() {
let name = login_state_cookie_name("abc123", true, "/app", DEFAULT_LOGIN_COOKIE_PREFIX);
assert!(name.starts_with("__Secure-"));
}
#[test]
fn cookie_name_insecure_no_prefix() {
let name = login_state_cookie_name("abc123", false, "/", DEFAULT_LOGIN_COOKIE_PREFIX);
assert!(!name.starts_with("__"));
}
#[test]
fn cookie_name_contains_state() {
let name = login_state_cookie_name("mystate", true, "/", DEFAULT_LOGIN_COOKIE_PREFIX);
assert!(name.contains("mystate"));
}
#[test]
fn state_accepts_alphanumeric_and_url_safe_chars() {
assert!(is_valid_oauth_state("abc123"));
assert!(is_valid_oauth_state("AbC-_xyz"));
}
#[test]
fn state_rejects_empty() {
assert!(!is_valid_oauth_state(""));
}
#[test]
fn state_rejects_overly_long() {
let long = "a".repeat(MAX_OAUTH_STATE_LEN + 1);
assert!(!is_valid_oauth_state(&long));
}
#[test]
fn state_rejects_separators_and_specials() {
for s in [
"abc;def", "abc=def", "abc def", "abc\nxyz", "abc/def", "abc+def", "abc.def",
] {
assert!(!is_valid_oauth_state(s), "expected reject: {s:?}");
}
}
#[test]
fn state_rejects_non_ascii() {
assert!(!is_valid_oauth_state("café"));
}
#[test]
fn kid_cookie_name_suffixes_base() {
assert_eq!(kid_cookie_name("huskarl_session"), "huskarl_session.kid");
}
#[test]
fn get_kid_cookie_decodes_present_value() {
let mut headers = http::HeaderMap::new();
let encoded = encode_kid("arn:aws:kms:us-east-1:111:key/abc");
headers.insert(
header::COOKIE,
format!("huskarl_session.kid={encoded}").parse().unwrap(),
);
assert_eq!(
get_kid_cookie(&headers, "huskarl_session").as_deref(),
Some("arn:aws:kms:us-east-1:111:key/abc")
);
}
#[test]
fn get_kid_cookie_absent_returns_none() {
let headers = http::HeaderMap::new();
assert_eq!(get_kid_cookie(&headers, "huskarl_session"), None);
}
#[test]
fn get_kid_cookie_invalid_base64_returns_none() {
let mut headers = http::HeaderMap::new();
headers.insert(
header::COOKIE,
"huskarl_session.kid=!!!notbase64!!!".parse().unwrap(),
);
assert_eq!(get_kid_cookie(&headers, "huskarl_session"), None);
}
#[test]
fn get_kid_cookie_invalid_utf8_returns_none() {
let mut headers = http::HeaderMap::new();
let bad = URL_SAFE_NO_PAD.encode([0xff_u8, 0xfe, 0xfd]);
headers.insert(
header::COOKIE,
format!("huskarl_session.kid={bad}").parse().unwrap(),
);
assert_eq!(get_kid_cookie(&headers, "huskarl_session"), None);
}
}