rama_net/stream/matcher/
socket.rs

1use rama_core::{Context, context::Extensions};
2use std::net::SocketAddr;
3
4#[cfg(feature = "http")]
5use crate::stream::SocketInfo;
6#[cfg(feature = "http")]
7use rama_http_types::Request;
8
9#[derive(Debug, Clone)]
10/// Matcher based on the [`SocketAddr`] of the peer.
11pub struct SocketAddressMatcher {
12    addr: SocketAddr,
13    optional: bool,
14}
15
16impl SocketAddressMatcher {
17    /// create a new socket address matcher to match on a socket address
18    ///
19    /// This matcher will not match in case socket address could not be found,
20    /// if you want to match in case socket address could not be found,
21    /// use the [`SocketAddressMatcher::optional`] constructor..
22    pub fn new(addr: impl Into<SocketAddr>) -> Self {
23        Self {
24            addr: addr.into(),
25            optional: false,
26        }
27    }
28
29    /// create a new socket address matcher to match on a socket address
30    ///
31    /// This matcher will match in case socket address could not be found.
32    /// Use the [`SocketAddressMatcher::new`] constructor if you want do not want
33    /// to match in case socket address could not be found.
34    pub fn optional(addr: impl Into<SocketAddr>) -> Self {
35        Self {
36            addr: addr.into(),
37            optional: true,
38        }
39    }
40}
41
42#[cfg(feature = "http")]
43impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for SocketAddressMatcher {
44    fn matches(
45        &self,
46        _ext: Option<&mut Extensions>,
47        ctx: &Context<State>,
48        _req: &Request<Body>,
49    ) -> bool {
50        ctx.get::<SocketInfo>()
51            .map(|info| info.peer_addr() == &self.addr)
52            .unwrap_or(self.optional)
53    }
54}
55
56impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for SocketAddressMatcher
57where
58    Socket: crate::stream::Socket,
59{
60    fn matches(
61        &self,
62        _ext: Option<&mut Extensions>,
63        _ctx: &Context<State>,
64        stream: &Socket,
65    ) -> bool {
66        stream
67            .peer_addr()
68            .map(|addr| addr == self.addr)
69            .unwrap_or(self.optional)
70    }
71}
72
73#[cfg(feature = "http")]
74#[cfg(test)]
75mod test {
76    use rama_core::matcher::Matcher;
77    use rama_http_types::Body;
78
79    use super::*;
80
81    #[test]
82    fn test_socket_matcher_http() {
83        let matcher = SocketAddressMatcher::new(([127, 0, 0, 1], 8080));
84
85        let mut ctx = Context::default();
86        let req = Request::builder()
87            .method("GET")
88            .uri("/hello")
89            .body(Body::empty())
90            .unwrap();
91
92        // test #1: no match: test with no socket info registered
93        assert!(!matcher.matches(None, &ctx, &req));
94
95        // test #2: no match: test with different socket info (port difference)
96        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8081).into()));
97        assert!(!matcher.matches(None, &ctx, &req));
98
99        // test #3: no match: test with different socket info (ip addr difference)
100        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
101        assert!(!matcher.matches(None, &ctx, &req));
102
103        // test #4: match: test with correct address
104        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
105        assert!(matcher.matches(None, &ctx, &req));
106
107        // test #5: match: test with missing socket info, but it's seen as optional
108        let matcher = SocketAddressMatcher::optional(([127, 0, 0, 1], 8080));
109        let ctx = Context::default();
110        assert!(matcher.matches(None, &ctx, &req));
111    }
112
113    #[test]
114    fn test_socket_matcher_socket_trait() {
115        let matcher = SocketAddressMatcher::new(([127, 0, 0, 1], 8080));
116
117        let ctx = Context::default();
118
119        struct FakeSocket {
120            local_addr: Option<SocketAddr>,
121            peer_addr: Option<SocketAddr>,
122        }
123
124        impl crate::stream::Socket for FakeSocket {
125            fn local_addr(&self) -> std::io::Result<SocketAddr> {
126                match &self.local_addr {
127                    Some(addr) => Ok(*addr),
128                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
129                }
130            }
131
132            fn peer_addr(&self) -> std::io::Result<SocketAddr> {
133                match &self.peer_addr {
134                    Some(addr) => Ok(*addr),
135                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
136                }
137            }
138        }
139
140        let mut socket = FakeSocket {
141            local_addr: None,
142            peer_addr: Some(([127, 0, 0, 1], 8081).into()),
143        };
144
145        // test #1: no match: test with different socket info (port difference)
146        assert!(!matcher.matches(None, &ctx, &socket));
147
148        // test #2: no match: test with different socket info (ip addr difference)
149        socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
150        assert!(!matcher.matches(None, &ctx, &socket));
151
152        // test #3: match: test with correct address
153        socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
154        assert!(matcher.matches(None, &ctx, &socket));
155
156        // test #5: match: test with missing socket info, but it's seen as optional
157        let matcher = SocketAddressMatcher::optional(([127, 0, 0, 1], 8080));
158        socket.peer_addr = None;
159        assert!(matcher.matches(None, &ctx, &socket));
160    }
161}