1use rama_core::{Context, context::Extensions};
2use std::net::SocketAddr;
3
4#[cfg(feature = "http")]
5use crate::stream::SocketInfo;
6#[cfg(feature = "http")]
7use rama_http_types::Request;
8
9#[derive(Debug, Clone)]
10pub struct SocketAddressMatcher {
12 addr: SocketAddr,
13 optional: bool,
14}
15
16impl SocketAddressMatcher {
17 pub fn new(addr: impl Into<SocketAddr>) -> Self {
23 Self {
24 addr: addr.into(),
25 optional: false,
26 }
27 }
28
29 pub fn optional(addr: impl Into<SocketAddr>) -> Self {
35 Self {
36 addr: addr.into(),
37 optional: true,
38 }
39 }
40}
41
42#[cfg(feature = "http")]
43impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for SocketAddressMatcher {
44 fn matches(
45 &self,
46 _ext: Option<&mut Extensions>,
47 ctx: &Context<State>,
48 _req: &Request<Body>,
49 ) -> bool {
50 ctx.get::<SocketInfo>()
51 .map(|info| info.peer_addr() == &self.addr)
52 .unwrap_or(self.optional)
53 }
54}
55
56impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for SocketAddressMatcher
57where
58 Socket: crate::stream::Socket,
59{
60 fn matches(
61 &self,
62 _ext: Option<&mut Extensions>,
63 _ctx: &Context<State>,
64 stream: &Socket,
65 ) -> bool {
66 stream
67 .peer_addr()
68 .map(|addr| addr == self.addr)
69 .unwrap_or(self.optional)
70 }
71}
72
73#[cfg(feature = "http")]
74#[cfg(test)]
75mod test {
76 use rama_core::matcher::Matcher;
77 use rama_http_types::Body;
78
79 use super::*;
80
81 #[test]
82 fn test_socket_matcher_http() {
83 let matcher = SocketAddressMatcher::new(([127, 0, 0, 1], 8080));
84
85 let mut ctx = Context::default();
86 let req = Request::builder()
87 .method("GET")
88 .uri("/hello")
89 .body(Body::empty())
90 .unwrap();
91
92 assert!(!matcher.matches(None, &ctx, &req));
94
95 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8081).into()));
97 assert!(!matcher.matches(None, &ctx, &req));
98
99 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 2], 8080).into()));
101 assert!(!matcher.matches(None, &ctx, &req));
102
103 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
105 assert!(matcher.matches(None, &ctx, &req));
106
107 let matcher = SocketAddressMatcher::optional(([127, 0, 0, 1], 8080));
109 let ctx = Context::default();
110 assert!(matcher.matches(None, &ctx, &req));
111 }
112
113 #[test]
114 fn test_socket_matcher_socket_trait() {
115 let matcher = SocketAddressMatcher::new(([127, 0, 0, 1], 8080));
116
117 let ctx = Context::default();
118
119 struct FakeSocket {
120 local_addr: Option<SocketAddr>,
121 peer_addr: Option<SocketAddr>,
122 }
123
124 impl crate::stream::Socket for FakeSocket {
125 fn local_addr(&self) -> std::io::Result<SocketAddr> {
126 match &self.local_addr {
127 Some(addr) => Ok(*addr),
128 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
129 }
130 }
131
132 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
133 match &self.peer_addr {
134 Some(addr) => Ok(*addr),
135 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
136 }
137 }
138 }
139
140 let mut socket = FakeSocket {
141 local_addr: None,
142 peer_addr: Some(([127, 0, 0, 1], 8081).into()),
143 };
144
145 assert!(!matcher.matches(None, &ctx, &socket));
147
148 socket.peer_addr = Some(([127, 0, 0, 2], 8080).into());
150 assert!(!matcher.matches(None, &ctx, &socket));
151
152 socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
154 assert!(matcher.matches(None, &ctx, &socket));
155
156 let matcher = SocketAddressMatcher::optional(([127, 0, 0, 1], 8080));
158 socket.peer_addr = None;
159 assert!(matcher.matches(None, &ctx, &socket));
160 }
161}