rama_net/stream/matcher/
ip.rs

1//! ip matcher and utilities
2
3pub use crate::stream::dep::ipnet::{IpNet, Ipv4Net, Ipv6Net};
4
5use rama_core::{Context, context::Extensions};
6
7#[cfg(feature = "http")]
8use crate::stream::SocketInfo;
9#[cfg(feature = "http")]
10use rama_http_types::Request;
11
12#[derive(Debug, Clone)]
13/// Matcher based on whether or not the [`IpNet`] contains the [`SocketAddr`] of the peer.
14///
15/// [`SocketAddr`]: std::net::SocketAddr
16pub struct IpNetMatcher {
17    net: IpNet,
18    optional: bool,
19}
20
21impl IpNetMatcher {
22    /// create a new IP network matcher to match on an IP Network.
23    ///
24    /// This matcher will not match in case socket address could not be found,
25    /// if you want to match in case socket address could not be found,
26    /// use the [`IpNetMatcher::optional`] constructor..
27    pub fn new(net: impl IntoIpNet) -> Self {
28        Self {
29            net: net.into_ip_net(),
30            optional: false,
31        }
32    }
33
34    /// create a new IP network matcher to match on an IP network
35    ///
36    /// This matcher will match in case socket address could not be found.
37    /// Use the [`IpNetMatcher::new`] constructor if you want do not want
38    /// to match in case socket address could not be found.
39    pub fn optional(net: impl IntoIpNet) -> Self {
40        Self {
41            net: net.into_ip_net(),
42            optional: true,
43        }
44    }
45}
46
47#[cfg(feature = "http")]
48impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for IpNetMatcher {
49    fn matches(
50        &self,
51        _ext: Option<&mut Extensions>,
52        ctx: &Context<State>,
53        _req: &Request<Body>,
54    ) -> bool {
55        ctx.get::<SocketInfo>()
56            .map(|info| self.net.contains(&IpNet::from(info.peer_addr().ip())))
57            .unwrap_or(self.optional)
58    }
59}
60
61impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for IpNetMatcher
62where
63    Socket: crate::stream::Socket,
64{
65    fn matches(
66        &self,
67        _ext: Option<&mut Extensions>,
68        _ctx: &Context<State>,
69        stream: &Socket,
70    ) -> bool {
71        stream
72            .peer_addr()
73            .map(|addr| self.net.contains(&IpNet::from(addr.ip())))
74            .unwrap_or(self.optional)
75    }
76}
77
78/// utility trait to consume a tpe into an [`IpNet`]
79pub trait IntoIpNet: private::Sealed {}
80
81macro_rules! impl_ip_net_from_ip_addr_into_all {
82    ($($ty:ty),+ $(,)?) => {
83        $(
84            impl IntoIpNet for $ty {}
85        )+
86    };
87}
88
89impl_ip_net_from_ip_addr_into_all!(
90    Ipv4Net,
91    Ipv6Net,
92    IpNet,
93    std::net::IpAddr,
94    std::net::Ipv4Addr,
95    std::net::Ipv6Addr,
96    [u16; 8],
97    [u8; 16],
98    [u8; 4],
99);
100
101mod private {
102    use super::*;
103
104    pub trait Sealed {
105        /// Consume `self` into an [`IpNet`]
106        fn into_ip_net(self) -> IpNet;
107    }
108
109    impl Sealed for Ipv4Net {
110        fn into_ip_net(self) -> IpNet {
111            IpNet::V4(self)
112        }
113    }
114
115    impl Sealed for Ipv6Net {
116        fn into_ip_net(self) -> IpNet {
117            IpNet::V6(self)
118        }
119    }
120
121    impl Sealed for IpNet {
122        fn into_ip_net(self) -> IpNet {
123            self
124        }
125    }
126
127    macro_rules! impl_sealed_from_ip_addr_into_all {
128        ($($ty:ty),+ $(,)?) => {
129            $(
130                impl Sealed for $ty {
131                    fn into_ip_net(self) -> IpNet {
132                        let ip_addr: std::net::IpAddr = self.into();
133                        ip_addr.into()
134                    }
135                }
136            )+
137        };
138    }
139
140    impl_sealed_from_ip_addr_into_all!(
141        std::net::IpAddr,
142        std::net::Ipv4Addr,
143        std::net::Ipv6Addr,
144        [u16; 8],
145        [u8; 16],
146        [u8; 4],
147    );
148}
149
150#[cfg(test)]
151mod test {
152    use super::*;
153    use rama_core::matcher::Matcher;
154    use std::net::SocketAddr;
155
156    const SUBNET_IPV4: &str = "192.168.0.0/24";
157    const SUBNET_IPV4_VALID_CASES: [&str; 2] = ["192.168.0.0/25", "192.168.0.1"];
158    const SUBNET_IPV4_INVALID_CASES: [&str; 2] = ["192.167.0.0/23", "192.168.1.0"];
159
160    const SUBNET_IPV6: &str = "fd00::/16";
161    const SUBNET_IPV6_VALID_CASES: [&str; 2] = ["fd00::/17", "fd00::1"];
162    const SUBNET_IPV6_INVALID_CASES: [&str; 2] = ["fd01::/15", "fd01::"];
163
164    fn socket_addr_from_case(s: &str) -> SocketAddr {
165        if s.contains('/') {
166            let ip_net: IpNet = s.parse().unwrap();
167            SocketAddr::new(ip_net.addr(), 60000)
168        } else {
169            let ip_addr: std::net::IpAddr = s.parse().unwrap();
170            SocketAddr::new(ip_addr, 60000)
171        }
172    }
173
174    #[cfg(feature = "http")]
175    #[test]
176    fn test_ip_net_matcher_http() {
177        let matcher = IpNetMatcher::new([127, 0, 0, 1]);
178
179        let mut ctx = Context::default();
180        let req = Request::builder()
181            .method("GET")
182            .uri("/hello")
183            .body(())
184            .unwrap();
185
186        // test #1: no match: test with no socket info registered
187        assert!(!matcher.matches(None, &ctx, &req));
188
189        // test #2: no match: test with different socket info (ip addr difference)
190        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
191        assert!(!matcher.matches(None, &ctx, &req));
192
193        // test #3: match: test with correct address
194        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
195        assert!(matcher.matches(None, &ctx, &req));
196
197        // test #4: match: test with missing socket info, but it's seen as optional
198        let matcher = IpNetMatcher::optional([127, 0, 0, 1]);
199        let mut ctx = Context::default();
200        assert!(matcher.matches(None, &ctx, &req));
201
202        // test #5: match: valid ipv4 subnets
203        let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
204        for subnet in SUBNET_IPV4_VALID_CASES.iter() {
205            let addr = socket_addr_from_case(subnet);
206            ctx.insert(SocketInfo::new(None, addr));
207            assert!(
208                matcher.matches(None, &ctx, &req),
209                "valid ipv4 subnets => {} >=? {} ({})",
210                SUBNET_IPV4,
211                addr,
212                subnet
213            );
214        }
215
216        // test #6: match: valid ipv6 subnets
217        let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
218        for subnet in SUBNET_IPV6_VALID_CASES.iter() {
219            let addr = socket_addr_from_case(subnet);
220            ctx.insert(SocketInfo::new(None, addr));
221            assert!(
222                matcher.matches(None, &ctx, &req),
223                "valid ipv6 subnets => {} >=? {} ({})",
224                SUBNET_IPV6,
225                addr,
226                subnet
227            );
228        }
229
230        // test #7: match: invalid ipv4 subnets
231        let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
232        for subnet in SUBNET_IPV4_INVALID_CASES.iter() {
233            let addr = socket_addr_from_case(subnet);
234            ctx.insert(SocketInfo::new(None, addr));
235            assert!(
236                !matcher.matches(None, &ctx, &req),
237                "invalid ipv4 subnets => {} >=? {} ({})",
238                SUBNET_IPV4,
239                addr,
240                subnet
241            );
242        }
243
244        // test #8: match: invalid ipv6 subnets
245        let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
246        for subnet in SUBNET_IPV6_INVALID_CASES.iter() {
247            let addr = socket_addr_from_case(subnet);
248            ctx.insert(SocketInfo::new(None, addr));
249            assert!(
250                !matcher.matches(None, &ctx, &req),
251                "invalid ipv6 subnets => {} >=? {} ({})",
252                SUBNET_IPV6,
253                addr,
254                subnet
255            );
256        }
257    }
258
259    #[test]
260    fn test_ip_net_matcher_socket_trait() {
261        let matcher = IpNetMatcher::new([127, 0, 0, 1]);
262
263        let ctx = Context::default();
264
265        struct FakeSocket {
266            local_addr: Option<SocketAddr>,
267            peer_addr: Option<SocketAddr>,
268        }
269
270        impl crate::stream::Socket for FakeSocket {
271            fn local_addr(&self) -> std::io::Result<SocketAddr> {
272                match &self.local_addr {
273                    Some(addr) => Ok(*addr),
274                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
275                }
276            }
277
278            fn peer_addr(&self) -> std::io::Result<SocketAddr> {
279                match &self.peer_addr {
280                    Some(addr) => Ok(*addr),
281                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
282                }
283            }
284        }
285
286        let mut socket = FakeSocket {
287            local_addr: None,
288            peer_addr: Some(([127, 0, 0, 1], 8081).into()),
289        };
290
291        // test #1: no match: test with different socket info (ip addr difference)
292        socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
293        assert!(!matcher.matches(None, &ctx, &socket));
294
295        // test #2: match: test with correct address
296        socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
297        assert!(matcher.matches(None, &ctx, &socket));
298
299        // test #3: match: test with missing socket info, but it's seen as optional
300        let matcher = IpNetMatcher::optional([127, 0, 0, 1]);
301        socket.peer_addr = None;
302        assert!(matcher.matches(None, &ctx, &socket));
303
304        // test #4: match: valid ipv4 subnets
305        let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
306        for subnet in SUBNET_IPV4_VALID_CASES.iter() {
307            let addr = socket_addr_from_case(subnet);
308            socket.peer_addr = Some(addr);
309            assert!(
310                matcher.matches(None, &ctx, &socket),
311                "valid ipv4 subnets => {} >=? {} ({})",
312                SUBNET_IPV4,
313                addr,
314                subnet
315            );
316        }
317
318        // test #5: match: valid ipv6 subnets
319        let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
320        for subnet in SUBNET_IPV6_VALID_CASES.iter() {
321            let addr = socket_addr_from_case(subnet);
322            socket.peer_addr = Some(addr);
323            assert!(
324                matcher.matches(None, &ctx, &socket),
325                "valid ipv6 subnets => {} >=? {} ({})",
326                SUBNET_IPV6,
327                addr,
328                subnet
329            );
330        }
331
332        // test #6: match: invalid ipv4 subnets
333        let matcher = IpNetMatcher::new(SUBNET_IPV4.parse::<IpNet>().unwrap());
334        for subnet in SUBNET_IPV4_INVALID_CASES.iter() {
335            let addr = socket_addr_from_case(subnet);
336            socket.peer_addr = Some(addr);
337            assert!(
338                !matcher.matches(None, &ctx, &socket),
339                "invalid ipv4 subnets => {} >=? {} ({})",
340                SUBNET_IPV4,
341                addr,
342                subnet
343            );
344        }
345
346        // test #7: match: invalid ipv6 subnets
347        let matcher = IpNetMatcher::new(SUBNET_IPV6.parse::<IpNet>().unwrap());
348        for subnet in SUBNET_IPV6_INVALID_CASES.iter() {
349            let addr = socket_addr_from_case(subnet);
350            socket.peer_addr = Some(addr);
351            assert!(
352                !matcher.matches(None, &ctx, &socket),
353                "invalid ipv6 subnets => {} >=? {} ({})",
354                SUBNET_IPV6,
355                addr,
356                subnet
357            );
358        }
359    }
360}