Skip to main content

cloudillo_core/rate_limit/
extractors.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Address Key Extractors
5//!
6//! Custom key extractors for hierarchical IP address rate limiting.
7//! Supports IPv4 /32, /24 and IPv6 /64, /48 address levels.
8
9use std::net::{IpAddr, Ipv4Addr, SocketAddr};
10
11use axum::extract::ConnectInfo;
12use hyper::Request;
13
14use crate::app::ServerMode;
15
16/// Represents the hierarchical address level being limited
17#[derive(Clone, Debug, Hash, PartialEq, Eq)]
18pub enum AddressKey {
19	/// IPv4 individual (/32)
20	Ipv4Individual(Ipv4Addr),
21	/// IPv4 /24 network (C-class)
22	Ipv4Network([u8; 3]),
23	/// IPv6 /64 subnet (standard allocation)
24	Ipv6Subnet([u8; 8]),
25	/// IPv6 /48 provider allocation
26	Ipv6Provider([u8; 6]),
27}
28
29impl AddressKey {
30	/// Create individual key from IP address
31	/// For IPv6, returns /64 subnet since /128 tracking is not supported
32	pub fn from_ip_individual(addr: &IpAddr) -> Self {
33		match addr {
34			IpAddr::V4(ip) => AddressKey::Ipv4Individual(*ip),
35			IpAddr::V6(ip) => {
36				let octets = ip.octets();
37				let mut subnet = [0u8; 8];
38				subnet.copy_from_slice(&octets[..8]);
39				AddressKey::Ipv6Subnet(subnet)
40			}
41		}
42	}
43
44	/// Create network key from IP address (IPv4 /24 or IPv6 /64)
45	pub fn from_ip_network(addr: &IpAddr) -> Self {
46		match addr {
47			IpAddr::V4(ip) => {
48				let octets = ip.octets();
49				AddressKey::Ipv4Network([octets[0], octets[1], octets[2]])
50			}
51			IpAddr::V6(ip) => {
52				let octets = ip.octets();
53				let mut subnet = [0u8; 8];
54				subnet.copy_from_slice(&octets[..8]);
55				AddressKey::Ipv6Subnet(subnet)
56			}
57		}
58	}
59
60	/// Create provider key from IPv6 address (/48)
61	/// Returns None for IPv4 addresses
62	pub fn from_ipv6_provider(addr: &IpAddr) -> Option<Self> {
63		match addr {
64			IpAddr::V4(_) => None,
65			IpAddr::V6(ip) => {
66				let octets = ip.octets();
67				let mut provider = [0u8; 6];
68				provider.copy_from_slice(&octets[..6]);
69				Some(AddressKey::Ipv6Provider(provider))
70			}
71		}
72	}
73
74	/// Extract all applicable hierarchical keys for an address
75	pub fn extract_all(addr: &IpAddr) -> Vec<Self> {
76		let mut keys = Vec::with_capacity(3);
77		match addr {
78			IpAddr::V4(ip) => {
79				keys.push(AddressKey::Ipv4Individual(*ip));
80				let octets = ip.octets();
81				keys.push(AddressKey::Ipv4Network([octets[0], octets[1], octets[2]]));
82			}
83			IpAddr::V6(ip) => {
84				// IPv6 uses /64 subnet as lowest level (no /128 tracking)
85				let octets = ip.octets();
86				let mut subnet = [0u8; 8];
87				subnet.copy_from_slice(&octets[..8]);
88				keys.push(AddressKey::Ipv6Subnet(subnet));
89				let mut provider = [0u8; 6];
90				provider.copy_from_slice(&octets[..6]);
91				keys.push(AddressKey::Ipv6Provider(provider));
92			}
93		}
94		keys
95	}
96
97	/// Check if this is an individual-level key (IPv4 only, IPv6 uses /64)
98	pub fn is_individual(&self) -> bool {
99		matches!(self, AddressKey::Ipv4Individual(_))
100	}
101
102	/// Check if this is a network-level key
103	pub fn is_network(&self) -> bool {
104		matches!(self, AddressKey::Ipv4Network(_) | AddressKey::Ipv6Subnet(_))
105	}
106
107	/// Check if this is a provider-level key
108	pub fn is_provider(&self) -> bool {
109		matches!(self, AddressKey::Ipv6Provider(_))
110	}
111
112	/// Get address level name for logging/responses
113	pub fn level_name(&self) -> &'static str {
114		match self {
115			AddressKey::Ipv4Individual(_) => "ipv4_individual",
116			AddressKey::Ipv4Network(_) => "ipv4_network",
117			AddressKey::Ipv6Subnet(_) => "ipv6_subnet",
118			AddressKey::Ipv6Provider(_) => "ipv6_provider",
119		}
120	}
121}
122
123/// Extract client IP from request based on ServerMode
124///
125/// - Standalone mode: Use peer IP directly from ConnectInfo
126/// - Proxy/StreamProxy mode: Check forwarding headers first
127pub fn extract_client_ip<B>(req: &Request<B>, mode: &ServerMode) -> Option<IpAddr> {
128	match mode {
129		ServerMode::Standalone => {
130			// Direct connection - use peer IP
131			req.extensions().get::<ConnectInfo<SocketAddr>>().map(|ci| ci.0.ip())
132		}
133		ServerMode::Proxy | ServerMode::StreamProxy => {
134			// Behind reverse proxy - check headers first
135			extract_from_xff(req)
136				.or_else(|| extract_from_x_real_ip(req))
137				.or_else(|| extract_from_forwarded(req))
138				.or_else(|| req.extensions().get::<ConnectInfo<SocketAddr>>().map(|ci| ci.0.ip()))
139		}
140	}
141}
142
143/// Extract IP from X-Forwarded-For header
144fn extract_from_xff<B>(req: &Request<B>) -> Option<IpAddr> {
145	req.headers()
146		.get("x-forwarded-for")
147		.and_then(|h| h.to_str().ok())
148		.and_then(|s| {
149			// X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
150			// Take the first (leftmost) IP as the original client
151			s.split(',').next().map(str::trim).and_then(|ip| ip.parse().ok())
152		})
153}
154
155/// Extract IP from X-Real-IP header
156fn extract_from_x_real_ip<B>(req: &Request<B>) -> Option<IpAddr> {
157	req.headers()
158		.get("x-real-ip")
159		.and_then(|h| h.to_str().ok())
160		.and_then(|s| s.trim().parse().ok())
161}
162
163/// Extract IP from Forwarded header (RFC 7239)
164fn extract_from_forwarded<B>(req: &Request<B>) -> Option<IpAddr> {
165	req.headers().get("forwarded").and_then(|h| h.to_str().ok()).and_then(|s| {
166		// Forwarded header format: "for=192.0.2.60;proto=http;by=203.0.113.43"
167		// or with IPv6: "for=\"[2001:db8::1]\""
168		s.split(';')
169			.find(|part| part.trim().to_lowercase().starts_with("for="))
170			.and_then(|for_part| {
171				let value = for_part
172					.trim()
173					.strip_prefix("for=")
174					.or_else(|| for_part.trim().strip_prefix("FOR="))?;
175				// Handle quoted IPv6: "for=\"[2001:db8::1]\""
176				let cleaned = value.trim_matches('"').trim_matches('[').trim_matches(']');
177				cleaned.parse().ok()
178			})
179	})
180}
181
182#[cfg(test)]
183mod tests {
184	use super::*;
185	use std::net::Ipv6Addr;
186
187	#[test]
188	fn test_address_key_extraction_ipv4() {
189		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
190		let keys = AddressKey::extract_all(&ip);
191
192		assert_eq!(keys.len(), 2);
193		assert!(
194			matches!(keys[0], AddressKey::Ipv4Individual(addr) if addr == Ipv4Addr::new(192, 168, 1, 100))
195		);
196		assert!(matches!(keys[1], AddressKey::Ipv4Network([192, 168, 1])));
197	}
198
199	#[test]
200	fn test_address_key_extraction_ipv6() {
201		let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0x85a3, 0, 0, 0, 0, 1));
202		let keys = AddressKey::extract_all(&ip);
203
204		// IPv6 uses /64 subnet as lowest level (no /128 tracking)
205		assert_eq!(keys.len(), 2);
206		assert!(matches!(keys[0], AddressKey::Ipv6Subnet(_)));
207		assert!(matches!(keys[1], AddressKey::Ipv6Provider(_)));
208	}
209
210	#[test]
211	fn test_address_key_levels() {
212		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
213		let individual = AddressKey::from_ip_individual(&ipv4);
214		let network = AddressKey::from_ip_network(&ipv4);
215
216		assert!(individual.is_individual());
217		assert!(!individual.is_network());
218		assert!(network.is_network());
219		assert!(!network.is_individual());
220	}
221
222	#[test]
223	fn test_ipv6_provider_key() {
224		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
225		let ipv6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
226
227		assert!(AddressKey::from_ipv6_provider(&ipv4).is_none());
228		assert!(AddressKey::from_ipv6_provider(&ipv6).is_some());
229	}
230
231	#[test]
232	fn test_level_names() {
233		let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
234		let ipv6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
235
236		assert_eq!(AddressKey::from_ip_individual(&ipv4).level_name(), "ipv4_individual");
237		assert_eq!(AddressKey::from_ip_network(&ipv4).level_name(), "ipv4_network");
238		// IPv6 individual falls back to subnet
239		assert_eq!(AddressKey::from_ip_individual(&ipv6).level_name(), "ipv6_subnet");
240		assert_eq!(AddressKey::from_ip_network(&ipv6).level_name(), "ipv6_subnet");
241		assert_eq!(AddressKey::from_ipv6_provider(&ipv6).unwrap().level_name(), "ipv6_provider");
242	}
243}
244
245// vim: ts=4