Skip to main content

cloudillo_core/rate_limit/
extractors.rs

1//! Address Key Extractors
2//!
3//! Custom key extractors for hierarchical IP address rate limiting.
4//! Supports IPv4 /32, /24 and IPv6 /64, /48 address levels.
5
6use std::net::{IpAddr, Ipv4Addr, SocketAddr};
7
8use axum::extract::ConnectInfo;
9use hyper::Request;
10
11use crate::app::ServerMode;
12
13/// Represents the hierarchical address level being limited
14#[derive(Clone, Debug, Hash, PartialEq, Eq)]
15pub enum AddressKey {
16	/// IPv4 individual (/32)
17	Ipv4Individual(Ipv4Addr),
18	/// IPv4 /24 network (C-class)
19	Ipv4Network([u8; 3]),
20	/// IPv6 /64 subnet (standard allocation)
21	Ipv6Subnet([u8; 8]),
22	/// IPv6 /48 provider allocation
23	Ipv6Provider([u8; 6]),
24}
25
26impl AddressKey {
27	/// Create individual key from IP address
28	/// For IPv6, returns /64 subnet since /128 tracking is not supported
29	pub fn from_ip_individual(addr: &IpAddr) -> Self {
30		match addr {
31			IpAddr::V4(ip) => AddressKey::Ipv4Individual(*ip),
32			IpAddr::V6(ip) => {
33				let octets = ip.octets();
34				let mut subnet = [0u8; 8];
35				subnet.copy_from_slice(&octets[..8]);
36				AddressKey::Ipv6Subnet(subnet)
37			}
38		}
39	}
40
41	/// Create network key from IP address (IPv4 /24 or IPv6 /64)
42	pub fn from_ip_network(addr: &IpAddr) -> Self {
43		match addr {
44			IpAddr::V4(ip) => {
45				let octets = ip.octets();
46				AddressKey::Ipv4Network([octets[0], octets[1], octets[2]])
47			}
48			IpAddr::V6(ip) => {
49				let octets = ip.octets();
50				let mut subnet = [0u8; 8];
51				subnet.copy_from_slice(&octets[..8]);
52				AddressKey::Ipv6Subnet(subnet)
53			}
54		}
55	}
56
57	/// Create provider key from IPv6 address (/48)
58	/// Returns None for IPv4 addresses
59	pub fn from_ipv6_provider(addr: &IpAddr) -> Option<Self> {
60		match addr {
61			IpAddr::V4(_) => None,
62			IpAddr::V6(ip) => {
63				let octets = ip.octets();
64				let mut provider = [0u8; 6];
65				provider.copy_from_slice(&octets[..6]);
66				Some(AddressKey::Ipv6Provider(provider))
67			}
68		}
69	}
70
71	/// Extract all applicable hierarchical keys for an address
72	pub fn extract_all(addr: &IpAddr) -> Vec<Self> {
73		let mut keys = Vec::with_capacity(3);
74		match addr {
75			IpAddr::V4(ip) => {
76				keys.push(AddressKey::Ipv4Individual(*ip));
77				let octets = ip.octets();
78				keys.push(AddressKey::Ipv4Network([octets[0], octets[1], octets[2]]));
79			}
80			IpAddr::V6(ip) => {
81				// IPv6 uses /64 subnet as lowest level (no /128 tracking)
82				let octets = ip.octets();
83				let mut subnet = [0u8; 8];
84				subnet.copy_from_slice(&octets[..8]);
85				keys.push(AddressKey::Ipv6Subnet(subnet));
86				let mut provider = [0u8; 6];
87				provider.copy_from_slice(&octets[..6]);
88				keys.push(AddressKey::Ipv6Provider(provider));
89			}
90		}
91		keys
92	}
93
94	/// Check if this is an individual-level key (IPv4 only, IPv6 uses /64)
95	pub fn is_individual(&self) -> bool {
96		matches!(self, AddressKey::Ipv4Individual(_))
97	}
98
99	/// Check if this is a network-level key
100	pub fn is_network(&self) -> bool {
101		matches!(self, AddressKey::Ipv4Network(_) | AddressKey::Ipv6Subnet(_))
102	}
103
104	/// Check if this is a provider-level key
105	pub fn is_provider(&self) -> bool {
106		matches!(self, AddressKey::Ipv6Provider(_))
107	}
108
109	/// Get address level name for logging/responses
110	pub fn level_name(&self) -> &'static str {
111		match self {
112			AddressKey::Ipv4Individual(_) => "ipv4_individual",
113			AddressKey::Ipv4Network(_) => "ipv4_network",
114			AddressKey::Ipv6Subnet(_) => "ipv6_subnet",
115			AddressKey::Ipv6Provider(_) => "ipv6_provider",
116		}
117	}
118}
119
120/// Extract client IP from request based on ServerMode
121///
122/// - Standalone mode: Use peer IP directly from ConnectInfo
123/// - Proxy/StreamProxy mode: Check forwarding headers first
124pub fn extract_client_ip<B>(req: &Request<B>, mode: &ServerMode) -> Option<IpAddr> {
125	match mode {
126		ServerMode::Standalone => {
127			// Direct connection - use peer IP
128			req.extensions().get::<ConnectInfo<SocketAddr>>().map(|ci| ci.0.ip())
129		}
130		ServerMode::Proxy | ServerMode::StreamProxy => {
131			// Behind reverse proxy - check headers first
132			extract_from_xff(req)
133				.or_else(|| extract_from_x_real_ip(req))
134				.or_else(|| extract_from_forwarded(req))
135				.or_else(|| req.extensions().get::<ConnectInfo<SocketAddr>>().map(|ci| ci.0.ip()))
136		}
137	}
138}
139
140/// Extract IP from X-Forwarded-For header
141fn extract_from_xff<B>(req: &Request<B>) -> Option<IpAddr> {
142	req.headers()
143		.get("x-forwarded-for")
144		.and_then(|h| h.to_str().ok())
145		.and_then(|s| {
146			// X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
147			// Take the first (leftmost) IP as the original client
148			s.split(',').next().map(|ip| ip.trim()).and_then(|ip| ip.parse().ok())
149		})
150}
151
152/// Extract IP from X-Real-IP header
153fn extract_from_x_real_ip<B>(req: &Request<B>) -> Option<IpAddr> {
154	req.headers()
155		.get("x-real-ip")
156		.and_then(|h| h.to_str().ok())
157		.and_then(|s| s.trim().parse().ok())
158}
159
160/// Extract IP from Forwarded header (RFC 7239)
161fn extract_from_forwarded<B>(req: &Request<B>) -> Option<IpAddr> {
162	req.headers().get("forwarded").and_then(|h| h.to_str().ok()).and_then(|s| {
163		// Forwarded header format: "for=192.0.2.60;proto=http;by=203.0.113.43"
164		// or with IPv6: "for=\"[2001:db8::1]\""
165		s.split(';')
166			.find(|part| part.trim().to_lowercase().starts_with("for="))
167			.and_then(|for_part| {
168				let value = for_part
169					.trim()
170					.strip_prefix("for=")
171					.or_else(|| for_part.trim().strip_prefix("FOR="))?;
172				// Handle quoted IPv6: "for=\"[2001:db8::1]\""
173				let cleaned = value.trim_matches('"').trim_matches('[').trim_matches(']');
174				cleaned.parse().ok()
175			})
176	})
177}
178
179#[cfg(test)]
180mod tests {
181	use super::*;
182	use std::net::Ipv6Addr;
183
184	#[test]
185	fn test_address_key_extraction_ipv4() {
186		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
187		let keys = AddressKey::extract_all(&ip);
188
189		assert_eq!(keys.len(), 2);
190		assert!(
191			matches!(keys[0], AddressKey::Ipv4Individual(addr) if addr == Ipv4Addr::new(192, 168, 1, 100))
192		);
193		assert!(matches!(keys[1], AddressKey::Ipv4Network([192, 168, 1])));
194	}
195
196	#[test]
197	fn test_address_key_extraction_ipv6() {
198		let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0x85a3, 0, 0, 0, 0, 1));
199		let keys = AddressKey::extract_all(&ip);
200
201		// IPv6 uses /64 subnet as lowest level (no /128 tracking)
202		assert_eq!(keys.len(), 2);
203		assert!(matches!(keys[0], AddressKey::Ipv6Subnet(_)));
204		assert!(matches!(keys[1], AddressKey::Ipv6Provider(_)));
205	}
206
207	#[test]
208	fn test_address_key_levels() {
209		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
210		let individual = AddressKey::from_ip_individual(&ipv4);
211		let network = AddressKey::from_ip_network(&ipv4);
212
213		assert!(individual.is_individual());
214		assert!(!individual.is_network());
215		assert!(network.is_network());
216		assert!(!network.is_individual());
217	}
218
219	#[test]
220	fn test_ipv6_provider_key() {
221		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
222		let ipv6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
223
224		assert!(AddressKey::from_ipv6_provider(&ipv4).is_none());
225		assert!(AddressKey::from_ipv6_provider(&ipv6).is_some());
226	}
227
228	#[test]
229	fn test_level_names() {
230		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
231		let ipv6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
232
233		assert_eq!(AddressKey::from_ip_individual(&ipv4).level_name(), "ipv4_individual");
234		assert_eq!(AddressKey::from_ip_network(&ipv4).level_name(), "ipv4_network");
235		// IPv6 individual falls back to subnet
236		assert_eq!(AddressKey::from_ip_individual(&ipv6).level_name(), "ipv6_subnet");
237		assert_eq!(AddressKey::from_ip_network(&ipv6).level_name(), "ipv6_subnet");
238		assert_eq!(AddressKey::from_ipv6_provider(&ipv6).unwrap().level_name(), "ipv6_provider");
239	}
240}
241
242// vim: ts=4