rama_net/stream/matcher/
port.rs1use rama_core::{Context, context::Extensions};
2
3#[cfg(feature = "http")]
4use crate::stream::SocketInfo;
5#[cfg(feature = "http")]
6use rama_http_types::Request;
7
8#[derive(Debug, Clone)]
9pub struct PortMatcher {
13 port: u16,
14 optional: bool,
15}
16
17impl PortMatcher {
18 pub const fn new(port: u16) -> Self {
26 Self {
27 port,
28 optional: false,
29 }
30 }
31
32 pub const fn optional(port: u16) -> Self {
40 Self {
41 port,
42 optional: true,
43 }
44 }
45}
46
47#[cfg(feature = "http")]
48impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for PortMatcher {
49 fn matches(
50 &self,
51 _ext: Option<&mut Extensions>,
52 ctx: &Context<State>,
53 _req: &Request<Body>,
54 ) -> bool {
55 ctx.get::<SocketInfo>()
56 .map(|info| info.peer_addr().port() == self.port)
57 .unwrap_or(self.optional)
58 }
59}
60
61impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for PortMatcher
62where
63 Socket: crate::stream::Socket,
64{
65 fn matches(
66 &self,
67 _ext: Option<&mut Extensions>,
68 _ctx: &Context<State>,
69 stream: &Socket,
70 ) -> bool {
71 stream
72 .peer_addr()
73 .map(|addr| addr.port() == self.port)
74 .unwrap_or(self.optional)
75 }
76}
77
78#[cfg(test)]
79mod test {
80 use super::*;
81 use rama_core::matcher::Matcher;
82 use std::net::SocketAddr;
83
84 #[cfg(feature = "http")]
85 #[test]
86 fn test_port_matcher_http() {
87 let matcher = PortMatcher::new(8080);
88
89 let mut ctx = Context::default();
90 let req = Request::builder()
91 .method("GET")
92 .uri("/hello")
93 .body(())
94 .unwrap();
95
96 assert!(!matcher.matches(None, &ctx, &req));
98
99 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8081).into()));
101 assert!(!matcher.matches(None, &ctx, &req));
102
103 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
105 assert!(matcher.matches(None, &ctx, &req));
106
107 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
109 assert!(matcher.matches(None, &ctx, &req));
110
111 let matcher = PortMatcher::optional(8080);
113 let ctx = Context::default();
114 assert!(matcher.matches(None, &ctx, &req));
115 }
116
117 #[test]
118 fn test_port_matcher_socket_trait() {
119 let matcher = PortMatcher::new(8080);
120
121 let ctx = Context::default();
122
123 struct FakeSocket {
124 local_addr: Option<SocketAddr>,
125 peer_addr: Option<SocketAddr>,
126 }
127
128 impl crate::stream::Socket for FakeSocket {
129 fn local_addr(&self) -> std::io::Result<SocketAddr> {
130 match &self.local_addr {
131 Some(addr) => Ok(*addr),
132 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
133 }
134 }
135
136 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
137 match &self.peer_addr {
138 Some(addr) => Ok(*addr),
139 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
140 }
141 }
142 }
143
144 let mut socket = FakeSocket {
145 local_addr: None,
146 peer_addr: Some(([127, 0, 0, 1], 8081).into()),
147 };
148
149 assert!(!matcher.matches(None, &ctx, &socket));
151
152 socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
154 assert!(matcher.matches(None, &ctx, &socket));
155
156 socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
158 assert!(matcher.matches(None, &ctx, &socket));
159
160 let matcher = PortMatcher::optional(8080);
162 socket.peer_addr = None;
163 assert!(matcher.matches(None, &ctx, &socket));
164 }
165}