rama_net/stream/matcher/
loopback.rs

1use rama_core::{Context, context::Extensions};
2
3#[cfg(feature = "http")]
4use crate::stream::SocketInfo;
5#[cfg(feature = "http")]
6use rama_http_types::Request;
7
8#[derive(Debug, Clone)]
9/// Matcher based on the ip part of the [`SocketAddr`] of the peer,
10/// matching only if the ip is a loopback address.
11///
12/// [`SocketAddr`]: std::net::SocketAddr
13pub struct LoopbackMatcher {
14    optional: bool,
15}
16
17impl LoopbackMatcher {
18    /// create a new loopback matcher to match on the ip part a [`SocketAddr`],
19    /// matching only if the ip is a loopback address.
20    ///
21    /// This matcher will not match in case socket address could not be found,
22    /// if you want to match in case socket address could not be found,
23    /// use the [`LoopbackMatcher::optional`] constructor..
24    ///
25    /// [`SocketAddr`]: std::net::SocketAddr
26    pub const fn new() -> Self {
27        Self { optional: false }
28    }
29
30    /// create a new loopback matcher to match on the ip part a [`SocketAddr`],
31    /// matching only if the ip is a loopback address or no socket address could be found.
32    ///
33    /// This matcher will match in case socket address could not be found.
34    /// Use the [`LoopbackMatcher::new`] constructor if you want do not want
35    /// to match in case socket address could not be found.
36    ///
37    /// [`SocketAddr`]: std::net::SocketAddr
38    pub const fn optional() -> Self {
39        Self { optional: true }
40    }
41}
42
43impl Default for LoopbackMatcher {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49#[cfg(feature = "http")]
50impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for LoopbackMatcher {
51    fn matches(
52        &self,
53        _ext: Option<&mut Extensions>,
54        ctx: &Context<State>,
55        _req: &Request<Body>,
56    ) -> bool {
57        ctx.get::<SocketInfo>()
58            .map(|info| info.peer_addr().ip().is_loopback())
59            .unwrap_or(self.optional)
60    }
61}
62
63impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for LoopbackMatcher
64where
65    Socket: crate::stream::Socket,
66{
67    fn matches(
68        &self,
69        _ext: Option<&mut Extensions>,
70        _ctx: &Context<State>,
71        stream: &Socket,
72    ) -> bool {
73        stream
74            .peer_addr()
75            .map(|addr| addr.ip().is_loopback())
76            .unwrap_or(self.optional)
77    }
78}
79
80#[cfg(test)]
81mod test {
82    use super::*;
83    use rama_core::matcher::Matcher;
84    use std::net::SocketAddr;
85
86    #[cfg(feature = "http")]
87    #[test]
88    fn test_loopback_matcher_http() {
89        let matcher = LoopbackMatcher::new();
90
91        let mut ctx = Context::default();
92        let req = Request::builder()
93            .method("GET")
94            .uri("/hello")
95            .body(())
96            .unwrap();
97
98        // test #1: no match: test with no socket info registered
99        assert!(!matcher.matches(None, &ctx, &req));
100
101        // test #2: no match: test with network address (ipv4)
102        ctx.insert(SocketInfo::new(None, ([192, 168, 0, 1], 8080).into()));
103        assert!(!matcher.matches(None, &ctx, &req));
104
105        // test #3: no match: test with network address (ipv6)
106        ctx.insert(SocketInfo::new(
107            None,
108            ([1, 1, 1, 1, 1, 1, 1, 1], 8080).into(),
109        ));
110        assert!(!matcher.matches(None, &ctx, &req));
111
112        // test #4: match: test with loopback address (ipv4)
113        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
114        assert!(matcher.matches(None, &ctx, &req));
115
116        // test #5: match: test with another loopback address (ipv4)
117        ctx.insert(SocketInfo::new(None, ([127, 3, 2, 1], 8080).into()));
118        assert!(matcher.matches(None, &ctx, &req));
119
120        // test #6: match: test with loopback address (ipv6)
121        ctx.insert(SocketInfo::new(
122            None,
123            ([0, 0, 0, 0, 0, 0, 0, 1], 8080).into(),
124        ));
125        assert!(matcher.matches(None, &ctx, &req));
126
127        // test #7: match: test with missing socket info, but it's seen as optional
128        let matcher = LoopbackMatcher::optional();
129        let ctx = Context::default();
130        assert!(matcher.matches(None, &ctx, &req));
131    }
132
133    #[test]
134    fn test_loopback_matcher_socket_trait() {
135        let matcher = LoopbackMatcher::new();
136
137        let ctx = Context::default();
138
139        struct FakeSocket {
140            local_addr: Option<SocketAddr>,
141            peer_addr: Option<SocketAddr>,
142        }
143
144        impl crate::stream::Socket for FakeSocket {
145            fn local_addr(&self) -> std::io::Result<SocketAddr> {
146                match &self.local_addr {
147                    Some(addr) => Ok(*addr),
148                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
149                }
150            }
151
152            fn peer_addr(&self) -> std::io::Result<SocketAddr> {
153                match &self.peer_addr {
154                    Some(addr) => Ok(*addr),
155                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
156                }
157            }
158        }
159
160        let mut socket = FakeSocket {
161            local_addr: None,
162            peer_addr: None,
163        };
164
165        // test #1: no match: test with no socket info registered
166        assert!(!matcher.matches(None, &ctx, &socket));
167
168        // test #2: no match: test with network address (ipv4)
169        socket.peer_addr = Some(([192, 168, 0, 1], 8080).into());
170        assert!(!matcher.matches(None, &ctx, &socket));
171
172        // test #3: no match: test with network address (ipv6)
173        socket.peer_addr = Some(([1, 1, 1, 1, 1, 1, 1, 1], 8080).into());
174        assert!(!matcher.matches(None, &ctx, &socket));
175
176        // test #4: match: test with loopback address (ipv4)
177        socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
178        assert!(matcher.matches(None, &ctx, &socket));
179
180        // test #5: match: test with another loopback address (ipv4)
181        socket.peer_addr = Some(([127, 3, 2, 1], 8080).into());
182        assert!(matcher.matches(None, &ctx, &socket));
183
184        // test #6: match: test with loopback address (ipv6)
185        socket.peer_addr = Some(([0, 0, 0, 0, 0, 0, 0, 1], 8080).into());
186        assert!(matcher.matches(None, &ctx, &socket));
187
188        // test #7: match: test with missing socket info, but it's seen as optional
189        let matcher = LoopbackMatcher::optional();
190        socket.peer_addr = None;
191        assert!(matcher.matches(None, &ctx, &socket));
192    }
193}