use axum::http::HeaderMap;
use crate::config::SessionCookieSecurePolicy;
pub fn is_https(policy: SessionCookieSecurePolicy, headers: &HeaderMap) -> bool {
match policy {
SessionCookieSecurePolicy::Always => true,
SessionCookieSecurePolicy::Never => false,
SessionCookieSecurePolicy::Auto => headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(first_csv)
.is_some_and(|s| s.eq_ignore_ascii_case("https")),
}
}
pub fn host(headers: &HeaderMap) -> Option<String> {
headers
.get("x-forwarded-host")
.or_else(|| headers.get(axum::http::header::HOST))
.and_then(|v| v.to_str().ok())
.map(|s| first_csv(s).to_owned())
}
pub fn origin(policy: SessionCookieSecurePolicy, headers: &HeaderMap) -> Option<String> {
let scheme = if is_https(policy, headers) {
"https"
} else {
"http"
};
host(headers).map(|h| format!("{scheme}://{h}"))
}
fn first_csv(s: &str) -> &str {
s.split(',').next().unwrap_or(s).trim()
}
#[cfg(test)]
mod tests {
use super::*;
fn headers(pairs: &[(&'static str, &'static str)]) -> HeaderMap {
let mut h = HeaderMap::new();
for (k, v) in pairs {
h.insert(*k, v.parse().unwrap());
}
h
}
#[test]
fn is_https_follows_policy() {
let xfp = headers(&[("x-forwarded-proto", "https")]);
let plain = HeaderMap::new();
assert!(is_https(SessionCookieSecurePolicy::Always, &plain));
assert!(!is_https(SessionCookieSecurePolicy::Never, &xfp));
assert!(is_https(SessionCookieSecurePolicy::Auto, &xfp));
assert!(!is_https(SessionCookieSecurePolicy::Auto, &plain));
}
#[test]
fn host_prefers_forwarded() {
let h = headers(&[
("host", "internal:8080"),
("x-forwarded-host", "dns.example.com"),
]);
assert_eq!(host(&h).as_deref(), Some("dns.example.com"));
let h = headers(&[("host", "localhost:8080")]);
assert_eq!(host(&h).as_deref(), Some("localhost:8080"));
}
#[test]
fn origin_combines_scheme_and_host() {
let h = headers(&[
("host", "internal"),
("x-forwarded-host", "dns.example.com"),
("x-forwarded-proto", "https"),
]);
assert_eq!(
origin(SessionCookieSecurePolicy::Auto, &h).as_deref(),
Some("https://dns.example.com")
);
let h = headers(&[("host", "127.0.0.1:8080")]);
assert_eq!(
origin(SessionCookieSecurePolicy::Never, &h).as_deref(),
Some("http://127.0.0.1:8080")
);
}
#[test]
fn first_csv_takes_first_hop() {
assert_eq!(first_csv("https, http"), "https");
assert_eq!(first_csv("a.example , b"), "a.example");
assert_eq!(first_csv("solo"), "solo");
}
}