Skip to main content

securitydept_realip/
access.rs

1use std::net::IpAddr;
2
3use ipnet::IpNet;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    error::{RealIpError, RealIpResult},
8    resolve::{ResolvedClientIp, ResolvedSourceKind},
9};
10
11#[derive(Debug, Clone, Deserialize, Serialize, Default)]
12pub struct RealIpAccessConfig {
13    #[serde(default)]
14    pub allowed_cidrs: Vec<IpNet>,
15    #[serde(default)]
16    pub allow_fallback: bool,
17}
18
19impl RealIpAccessConfig {
20    pub fn validate(&self) -> RealIpResult<()> {
21        if self.allowed_cidrs.is_empty() {
22            return Err(RealIpError::AccessConfig {
23                message: "allowed_cidrs must not be empty".to_string(),
24            });
25        }
26
27        Ok(())
28    }
29
30    pub fn allows_client_ip(&self, client_ip: IpAddr) -> bool {
31        self.allowed_cidrs
32            .iter()
33            .any(|cidr| cidr.contains(&client_ip))
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct RealIpAccessManager {
39    config: RealIpAccessConfig,
40}
41
42impl RealIpAccessManager {
43    pub fn from_config(config: RealIpAccessConfig) -> RealIpResult<Self> {
44        config.validate()?;
45        Ok(Self { config })
46    }
47
48    pub fn config(&self) -> &RealIpAccessConfig {
49        &self.config
50    }
51
52    pub fn ensure_allowed(&self, resolved: &ResolvedClientIp) -> RealIpResult<()> {
53        if resolved.source_kind == ResolvedSourceKind::Fallback && !self.config.allow_fallback {
54            return Err(RealIpError::AccessDenied {
55                client_ip: resolved.client_ip,
56                reason: "fallback source is not allowed".to_string(),
57            });
58        }
59
60        if !self.config.allows_client_ip(resolved.client_ip) {
61            return Err(RealIpError::AccessDenied {
62                client_ip: resolved.client_ip,
63                reason: "client IP is not in allowed_cidrs".to_string(),
64            });
65        }
66
67        Ok(())
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use std::net::{IpAddr, Ipv4Addr};
74
75    use super::*;
76
77    #[test]
78    fn access_config_requires_non_empty_allowed_cidrs() {
79        let error = RealIpAccessManager::from_config(RealIpAccessConfig::default())
80            .expect_err("empty allowed_cidrs should be rejected");
81
82        assert!(matches!(error, RealIpError::AccessConfig { .. }));
83    }
84
85    #[test]
86    fn access_manager_rejects_fallback_when_disabled() {
87        let manager = RealIpAccessManager::from_config(RealIpAccessConfig {
88            allowed_cidrs: vec!["10.0.0.0/8".parse().expect("cidr should parse")],
89            allow_fallback: false,
90        })
91        .expect("access manager should build");
92        let resolved = ResolvedClientIp {
93            client_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)),
94            peer_ip: IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)),
95            source_name: None,
96            source_kind: ResolvedSourceKind::Fallback,
97            header_name: None,
98        };
99
100        let error = manager
101            .ensure_allowed(&resolved)
102            .expect_err("fallback should be rejected");
103
104        assert!(matches!(error, RealIpError::AccessDenied { .. }));
105    }
106}