1pub use crate::stream::dep::ipnet::{IpNet, Ipv4Net, Ipv6Net};
4
5use rama_core::{Context, context::Extensions};
6
7#[cfg(feature = "http")]
8use crate::stream::SocketInfo;
9#[cfg(feature = "http")]
10use rama_http_types::Request;
11
12#[derive(Debug, Clone)]
13pub struct IpNetMatcher {
17 net: IpNet,
18 optional: bool,
19}
20
21impl IpNetMatcher {
22 pub fn new(net: impl IntoIpNet) -> Self {
28 Self {
29 net: net.into_ip_net(),
30 optional: false,
31 }
32 }
33
34 pub fn optional(net: impl IntoIpNet) -> Self {
40 Self {
41 net: net.into_ip_net(),
42 optional: true,
43 }
44 }
45}
46
47#[cfg(feature = "http")]
48impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for IpNetMatcher {
49 fn matches(
50 &self,
51 _ext: Option<&mut Extensions>,
52 ctx: &Context<State>,
53 _req: &Request<Body>,
54 ) -> bool {
55 ctx.get::<SocketInfo>()
56 .map(|info| self.net.contains(&IpNet::from(info.peer_addr().ip())))
57 .unwrap_or(self.optional)
58 }
59}
60
61impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for IpNetMatcher
62where
63 Socket: crate::stream::Socket,
64{
65 fn matches(
66 &self,
67 _ext: Option<&mut Extensions>,
68 _ctx: &Context<State>,
69 stream: &Socket,
70 ) -> bool {
71 stream
72 .peer_addr()
73 .map(|addr| self.net.contains(&IpNet::from(addr.ip())))
74 .unwrap_or(self.optional)
75 }
76}
77
78pub trait IntoIpNet: private::Sealed {}
80
81macro_rules! impl_ip_net_from_ip_addr_into_all {
82 ($($ty:ty),+ $(,)?) => {
83 $(
84 impl IntoIpNet for $ty {}
85 )+
86 };
87}
88
89impl_ip_net_from_ip_addr_into_all!(
90 Ipv4Net,
91 Ipv6Net,
92 IpNet,
93 std::net::IpAddr,
94 std::net::Ipv4Addr,
95 std::net::Ipv6Addr,
96 [u16; 8],
97 [u8; 16],
98 [u8; 4],
99);
100
101mod private {
102 use super::*;
103
104 pub trait Sealed {
105 fn into_ip_net(self) -> IpNet;
107 }
108
109 impl Sealed for Ipv4Net {
110 fn into_ip_net(self) -> IpNet {
111 IpNet::V4(self)
112 }
113 }
114
115 impl Sealed for Ipv6Net {
116 fn into_ip_net(self) -> IpNet {
117 IpNet::V6(self)
118 }
119 }
120
121 impl Sealed for IpNet {
122 fn into_ip_net(self) -> IpNet {
123 self
124 }
125 }
126
127 macro_rules! impl_sealed_from_ip_addr_into_all {
128 ($($ty:ty),+ $(,)?) => {
129 $(
130 impl Sealed for $ty {
131 fn into_ip_net(self) -> IpNet {
132 let ip_addr: std::net::IpAddr = self.into();
133 ip_addr.into()
134 }
135 }
136 )+
137 };
138 }
139
140 impl_sealed_from_ip_addr_into_all!(
141 std::net::IpAddr,
142 std::net::Ipv4Addr,
143 std::net::Ipv6Addr,
144 [u16; 8],
145 [u8; 16],
146 [u8; 4],
147 );
148}
149
150#[cfg(test)]
151mod test {
152 use super::*;
153 use rama_core::matcher::Matcher;
154 use std::net::SocketAddr;
155
156 const SUBNET_IPV4: &str = "192.168.0.0/24";
157 const SUBNET_IPV4_VALID_CASES: [&str; 2] = ["192.168.0.0/25", "192.168.0.1"];
158 const SUBNET_IPV4_INVALID_CASES: [&str; 2] = ["192.167.0.0/23", "192.168.1.0"];
159
160 const SUBNET_IPV6: &str = "fd00::/16";
161 const SUBNET_IPV6_VALID_CASES: [&str; 2] = ["fd00::/17", "fd00::1"];
162 const SUBNET_IPV6_INVALID_CASES: [&str; 2] = ["fd01::/15", "fd01::"];
163
164 fn socket_addr_from_case(s: &str) -> SocketAddr {
165 if s.contains('/') {
166 let ip_net: IpNet = s.parse().unwrap();
167 SocketAddr::new(ip_net.addr(), 60000)
168 } else {
169 let ip_addr: std::net::IpAddr = s.parse().unwrap();
170 SocketAddr::new(ip_addr, 60000)
171 }
172 }
173
174 #[cfg(feature = "http")]
175 #[test]
176 fn test_ip_net_matcher_http() {
177 let matcher = IpNetMatcher::new([127, 0, 0, 1]);
178
179 let mut ctx = Context::default();
180 let req = Request::builder()
181 .method("GET")
182 .uri("/hello")
183 .body(())
184 .unwrap();
185
186 assert!(!matcher.matches(None, &ctx, &req));
188
189 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
191 assert!(!matcher.matches(None, &ctx, &req));
192
193 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
195 assert!(matcher.matches(None, &ctx, &req));
196
197 let matcher = IpNetMatcher::optional([127, 0, 0, 1]);
199 let mut ctx = Context::default();
200 assert!(matcher.matches(None, &ctx, &req));
201
202 let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
204 for subnet in SUBNET_IPV4_VALID_CASES.iter() {
205 let addr = socket_addr_from_case(subnet);
206 ctx.insert(SocketInfo::new(None, addr));
207 assert!(
208 matcher.matches(None, &ctx, &req),
209 "valid ipv4 subnets => {} >=? {} ({})",
210 SUBNET_IPV4,
211 addr,
212 subnet
213 );
214 }
215
216 let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
218 for subnet in SUBNET_IPV6_VALID_CASES.iter() {
219 let addr = socket_addr_from_case(subnet);
220 ctx.insert(SocketInfo::new(None, addr));
221 assert!(
222 matcher.matches(None, &ctx, &req),
223 "valid ipv6 subnets => {} >=? {} ({})",
224 SUBNET_IPV6,
225 addr,
226 subnet
227 );
228 }
229
230 let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
232 for subnet in SUBNET_IPV4_INVALID_CASES.iter() {
233 let addr = socket_addr_from_case(subnet);
234 ctx.insert(SocketInfo::new(None, addr));
235 assert!(
236 !matcher.matches(None, &ctx, &req),
237 "invalid ipv4 subnets => {} >=? {} ({})",
238 SUBNET_IPV4,
239 addr,
240 subnet
241 );
242 }
243
244 let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
246 for subnet in SUBNET_IPV6_INVALID_CASES.iter() {
247 let addr = socket_addr_from_case(subnet);
248 ctx.insert(SocketInfo::new(None, addr));
249 assert!(
250 !matcher.matches(None, &ctx, &req),
251 "invalid ipv6 subnets => {} >=? {} ({})",
252 SUBNET_IPV6,
253 addr,
254 subnet
255 );
256 }
257 }
258
259 #[test]
260 fn test_ip_net_matcher_socket_trait() {
261 let matcher = IpNetMatcher::new([127, 0, 0, 1]);
262
263 let ctx = Context::default();
264
265 struct FakeSocket {
266 local_addr: Option<SocketAddr>,
267 peer_addr: Option<SocketAddr>,
268 }
269
270 impl crate::stream::Socket for FakeSocket {
271 fn local_addr(&self) -> std::io::Result<SocketAddr> {
272 match &self.local_addr {
273 Some(addr) => Ok(*addr),
274 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
275 }
276 }
277
278 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
279 match &self.peer_addr {
280 Some(addr) => Ok(*addr),
281 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
282 }
283 }
284 }
285
286 let mut socket = FakeSocket {
287 local_addr: None,
288 peer_addr: Some(([127, 0, 0, 1], 8081).into()),
289 };
290
291 socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
293 assert!(!matcher.matches(None, &ctx, &socket));
294
295 socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
297 assert!(matcher.matches(None, &ctx, &socket));
298
299 let matcher = IpNetMatcher::optional([127, 0, 0, 1]);
301 socket.peer_addr = None;
302 assert!(matcher.matches(None, &ctx, &socket));
303
304 let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
306 for subnet in SUBNET_IPV4_VALID_CASES.iter() {
307 let addr = socket_addr_from_case(subnet);
308 socket.peer_addr = Some(addr);
309 assert!(
310 matcher.matches(None, &ctx, &socket),
311 "valid ipv4 subnets => {} >=? {} ({})",
312 SUBNET_IPV4,
313 addr,
314 subnet
315 );
316 }
317
318 let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
320 for subnet in SUBNET_IPV6_VALID_CASES.iter() {
321 let addr = socket_addr_from_case(subnet);
322 socket.peer_addr = Some(addr);
323 assert!(
324 matcher.matches(None, &ctx, &socket),
325 "valid ipv6 subnets => {} >=? {} ({})",
326 SUBNET_IPV6,
327 addr,
328 subnet
329 );
330 }
331
332 let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
334 for subnet in SUBNET_IPV4_INVALID_CASES.iter() {
335 let addr = socket_addr_from_case(subnet);
336 socket.peer_addr = Some(addr);
337 assert!(
338 !matcher.matches(None, &ctx, &socket),
339 "invalid ipv4 subnets => {} >=? {} ({})",
340 SUBNET_IPV4,
341 addr,
342 subnet
343 );
344 }
345
346 let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
348 for subnet in SUBNET_IPV6_INVALID_CASES.iter() {
349 let addr = socket_addr_from_case(subnet);
350 socket.peer_addr = Some(addr);
351 assert!(
352 !matcher.matches(None, &ctx, &socket),
353 "invalid ipv6 subnets => {} >=? {} ({})",
354 SUBNET_IPV6,
355 addr,
356 subnet
357 );
358 }
359 }
360}