rama_net/stream/matcher/
port.rs

1use rama_core::{Context, context::Extensions};
2
3#[cfg(feature = "http")]
4use crate::stream::SocketInfo;
5#[cfg(feature = "http")]
6use rama_http_types::Request;
7
8#[derive(Debug, Clone)]
9/// Matcher based on the port part of the [`SocketAddr`] of the peer.
10///
11/// [`SocketAddr`]: std::net::SocketAddr
12pub struct PortMatcher {
13    port: u16,
14    optional: bool,
15}
16
17impl PortMatcher {
18    /// create a new port matcher to match on the port part a [`SocketAddr`]
19    ///
20    /// This matcher will not match in case socket address could not be found,
21    /// if you want to match in case socket address could not be found,
22    /// use the [`PortMatcher::optional`] constructor..
23    ///
24    /// [`SocketAddr`]: std::net::SocketAddr
25    pub const fn new(port: u16) -> Self {
26        Self {
27            port,
28            optional: false,
29        }
30    }
31
32    /// create a new port matcher to match on the port part a [`SocketAddr`]
33    ///
34    /// This matcher will match in case socket address could not be found.
35    /// Use the [`PortMatcher::new`] constructor if you want do not want
36    /// to match in case socket address could not be found.
37    ///
38    /// [`SocketAddr`]: std::net::SocketAddr
39    pub const fn optional(port: u16) -> Self {
40        Self {
41            port,
42            optional: true,
43        }
44    }
45}
46
47#[cfg(feature = "http")]
48impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for PortMatcher {
49    fn matches(
50        &self,
51        _ext: Option<&mut Extensions>,
52        ctx: &Context<State>,
53        _req: &Request<Body>,
54    ) -> bool {
55        ctx.get::<SocketInfo>()
56            .map(|info| info.peer_addr().port() == self.port)
57            .unwrap_or(self.optional)
58    }
59}
60
61impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for PortMatcher
62where
63    Socket: crate::stream::Socket,
64{
65    fn matches(
66        &self,
67        _ext: Option<&mut Extensions>,
68        _ctx: &Context<State>,
69        stream: &Socket,
70    ) -> bool {
71        stream
72            .peer_addr()
73            .map(|addr| addr.port() == self.port)
74            .unwrap_or(self.optional)
75    }
76}
77
78#[cfg(test)]
79mod test {
80    use super::*;
81    use rama_core::matcher::Matcher;
82    use std::net::SocketAddr;
83
84    #[cfg(feature = "http")]
85    #[test]
86    fn test_port_matcher_http() {
87        let matcher = PortMatcher::new(8080);
88
89        let mut ctx = Context::default();
90        let req = Request::builder()
91            .method("GET")
92            .uri("/hello")
93            .body(())
94            .unwrap();
95
96        // test #1: no match: test with no socket info registered
97        assert!(!matcher.matches(None, &ctx, &req));
98
99        // test #2: no match: test with different socket info (port difference)
100        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8081).into()));
101        assert!(!matcher.matches(None, &ctx, &req));
102
103        // test #3: match: test with matching port
104        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
105        assert!(matcher.matches(None, &ctx, &req));
106
107        // test #4: match: test with different ip, same port
108        ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
109        assert!(matcher.matches(None, &ctx, &req));
110
111        // test #5: match: test with missing socket info, but it's seen as optional
112        let matcher = PortMatcher::optional(8080);
113        let ctx = Context::default();
114        assert!(matcher.matches(None, &ctx, &req));
115    }
116
117    #[test]
118    fn test_port_matcher_socket_trait() {
119        let matcher = PortMatcher::new(8080);
120
121        let ctx = Context::default();
122
123        struct FakeSocket {
124            local_addr: Option<SocketAddr>,
125            peer_addr: Option<SocketAddr>,
126        }
127
128        impl crate::stream::Socket for FakeSocket {
129            fn local_addr(&self) -> std::io::Result<SocketAddr> {
130                match &self.local_addr {
131                    Some(addr) => Ok(*addr),
132                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
133                }
134            }
135
136            fn peer_addr(&self) -> std::io::Result<SocketAddr> {
137                match &self.peer_addr {
138                    Some(addr) => Ok(*addr),
139                    None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
140                }
141            }
142        }
143
144        let mut socket = FakeSocket {
145            local_addr: None,
146            peer_addr: Some(([127, 0, 0, 1], 8081).into()),
147        };
148
149        // test #1: no match: test with different socket info (port difference)
150        assert!(!matcher.matches(None, &ctx, &socket));
151
152        // test #2: match: test with correct port
153        socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
154        assert!(matcher.matches(None, &ctx, &socket));
155
156        // test #3: match: test with another correct address
157        socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
158        assert!(matcher.matches(None, &ctx, &socket));
159
160        // test #5: match: test with missing socket info, but it's seen as optional
161        let matcher = PortMatcher::optional(8080);
162        socket.peer_addr = None;
163        assert!(matcher.matches(None, &ctx, &socket));
164    }
165}