use cedar_policy::{Context, RestrictedExpression};
use chrono::{DateTime, Utc};
use super::error::AuthzError;
pub trait BuildRequestContext: Send + Sync {
fn to_cedar_context(&self) -> Result<Context, AuthzError>;
}
pub struct NoContext;
impl BuildRequestContext for NoContext {
fn to_cedar_context(&self) -> Result<Context, AuthzError> {
Ok(Context::empty())
}
}
pub struct StandardRequestContext {
pub mfa_verified: bool,
pub ip_address: Option<std::net::IpAddr>,
pub timestamp: Option<DateTime<Utc>>,
}
impl StandardRequestContext {
pub fn new(mfa_verified: bool, ip_address: Option<std::net::IpAddr>) -> Self {
Self {
mfa_verified,
ip_address,
timestamp: None,
}
}
pub fn at(
mfa_verified: bool,
ip_address: Option<std::net::IpAddr>,
timestamp: DateTime<Utc>,
) -> Self {
Self {
mfa_verified,
ip_address,
timestamp: Some(timestamp),
}
}
}
impl BuildRequestContext for StandardRequestContext {
fn to_cedar_context(&self) -> Result<Context, AuthzError> {
let mut pairs: Vec<(String, RestrictedExpression)> = vec![(
"mfa_verified".to_string(),
RestrictedExpression::new_bool(self.mfa_verified),
)];
if let Some(ip) = &self.ip_address {
pairs.push((
"ip_address".to_string(),
RestrictedExpression::new_string(ip.to_string()),
));
}
if let Some(ts) = &self.timestamp {
pairs.push((
"timestamp".to_string(),
RestrictedExpression::new_string(ts.to_rfc3339()),
));
}
Context::from_pairs(pairs).map_err(|e| AuthzError::Context(format!("{e:?}")))
}
}
pub fn ip_from_headers(headers: &axum::http::HeaderMap) -> Option<std::net::IpAddr> {
let raw = headers
.get("X-Real-IP")
.or_else(|| headers.get("X-Forwarded-For"))
.and_then(|v| v.to_str().ok())?;
raw.split(',').next().and_then(|s| s.trim().parse().ok())
}
#[derive(Debug, Clone, Default)]
pub struct TrustedProxies {
proxies: std::sync::Arc<[std::net::IpAddr]>,
}
impl TrustedProxies {
pub fn new(addrs: impl IntoIterator<Item = std::net::IpAddr>) -> Self {
Self {
proxies: addrs.into_iter().collect(),
}
}
pub fn loopback_only() -> Self {
Self::new([
std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST),
std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST),
])
}
pub fn trusts(&self, peer: std::net::IpAddr) -> bool {
self.proxies.contains(&peer)
}
}
pub fn ip_from_headers_trusted(
headers: &axum::http::HeaderMap,
peer: std::net::IpAddr,
trusted: &TrustedProxies,
) -> std::net::IpAddr {
if trusted.trusts(peer)
&& let Some(forwarded) = ip_from_headers(headers)
{
return forwarded;
}
peer
}
impl<T: BuildRequestContext> BuildRequestContext for std::sync::Arc<T> {
fn to_cedar_context(&self) -> Result<Context, AuthzError> {
self.as_ref().to_cedar_context()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[test]
fn ip_from_headers_reads_x_real_ip() {
let mut headers = HeaderMap::new();
headers.insert("X-Real-IP", "203.0.113.5".parse().unwrap());
assert_eq!(
ip_from_headers(&headers),
Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)))
);
}
#[test]
fn ip_from_headers_falls_back_to_x_forwarded_for_first_entry() {
let mut headers = HeaderMap::new();
headers.insert(
"X-Forwarded-For",
"198.51.100.10, 10.0.0.1".parse().unwrap(),
);
assert_eq!(
ip_from_headers(&headers),
Some(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10)))
);
}
#[test]
fn ip_from_headers_returns_none_when_neither_header_present() {
let headers = HeaderMap::new();
assert!(ip_from_headers(&headers).is_none());
}
#[test]
fn trusted_proxies_loopback_only_contains_v4_and_v6_loopback() {
let trusted = TrustedProxies::loopback_only();
assert!(trusted.trusts(IpAddr::V4(Ipv4Addr::LOCALHOST)));
assert!(trusted.trusts(IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn trusted_proxies_rejects_unrelated_peer() {
let trusted = TrustedProxies::loopback_only();
assert!(!trusted.trusts(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1))));
}
#[test]
fn trusted_proxies_empty_set_trusts_nothing() {
let trusted = TrustedProxies::new([]);
assert!(!trusted.trusts(IpAddr::V4(Ipv4Addr::LOCALHOST)));
}
}