1use rama_core::{Context, context::Extensions};
2
3#[cfg(feature = "http")]
4use crate::stream::SocketInfo;
5#[cfg(feature = "http")]
6use rama_http_types::Request;
7
8#[derive(Debug, Clone)]
9pub struct LoopbackMatcher {
14 optional: bool,
15}
16
17impl LoopbackMatcher {
18 pub const fn new() -> Self {
27 Self { optional: false }
28 }
29
30 pub const fn optional() -> Self {
39 Self { optional: true }
40 }
41}
42
43impl Default for LoopbackMatcher {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49#[cfg(feature = "http")]
50impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for LoopbackMatcher {
51 fn matches(
52 &self,
53 _ext: Option<&mut Extensions>,
54 ctx: &Context<State>,
55 _req: &Request<Body>,
56 ) -> bool {
57 ctx.get::<SocketInfo>()
58 .map(|info| info.peer_addr().ip().is_loopback())
59 .unwrap_or(self.optional)
60 }
61}
62
63impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for LoopbackMatcher
64where
65 Socket: crate::stream::Socket,
66{
67 fn matches(
68 &self,
69 _ext: Option<&mut Extensions>,
70 _ctx: &Context<State>,
71 stream: &Socket,
72 ) -> bool {
73 stream
74 .peer_addr()
75 .map(|addr| addr.ip().is_loopback())
76 .unwrap_or(self.optional)
77 }
78}
79
80#[cfg(test)]
81mod test {
82 use super::*;
83 use rama_core::matcher::Matcher;
84 use std::net::SocketAddr;
85
86 #[cfg(feature = "http")]
87 #[test]
88 fn test_loopback_matcher_http() {
89 let matcher = LoopbackMatcher::new();
90
91 let mut ctx = Context::default();
92 let req = Request::builder()
93 .method("GET")
94 .uri("/hello")
95 .body(())
96 .unwrap();
97
98 assert!(!matcher.matches(None, &ctx, &req));
100
101 ctx.insert(SocketInfo::new(None, ([192, 168, 0, 1], 8080).into()));
103 assert!(!matcher.matches(None, &ctx, &req));
104
105 ctx.insert(SocketInfo::new(
107 None,
108 ([1, 1, 1, 1, 1, 1, 1, 1], 8080).into(),
109 ));
110 assert!(!matcher.matches(None, &ctx, &req));
111
112 ctx.insert(SocketInfo::new(None, ([127, 0, 0, 1], 8080).into()));
114 assert!(matcher.matches(None, &ctx, &req));
115
116 ctx.insert(SocketInfo::new(None, ([127, 3, 2, 1], 8080).into()));
118 assert!(matcher.matches(None, &ctx, &req));
119
120 ctx.insert(SocketInfo::new(
122 None,
123 ([0, 0, 0, 0, 0, 0, 0, 1], 8080).into(),
124 ));
125 assert!(matcher.matches(None, &ctx, &req));
126
127 let matcher = LoopbackMatcher::optional();
129 let ctx = Context::default();
130 assert!(matcher.matches(None, &ctx, &req));
131 }
132
133 #[test]
134 fn test_loopback_matcher_socket_trait() {
135 let matcher = LoopbackMatcher::new();
136
137 let ctx = Context::default();
138
139 struct FakeSocket {
140 local_addr: Option<SocketAddr>,
141 peer_addr: Option<SocketAddr>,
142 }
143
144 impl crate::stream::Socket for FakeSocket {
145 fn local_addr(&self) -> std::io::Result<SocketAddr> {
146 match &self.local_addr {
147 Some(addr) => Ok(*addr),
148 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
149 }
150 }
151
152 fn peer_addr(&self) -> std::io::Result<SocketAddr> {
153 match &self.peer_addr {
154 Some(addr) => Ok(*addr),
155 None => Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)),
156 }
157 }
158 }
159
160 let mut socket = FakeSocket {
161 local_addr: None,
162 peer_addr: None,
163 };
164
165 assert!(!matcher.matches(None, &ctx, &socket));
167
168 socket.peer_addr = Some(([192, 168, 0, 1], 8080).into());
170 assert!(!matcher.matches(None, &ctx, &socket));
171
172 socket.peer_addr = Some(([1, 1, 1, 1, 1, 1, 1, 1], 8080).into());
174 assert!(!matcher.matches(None, &ctx, &socket));
175
176 socket.peer_addr = Some(([127, 0, 0, 1], 8080).into());
178 assert!(matcher.matches(None, &ctx, &socket));
179
180 socket.peer_addr = Some(([127, 3, 2, 1], 8080).into());
182 assert!(matcher.matches(None, &ctx, &socket));
183
184 socket.peer_addr = Some(([0, 0, 0, 0, 0, 0, 0, 1], 8080).into());
186 assert!(matcher.matches(None, &ctx, &socket));
187
188 let matcher = LoopbackMatcher::optional();
190 socket.peer_addr = None;
191 assert!(matcher.matches(None, &ctx, &socket));
192 }
193}