use axum::{
extract::{ConnectInfo, FromRequestParts},
http::{header, request::Parts, HeaderMap},
};
use std::convert::Infallible;
use std::future;
use std::net::{IpAddr, SocketAddr};
pub fn extract_access_token(headers: &HeaderMap, cookie_name: &str) -> Option<String> {
if let Some(auth_header) = headers.get(header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
return Some(token.to_string());
}
}
}
extract_cookie(headers, cookie_name)
}
const MAX_COOKIE_VALUE_LEN: usize = 4096;
const MAX_COOKIE_HEADER_LEN: usize = 16 * 1024;
pub fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
for cookie_header in headers.get_all(header::COOKIE).iter() {
let Ok(cookies) = cookie_header.to_str() else {
continue;
};
if cookies.len() > MAX_COOKIE_HEADER_LEN {
return None;
}
for cookie in cookies.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", name)) {
if value.len() > MAX_COOKIE_VALUE_LEN {
return None;
}
return Some(value.to_string());
}
}
}
None
}
#[derive(Debug, Clone, Copy)]
pub struct PeerIp(pub Option<IpAddr>);
impl<S> FromRequestParts<S> for PeerIp
where
S: Send + Sync,
{
type Rejection = Infallible;
fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
let ip = parts
.extensions
.get::<ConnectInfo<SocketAddr>>()
.map(|info| info.0.ip());
future::ready(Ok(PeerIp(ip)))
}
}
const MAX_IP_LEN: usize = 64;
pub fn extract_client_ip(headers: &HeaderMap, trust_proxy: bool) -> Option<String> {
if !trust_proxy {
return None;
}
headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| {
let first_ip = s.split(',').next()?.trim();
if first_ip.len() > MAX_IP_LEN {
return None;
}
let parsed = first_ip.parse::<IpAddr>().ok()?;
Some(parsed.to_string())
})
}
pub fn extract_client_ip_with_fallback(
headers: &HeaderMap,
trust_proxy: bool,
peer_ip: Option<IpAddr>,
) -> Option<String> {
extract_client_ip(headers, trust_proxy).or_else(|| peer_ip.map(|ip| ip.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn test_extract_access_token_from_bearer_header() {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer test-access-token"),
);
let token = extract_access_token(&headers, "cedros_access");
assert_eq!(token, Some("test-access-token".to_string()));
}
#[test]
fn test_extract_access_token_from_cookie() {
let mut headers = HeaderMap::new();
headers.insert(
header::COOKIE,
HeaderValue::from_static("cedros_access=cookie-token; other=value"),
);
let token = extract_access_token(&headers, "cedros_access");
assert_eq!(token, Some("cookie-token".to_string()));
}
#[test]
fn test_extract_access_token_prefers_header() {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer header-token"),
);
headers.insert(
header::COOKIE,
HeaderValue::from_static("cedros_access=cookie-token"),
);
let token = extract_access_token(&headers, "cedros_access");
assert_eq!(token, Some("header-token".to_string()));
}
#[test]
fn test_extract_access_token_none_when_missing() {
let headers = HeaderMap::new();
let token = extract_access_token(&headers, "cedros_access");
assert_eq!(token, None);
}
#[test]
fn test_extract_access_token_none_for_invalid_bearer() {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Basic credentials"),
);
let token = extract_access_token(&headers, "cedros_access");
assert_eq!(token, None);
}
#[test]
fn test_extract_cookie_single() {
let mut headers = HeaderMap::new();
headers.insert(
header::COOKIE,
HeaderValue::from_static("my_cookie=my_value"),
);
let value = extract_cookie(&headers, "my_cookie");
assert_eq!(value, Some("my_value".to_string()));
}
#[test]
fn test_extract_cookie_multiple() {
let mut headers = HeaderMap::new();
headers.insert(
header::COOKIE,
HeaderValue::from_static("first=1; second=2; third=3"),
);
assert_eq!(extract_cookie(&headers, "first"), Some("1".to_string()));
assert_eq!(extract_cookie(&headers, "second"), Some("2".to_string()));
assert_eq!(extract_cookie(&headers, "third"), Some("3".to_string()));
}
#[test]
fn test_extract_cookie_multiple_headers() {
let mut headers = HeaderMap::new();
headers.append(header::COOKIE, HeaderValue::from_static("a=1"));
headers.append(header::COOKIE, HeaderValue::from_static("b=2"));
assert_eq!(extract_cookie(&headers, "a"), Some("1".to_string()));
assert_eq!(extract_cookie(&headers, "b"), Some("2".to_string()));
}
#[test]
fn test_extract_cookie_with_leading_space() {
let mut headers = HeaderMap::new();
headers.insert(
header::COOKIE,
HeaderValue::from_static("first=1; spaced=value; other=test"),
);
let value = extract_cookie(&headers, "spaced");
assert_eq!(value, Some("value".to_string()));
}
#[test]
fn test_extract_cookie_not_found() {
let mut headers = HeaderMap::new();
headers.insert(header::COOKIE, HeaderValue::from_static("other=value"));
let value = extract_cookie(&headers, "missing");
assert_eq!(value, None);
}
#[test]
fn test_extract_cookie_empty_headers() {
let headers = HeaderMap::new();
let value = extract_cookie(&headers, "any");
assert_eq!(value, None);
}
#[test]
fn test_extract_cookie_rejects_oversized_header() {
let mut headers = HeaderMap::new();
let oversized = "a".repeat(17 * 1024); headers.insert(
header::COOKIE,
HeaderValue::from_str(&format!("my_cookie={}", oversized)).unwrap(),
);
assert_eq!(extract_cookie(&headers, "my_cookie"), None);
}
#[test]
fn test_extract_client_ip_valid_ipv4() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", HeaderValue::from_static("192.168.1.1"));
assert_eq!(
extract_client_ip(&headers, true),
Some("192.168.1.1".to_string())
);
}
#[test]
fn test_extract_client_ip_valid_ipv6() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", HeaderValue::from_static("::1"));
assert_eq!(extract_client_ip(&headers, true), Some("::1".to_string()));
}
#[test]
fn test_extract_client_ip_multiple_takes_first() {
let mut headers = HeaderMap::new();
headers.insert(
"x-forwarded-for",
HeaderValue::from_static("10.0.0.1, 192.168.1.1, 172.16.0.1"),
);
assert_eq!(
extract_client_ip(&headers, true),
Some("10.0.0.1".to_string())
);
}
#[test]
fn test_extract_client_ip_rejects_invalid() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", HeaderValue::from_static("not-an-ip"));
assert_eq!(extract_client_ip(&headers, true), None);
}
#[test]
fn test_extract_client_ip_rejects_script() {
let mut headers = HeaderMap::new();
headers.insert(
"x-forwarded-for",
HeaderValue::from_static("<script>alert(1)</script>"),
);
assert_eq!(extract_client_ip(&headers, true), None);
}
#[test]
fn test_extract_client_ip_with_fallback_uses_peer_ip() {
let headers = HeaderMap::new();
let peer_ip = Some("203.0.113.10".parse::<IpAddr>().unwrap());
assert_eq!(
extract_client_ip_with_fallback(&headers, false, peer_ip),
Some("203.0.113.10".to_string())
);
}
#[test]
fn test_extract_client_ip_with_fallback_prefers_proxy() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", HeaderValue::from_static("192.0.2.1"));
let peer_ip = Some("203.0.113.10".parse::<IpAddr>().unwrap());
assert_eq!(
extract_client_ip_with_fallback(&headers, true, peer_ip),
Some("192.0.2.1".to_string())
);
}
#[test]
fn test_extract_client_ip_respects_trust_proxy_false() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", HeaderValue::from_static("192.168.1.1"));
assert_eq!(extract_client_ip(&headers, false), None);
}
#[test]
fn test_extract_client_ip_none_when_missing() {
let headers = HeaderMap::new();
assert_eq!(extract_client_ip(&headers, true), None);
}
}