use crate::{Permission, PermissionContext};
use async_trait::async_trait;
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Debug, Clone)]
pub struct IpWhitelistPermission {
pub allowed_ips: Vec<IpAddr>,
pub allowed_cidrs: Vec<CidrRange>,
pub deny_on_error: bool,
pub trusted_proxies: Vec<IpAddr>,
}
impl IpWhitelistPermission {
pub fn new() -> Self {
Self {
allowed_ips: Vec::new(),
allowed_cidrs: Vec::new(),
deny_on_error: true,
trusted_proxies: Vec::new(),
}
}
pub fn add_ip(mut self, ip: impl AsRef<str>) -> Self {
if let Ok(addr) = IpAddr::from_str(ip.as_ref()) {
self.allowed_ips.push(addr);
}
self
}
pub fn add_cidr(mut self, cidr: impl AsRef<str>) -> Self {
if let Ok(range) = CidrRange::from_str(cidr.as_ref()) {
self.allowed_cidrs.push(range);
}
self
}
pub fn deny_on_error(mut self, deny: bool) -> Self {
self.deny_on_error = deny;
self
}
pub fn add_trusted_proxy(mut self, ip: impl AsRef<str>) -> Self {
if let Ok(addr) = IpAddr::from_str(ip.as_ref()) {
self.trusted_proxies.push(addr);
}
self
}
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
self.allowed_ips.contains(ip) || self.allowed_cidrs.iter().any(|cidr| cidr.contains(ip))
}
fn is_trusted_proxy(&self, context: &PermissionContext) -> bool {
if self.trusted_proxies.is_empty() {
return false;
}
if let Some(remote_addr) = context.request.remote_addr {
self.trusted_proxies.contains(&remote_addr.ip())
} else {
false
}
}
fn extract_client_ip(&self, context: &PermissionContext) -> Option<IpAddr> {
if self.is_trusted_proxy(context) {
if let Some(forwarded) = context.request.headers.get("x-forwarded-for")
&& let Ok(forwarded_str) = forwarded.to_str()
&& let Some(first_ip) = forwarded_str.split(',').next()
&& let Ok(ip) = IpAddr::from_str(first_ip.trim())
{
return Some(ip);
}
if let Some(real_ip) = context.request.headers.get("x-real-ip")
&& let Ok(real_ip_str) = real_ip.to_str()
&& let Ok(ip) = IpAddr::from_str(real_ip_str.trim())
{
return Some(ip);
}
}
if let Some(remote_addr) = context.request.remote_addr {
return Some(remote_addr.ip());
}
None
}
}
impl Default for IpWhitelistPermission {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Permission for IpWhitelistPermission {
async fn has_permission(&self, context: &PermissionContext<'_>) -> bool {
match self.extract_client_ip(context) {
Some(ip) => self.is_allowed(&ip),
None => !self.deny_on_error,
}
}
}
#[derive(Debug, Clone)]
pub struct IpBlacklistPermission {
pub blocked_ips: Vec<IpAddr>,
pub blocked_cidrs: Vec<CidrRange>,
pub allow_on_error: bool,
pub trusted_proxies: Vec<IpAddr>,
}
impl IpBlacklistPermission {
pub fn new() -> Self {
Self {
blocked_ips: Vec::new(),
blocked_cidrs: Vec::new(),
allow_on_error: false,
trusted_proxies: Vec::new(),
}
}
pub fn add_ip(mut self, ip: impl AsRef<str>) -> Self {
if let Ok(addr) = IpAddr::from_str(ip.as_ref()) {
self.blocked_ips.push(addr);
}
self
}
pub fn add_cidr(mut self, cidr: impl AsRef<str>) -> Self {
if let Ok(range) = CidrRange::from_str(cidr.as_ref()) {
self.blocked_cidrs.push(range);
}
self
}
pub fn allow_on_error(mut self, allow: bool) -> Self {
self.allow_on_error = allow;
self
}
pub fn add_trusted_proxy(mut self, ip: impl AsRef<str>) -> Self {
if let Ok(addr) = IpAddr::from_str(ip.as_ref()) {
self.trusted_proxies.push(addr);
}
self
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
self.blocked_ips.contains(ip) || self.blocked_cidrs.iter().any(|cidr| cidr.contains(ip))
}
fn is_trusted_proxy(&self, context: &PermissionContext) -> bool {
if self.trusted_proxies.is_empty() {
return false;
}
if let Some(remote_addr) = context.request.remote_addr {
self.trusted_proxies.contains(&remote_addr.ip())
} else {
false
}
}
fn extract_client_ip(&self, context: &PermissionContext) -> Option<IpAddr> {
if self.is_trusted_proxy(context) {
if let Some(forwarded) = context.request.headers.get("x-forwarded-for")
&& let Ok(forwarded_str) = forwarded.to_str()
&& let Some(first_ip) = forwarded_str.split(',').next()
&& let Ok(ip) = IpAddr::from_str(first_ip.trim())
{
return Some(ip);
}
if let Some(real_ip) = context.request.headers.get("x-real-ip")
&& let Ok(real_ip_str) = real_ip.to_str()
&& let Ok(ip) = IpAddr::from_str(real_ip_str.trim())
{
return Some(ip);
}
}
if let Some(remote_addr) = context.request.remote_addr {
return Some(remote_addr.ip());
}
None
}
}
impl Default for IpBlacklistPermission {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Permission for IpBlacklistPermission {
async fn has_permission(&self, context: &PermissionContext<'_>) -> bool {
match self.extract_client_ip(context) {
Some(ip) => !self.is_blocked(&ip),
None => self.allow_on_error,
}
}
}
#[derive(Debug, Clone)]
pub struct CidrRange {
pub network: IpAddr,
pub prefix_len: u8,
}
impl CidrRange {
pub fn new(network: IpAddr, prefix_len: u8) -> Self {
Self {
network,
prefix_len,
}
}
pub fn contains(&self, ip: &IpAddr) -> bool {
match (self.network, ip) {
(IpAddr::V4(net), IpAddr::V4(addr)) => {
let net_u32 = u32::from_be_bytes(net.octets());
let addr_u32 = u32::from_be_bytes(addr.octets());
let mask = if self.prefix_len == 0 {
0
} else {
!0u32 << (32 - self.prefix_len)
};
(net_u32 & mask) == (addr_u32 & mask)
}
(IpAddr::V6(net), IpAddr::V6(addr)) => {
let net_u128 = u128::from_be_bytes(net.octets());
let addr_u128 = u128::from_be_bytes(addr.octets());
let mask = if self.prefix_len == 0 {
0
} else {
!0u128 << (128 - self.prefix_len)
};
(net_u128 & mask) == (addr_u128 & mask)
}
_ => false,
}
}
}
impl FromStr for CidrRange {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 {
return Err("Invalid CIDR format".to_string());
}
let network = IpAddr::from_str(parts[0]).map_err(|e| e.to_string())?;
let prefix_len = parts[1].parse::<u8>().map_err(|e| e.to_string())?;
match network {
IpAddr::V4(_) if prefix_len > 32 => Err("IPv4 prefix length must be <= 32".to_string()),
IpAddr::V6(_) if prefix_len > 128 => {
Err("IPv6 prefix length must be <= 128".to_string())
}
_ => Ok(CidrRange::new(network, prefix_len)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method};
use reinhardt_http::Request;
use rstest::rstest;
use std::net::SocketAddr;
#[rstest]
fn test_cidr_range_from_str() {
let cidr = CidrRange::from_str("192.168.1.0/24").unwrap();
let cidr6 = CidrRange::from_str("2001:db8::/32").unwrap();
assert_eq!(cidr.prefix_len, 24);
assert_eq!(cidr6.prefix_len, 32);
}
#[rstest]
fn test_cidr_range_invalid() {
assert!(CidrRange::from_str("192.168.1.0").is_err());
assert!(CidrRange::from_str("192.168.1.0/33").is_err());
assert!(CidrRange::from_str("invalid/24").is_err());
}
#[rstest]
fn test_cidr_contains_ipv4() {
let cidr = CidrRange::from_str("192.168.1.0/24").unwrap();
let ip1 = IpAddr::from_str("192.168.1.1").unwrap();
let ip2 = IpAddr::from_str("192.168.1.255").unwrap();
let ip3 = IpAddr::from_str("192.168.2.1").unwrap();
assert!(cidr.contains(&ip1));
assert!(cidr.contains(&ip2));
assert!(!cidr.contains(&ip3));
}
#[rstest]
fn test_cidr_contains_ipv6() {
let cidr = CidrRange::from_str("2001:db8::/32").unwrap();
let ip1 = IpAddr::from_str("2001:db8::1").unwrap();
let ip2 = IpAddr::from_str("2001:db8:ffff::1").unwrap();
let ip3 = IpAddr::from_str("2001:db9::1").unwrap();
assert!(cidr.contains(&ip1));
assert!(cidr.contains(&ip2));
assert!(!cidr.contains(&ip3));
}
#[rstest]
fn test_whitelist_permission_creation() {
let permission = IpWhitelistPermission::new();
assert_eq!(permission.allowed_ips.len(), 0);
assert_eq!(permission.allowed_cidrs.len(), 0);
assert!(permission.deny_on_error);
assert_eq!(permission.trusted_proxies.len(), 0);
}
#[rstest]
fn test_whitelist_add_ip() {
let permission = IpWhitelistPermission::new()
.add_ip("192.168.1.1")
.add_ip("10.0.0.1");
assert_eq!(permission.allowed_ips.len(), 2);
}
#[rstest]
fn test_whitelist_add_cidr() {
let permission = IpWhitelistPermission::new()
.add_cidr("192.168.1.0/24")
.add_cidr("10.0.0.0/8");
assert_eq!(permission.allowed_cidrs.len(), 2);
}
#[rstest]
fn test_whitelist_is_allowed() {
let permission = IpWhitelistPermission::new()
.add_ip("192.168.1.1")
.add_cidr("10.0.0.0/24");
let ip1 = IpAddr::from_str("192.168.1.1").unwrap();
let ip2 = IpAddr::from_str("10.0.0.100").unwrap();
let ip3 = IpAddr::from_str("172.16.0.1").unwrap();
assert!(permission.is_allowed(&ip1));
assert!(permission.is_allowed(&ip2));
assert!(!permission.is_allowed(&ip3));
}
#[rstest]
#[tokio::test]
async fn test_whitelist_ignores_forwarded_header_without_trusted_proxy() {
let permission = IpWhitelistPermission::new().add_ip("192.168.1.1");
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
let remote_addr: SocketAddr = "10.0.0.99:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.headers(headers)
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(!permission.has_permission(&context).await);
}
#[rstest]
#[tokio::test]
async fn test_whitelist_trusts_forwarded_header_from_trusted_proxy() {
let permission = IpWhitelistPermission::new()
.add_trusted_proxy("10.0.0.1")
.add_ip("192.168.1.1");
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
let remote_addr: SocketAddr = "10.0.0.1:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.headers(headers)
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(permission.has_permission(&context).await);
}
#[rstest]
#[tokio::test]
async fn test_whitelist_uses_remote_addr_directly() {
let permission = IpWhitelistPermission::new().add_ip("192.168.1.1");
let remote_addr: SocketAddr = "192.168.1.1:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(permission.has_permission(&context).await);
}
#[rstest]
fn test_blacklist_permission_creation() {
let permission = IpBlacklistPermission::new();
assert_eq!(permission.blocked_ips.len(), 0);
assert_eq!(permission.blocked_cidrs.len(), 0);
assert!(!permission.allow_on_error);
assert_eq!(permission.trusted_proxies.len(), 0);
}
#[rstest]
fn test_blacklist_add_ip() {
let permission = IpBlacklistPermission::new()
.add_ip("192.168.1.100")
.add_ip("10.0.0.100");
assert_eq!(permission.blocked_ips.len(), 2);
}
#[rstest]
fn test_blacklist_is_blocked() {
let permission = IpBlacklistPermission::new()
.add_ip("192.168.1.100")
.add_cidr("10.0.0.0/24");
let ip1 = IpAddr::from_str("192.168.1.100").unwrap();
let ip2 = IpAddr::from_str("10.0.0.50").unwrap();
let ip3 = IpAddr::from_str("172.16.0.1").unwrap();
assert!(permission.is_blocked(&ip1));
assert!(permission.is_blocked(&ip2));
assert!(!permission.is_blocked(&ip3));
}
#[rstest]
#[tokio::test]
async fn test_blacklist_ignores_forwarded_header_without_trusted_proxy() {
let permission = IpBlacklistPermission::new().add_ip("192.168.1.100");
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "10.0.0.1".parse().unwrap());
let remote_addr: SocketAddr = "192.168.1.100:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.headers(headers)
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(!permission.has_permission(&context).await);
}
#[rstest]
#[tokio::test]
async fn test_blacklist_trusts_forwarded_header_from_trusted_proxy() {
let permission = IpBlacklistPermission::new()
.add_trusted_proxy("10.0.0.1")
.add_ip("192.168.1.100");
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "192.168.1.100".parse().unwrap());
let remote_addr: SocketAddr = "10.0.0.1:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.headers(headers)
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(!permission.has_permission(&context).await);
}
#[rstest]
#[tokio::test]
async fn test_blacklist_allows_non_blocked_ip_via_trusted_proxy() {
let permission = IpBlacklistPermission::new()
.add_trusted_proxy("10.0.0.1")
.add_ip("192.168.1.100");
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "192.168.1.1".parse().unwrap());
let remote_addr: SocketAddr = "10.0.0.1:12345".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.headers(headers)
.remote_addr(remote_addr)
.body(Bytes::new())
.build()
.unwrap();
let context = PermissionContext {
request: &request,
is_authenticated: false,
is_admin: false,
is_active: false,
user: None,
};
assert!(permission.has_permission(&context).await);
}
#[rstest]
fn test_ipv6_whitelist_single_address() {
let permission = IpWhitelistPermission::new().add_ip("2001:db8::1");
let ipv6_addr = IpAddr::from_str("2001:db8::1").unwrap();
let other_ipv6 = IpAddr::from_str("2001:db8::2").unwrap();
let result_match = permission.is_allowed(&ipv6_addr);
let result_no_match = permission.is_allowed(&other_ipv6);
assert!(result_match);
assert!(!result_no_match);
}
#[rstest]
fn test_cidr_range_boundary_addresses() {
let permission = IpWhitelistPermission::new().add_cidr("10.0.0.0/24");
let first_addr = IpAddr::from_str("10.0.0.0").unwrap();
let last_addr = IpAddr::from_str("10.0.0.255").unwrap();
let outside_addr = IpAddr::from_str("10.0.1.0").unwrap();
let result_first = permission.is_allowed(&first_addr);
let result_last = permission.is_allowed(&last_addr);
let result_outside = permission.is_allowed(&outside_addr);
assert!(result_first);
assert!(result_last);
assert!(!result_outside);
}
#[rstest]
fn test_mixed_ipv4_ipv6_whitelist() {
let permission = IpWhitelistPermission::new()
.add_ip("192.168.1.1")
.add_ip("2001:db8::1");
let ipv4_addr = IpAddr::from_str("192.168.1.1").unwrap();
let ipv6_addr = IpAddr::from_str("2001:db8::1").unwrap();
let other_ipv4 = IpAddr::from_str("10.0.0.1").unwrap();
let result_ipv4 = permission.is_allowed(&ipv4_addr);
let result_ipv6 = permission.is_allowed(&ipv6_addr);
let result_other = permission.is_allowed(&other_ipv4);
assert!(result_ipv4);
assert!(result_ipv6);
assert!(!result_other);
}
}