1mod socket;
10#[doc(inline)]
11pub use socket::SocketAddressMatcher;
12
13mod port;
14#[doc(inline)]
15pub use port::PortMatcher;
16
17mod private_ip;
18#[doc(inline)]
19pub use private_ip::PrivateIpNetMatcher;
20
21mod loopback;
22#[doc(inline)]
23pub use loopback::LoopbackMatcher;
24
25pub mod ip;
26#[doc(inline)]
27pub use ip::IpNetMatcher;
28
29use rama_core::{Context, context::Extensions, matcher::IteratorMatcherExt};
30use std::{fmt, sync::Arc};
31
32#[cfg(feature = "http")]
33use rama_http_types::Request;
34
35pub struct SocketMatcher<State, Socket> {
39 kind: SocketMatcherKind<State, Socket>,
40 negate: bool,
41}
42
43impl<State, Socket> Clone for SocketMatcher<State, Socket> {
44 fn clone(&self) -> Self {
45 Self {
46 kind: self.kind.clone(),
47 negate: self.negate,
48 }
49 }
50}
51
52impl<State, Socket> fmt::Debug for SocketMatcher<State, Socket> {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("SocketMatcher")
55 .field("kind", &self.kind)
56 .field("negate", &self.negate)
57 .finish()
58 }
59}
60
61enum SocketMatcherKind<State, Socket> {
63 SocketAddress(SocketAddressMatcher),
67 Loopback(LoopbackMatcher),
69 PrivateIpNet(PrivateIpNetMatcher),
71 Port(PortMatcher),
75 IpNet(IpNetMatcher),
81 All(Vec<SocketMatcher<State, Socket>>),
83 Any(Vec<SocketMatcher<State, Socket>>),
85 Custom(Arc<dyn rama_core::matcher::Matcher<State, Socket>>),
87}
88
89impl<State, Socket> Clone for SocketMatcherKind<State, Socket> {
90 fn clone(&self) -> Self {
91 match self {
92 Self::SocketAddress(matcher) => Self::SocketAddress(matcher.clone()),
93 Self::Loopback(matcher) => Self::Loopback(matcher.clone()),
94 Self::PrivateIpNet(matcher) => Self::PrivateIpNet(matcher.clone()),
95 Self::Port(matcher) => Self::Port(matcher.clone()),
96 Self::IpNet(matcher) => Self::IpNet(matcher.clone()),
97 Self::All(matcher) => Self::All(matcher.clone()),
98 Self::Any(matcher) => Self::Any(matcher.clone()),
99 Self::Custom(matcher) => Self::Custom(matcher.clone()),
100 }
101 }
102}
103
104impl<State, Socket> fmt::Debug for SocketMatcherKind<State, Socket> {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Self::SocketAddress(matcher) => f.debug_tuple("SocketAddress").field(matcher).finish(),
108 Self::Loopback(matcher) => f.debug_tuple("Loopback").field(matcher).finish(),
109 Self::PrivateIpNet(matcher) => f.debug_tuple("PrivateIpNet").field(matcher).finish(),
110 Self::Port(matcher) => f.debug_tuple("Port").field(matcher).finish(),
111 Self::IpNet(matcher) => f.debug_tuple("IpNet").field(matcher).finish(),
112 Self::All(matcher) => f.debug_tuple("All").field(matcher).finish(),
113 Self::Any(matcher) => f.debug_tuple("Any").field(matcher).finish(),
114 Self::Custom(_) => f.debug_tuple("Custom").finish(),
115 }
116 }
117}
118
119impl<State, Socket> SocketMatcher<State, Socket> {
120 pub fn socket_addr(addr: impl Into<std::net::SocketAddr>) -> Self {
124 Self {
125 kind: SocketMatcherKind::SocketAddress(SocketAddressMatcher::new(addr)),
126 negate: false,
127 }
128 }
129
130 pub fn optional_socket_addr(addr: impl Into<std::net::SocketAddr>) -> Self {
135 Self {
136 kind: SocketMatcherKind::SocketAddress(SocketAddressMatcher::optional(addr)),
137 negate: false,
138 }
139 }
140
141 pub fn and_socket_addr(self, addr: impl Into<std::net::SocketAddr>) -> Self {
143 self.and(Self::socket_addr(addr))
144 }
145
146 pub fn and_optional_socket_addr(self, addr: impl Into<std::net::SocketAddr>) -> Self {
150 self.and(Self::optional_socket_addr(addr))
151 }
152
153 pub fn or_socket_addr(self, addr: impl Into<std::net::SocketAddr>) -> Self {
157 self.or(Self::socket_addr(addr))
158 }
159
160 pub fn or_optional_socket_addr(self, addr: impl Into<std::net::SocketAddr>) -> Self {
164 self.or(Self::optional_socket_addr(addr))
165 }
166
167 pub fn loopback() -> Self {
171 Self {
172 kind: SocketMatcherKind::Loopback(LoopbackMatcher::new()),
173 negate: false,
174 }
175 }
176
177 pub fn optional_loopback() -> Self {
182 Self {
183 kind: SocketMatcherKind::Loopback(LoopbackMatcher::optional()),
184 negate: false,
185 }
186 }
187
188 pub fn and_loopback(self) -> Self {
192 self.and(Self::loopback())
193 }
194
195 pub fn and_optional_loopback(self) -> Self {
199 self.and(Self::optional_loopback())
200 }
201
202 pub fn or_loopback(self) -> Self {
206 self.or(Self::loopback())
207 }
208
209 pub fn or_optional_loopback(self) -> Self {
213 self.or(Self::optional_loopback())
214 }
215
216 pub fn port(port: u16) -> Self {
220 Self {
221 kind: SocketMatcherKind::Port(PortMatcher::new(port)),
222 negate: false,
223 }
224 }
225
226 pub fn optional_port(port: u16) -> Self {
231 Self {
232 kind: SocketMatcherKind::Port(PortMatcher::optional(port)),
233 negate: false,
234 }
235 }
236
237 pub fn and_port(self, port: u16) -> Self {
242 self.and(Self::port(port))
243 }
244
245 pub fn and_optional_port(self, port: u16) -> Self {
250 self.and(Self::optional_port(port))
251 }
252
253 pub fn or_port(self, port: u16) -> Self {
258 self.or(Self::port(port))
259 }
260
261 pub fn or_optional_port(self, port: u16) -> Self {
266 self.or(Self::optional_port(port))
267 }
268
269 pub fn ip_net(ip_net: impl ip::IntoIpNet) -> Self {
273 Self {
274 kind: SocketMatcherKind::IpNet(IpNetMatcher::new(ip_net)),
275 negate: false,
276 }
277 }
278
279 pub fn optional_ip_net(ip_net: impl ip::IntoIpNet) -> Self {
284 Self {
285 kind: SocketMatcherKind::IpNet(IpNetMatcher::optional(ip_net)),
286 negate: false,
287 }
288 }
289
290 pub fn and_ip_net(self, ip_net: impl ip::IntoIpNet) -> Self {
294 self.and(Self::ip_net(ip_net))
295 }
296
297 pub fn and_optional_ip_net(self, ip_net: impl ip::IntoIpNet) -> Self {
301 self.and(Self::optional_ip_net(ip_net))
302 }
303
304 pub fn or_ip_net(self, ip_net: impl ip::IntoIpNet) -> Self {
308 self.or(Self::ip_net(ip_net))
309 }
310
311 pub fn or_optional_ip_net(self, ip_net: impl ip::IntoIpNet) -> Self {
315 self.or(Self::optional_ip_net(ip_net))
316 }
317
318 pub fn private_ip_net() -> Self {
322 Self {
323 kind: SocketMatcherKind::PrivateIpNet(PrivateIpNetMatcher::new()),
324 negate: false,
325 }
326 }
327
328 pub fn optional_private_ip_net() -> Self {
333 Self {
334 kind: SocketMatcherKind::PrivateIpNet(PrivateIpNetMatcher::optional()),
335 negate: false,
336 }
337 }
338
339 pub fn and_private_ip_net(self) -> Self {
343 self.and(Self::private_ip_net())
344 }
345
346 pub fn and_optional_private_ip_net(self) -> Self {
350 self.and(Self::optional_private_ip_net())
351 }
352
353 pub fn or_private_ip_net(self) -> Self {
357 self.or(Self::private_ip_net())
358 }
359
360 pub fn or_optional_private_ip_net(self) -> Self {
364 self.or(Self::optional_private_ip_net())
365 }
366
367 pub fn custom<M>(matcher: M) -> Self
371 where
372 M: rama_core::matcher::Matcher<State, Socket>,
373 {
374 Self {
375 kind: SocketMatcherKind::Custom(Arc::new(matcher)),
376 negate: false,
377 }
378 }
379
380 pub fn and_custom<M>(self, matcher: M) -> Self
384 where
385 M: rama_core::matcher::Matcher<State, Socket>,
386 {
387 self.and(Self::custom(matcher))
388 }
389
390 pub fn or_custom<M>(self, matcher: M) -> Self
394 where
395 M: rama_core::matcher::Matcher<State, Socket>,
396 {
397 self.or(Self::custom(matcher))
398 }
399
400 pub fn and(mut self, matcher: SocketMatcher<State, Socket>) -> Self {
402 match (self.negate, &mut self.kind) {
403 (false, SocketMatcherKind::All(v)) => {
404 v.push(matcher);
405 self
406 }
407 _ => SocketMatcher {
408 kind: SocketMatcherKind::All(vec![self, matcher]),
409 negate: false,
410 },
411 }
412 }
413
414 pub fn or(mut self, matcher: SocketMatcher<State, Socket>) -> Self {
416 match (self.negate, &mut self.kind) {
417 (false, SocketMatcherKind::Any(v)) => {
418 v.push(matcher);
419 self
420 }
421 _ => SocketMatcher {
422 kind: SocketMatcherKind::Any(vec![self, matcher]),
423 negate: false,
424 },
425 }
426 }
427
428 pub fn negate(self) -> Self {
430 Self {
431 kind: self.kind,
432 negate: true,
433 }
434 }
435}
436
437#[cfg(feature = "http")]
438impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>>
439 for SocketMatcherKind<State, Request<Body>>
440where
441 State: 'static,
442 Body: 'static,
443{
444 fn matches(
445 &self,
446 ext: Option<&mut Extensions>,
447 ctx: &Context<State>,
448 req: &Request<Body>,
449 ) -> bool {
450 match self {
451 SocketMatcherKind::SocketAddress(matcher) => matcher.matches(ext, ctx, req),
452 SocketMatcherKind::IpNet(matcher) => matcher.matches(ext, ctx, req),
453 SocketMatcherKind::Loopback(matcher) => matcher.matches(ext, ctx, req),
454 SocketMatcherKind::PrivateIpNet(matcher) => matcher.matches(ext, ctx, req),
455 SocketMatcherKind::All(matchers) => matchers.iter().matches_and(ext, ctx, req),
456 SocketMatcherKind::Any(matchers) => matchers.iter().matches_or(ext, ctx, req),
457 SocketMatcherKind::Port(matcher) => matcher.matches(ext, ctx, req),
458 SocketMatcherKind::Custom(matcher) => matcher.matches(ext, ctx, req),
459 }
460 }
461}
462
463#[cfg(feature = "http")]
464impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>>
465 for SocketMatcher<State, Request<Body>>
466where
467 State: 'static,
468 Body: 'static,
469{
470 fn matches(
471 &self,
472 ext: Option<&mut Extensions>,
473 ctx: &Context<State>,
474 req: &Request<Body>,
475 ) -> bool {
476 let result = self.kind.matches(ext, ctx, req);
477 if self.negate { !result } else { result }
478 }
479}
480
481impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for SocketMatcherKind<State, Socket>
482where
483 Socket: crate::stream::Socket,
484 State: 'static,
485{
486 fn matches(&self, ext: Option<&mut Extensions>, ctx: &Context<State>, stream: &Socket) -> bool {
487 match self {
488 SocketMatcherKind::SocketAddress(matcher) => matcher.matches(ext, ctx, stream),
489 SocketMatcherKind::IpNet(matcher) => matcher.matches(ext, ctx, stream),
490 SocketMatcherKind::Loopback(matcher) => matcher.matches(ext, ctx, stream),
491 SocketMatcherKind::PrivateIpNet(matcher) => matcher.matches(ext, ctx, stream),
492 SocketMatcherKind::Port(matcher) => matcher.matches(ext, ctx, stream),
493 SocketMatcherKind::All(matchers) => matchers.iter().matches_and(ext, ctx, stream),
494 SocketMatcherKind::Any(matchers) => matchers.iter().matches_or(ext, ctx, stream),
495 SocketMatcherKind::Custom(matcher) => matcher.matches(ext, ctx, stream),
496 }
497 }
498}
499
500impl<State, Socket> rama_core::matcher::Matcher<State, Socket> for SocketMatcher<State, Socket>
501where
502 Socket: crate::stream::Socket,
503 State: 'static,
504{
505 fn matches(&self, ext: Option<&mut Extensions>, ctx: &Context<State>, stream: &Socket) -> bool {
506 let result = self.kind.matches(ext, ctx, stream);
507 if self.negate { !result } else { result }
508 }
509}
510
511#[cfg(all(test, feature = "http"))]
512mod test {
513 use itertools::Itertools;
514
515 use rama_core::matcher::Matcher;
516
517 use super::*;
518
519 struct BooleanMatcher(bool);
520
521 impl Matcher<(), Request<()>> for BooleanMatcher {
522 fn matches(
523 &self,
524 _ext: Option<&mut Extensions>,
525 _ctx: &Context<()>,
526 _req: &Request<()>,
527 ) -> bool {
528 self.0
529 }
530 }
531
532 #[test]
533 fn test_matcher_and_combination() {
534 for v in [true, false].into_iter().permutations(3) {
535 let expected = v[0] && v[1] && v[2];
536 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
537 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
538 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
539
540 let matcher = a.and(b).and(c);
541 let req = Request::builder().body(()).unwrap();
542 assert_eq!(
543 matcher.matches(None, &Context::default(), &req),
544 expected,
545 "({:#?}).matches({:#?})",
546 matcher,
547 req
548 );
549 }
550 }
551
552 #[test]
553 fn test_matcher_negation_with_and_combination() {
554 for v in [true, false].into_iter().permutations(3) {
555 let expected = !v[0] && v[1] && v[2];
556 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
557 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
558 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
559
560 let matcher = a.negate().and(b).and(c);
561 let req = Request::builder().body(()).unwrap();
562 assert_eq!(
563 matcher.matches(None, &Context::default(), &req),
564 expected,
565 "({:#?}).matches({:#?})",
566 matcher,
567 req
568 );
569 }
570 }
571
572 #[test]
573 fn test_matcher_and_combination_negated() {
574 for v in [true, false].into_iter().permutations(3) {
575 let expected = !(v[0] && v[1] && v[2]);
576 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
577 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
578 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
579
580 let matcher = a.and(b).and(c).negate();
581 let req = Request::builder().body(()).unwrap();
582 assert_eq!(
583 matcher.matches(None, &Context::default(), &req),
584 expected,
585 "({:#?}).matches({:#?})",
586 matcher,
587 req
588 );
589 }
590 }
591
592 #[test]
593 fn test_matcher_ors_combination() {
594 for v in [true, false].into_iter().permutations(3) {
595 let expected = v[0] || v[1] || v[2];
596 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
597 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
598 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
599
600 let matcher = a.or(b).or(c);
601 let req = Request::builder().body(()).unwrap();
602 assert_eq!(
603 matcher.matches(None, &Context::default(), &req),
604 expected,
605 "({:#?}).matches({:#?})",
606 matcher,
607 req
608 );
609 }
610 }
611
612 #[test]
613 fn test_matcher_negation_with_ors_combination() {
614 for v in [true, false].into_iter().permutations(3) {
615 let expected = !v[0] || v[1] || v[2];
616 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
617 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
618 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
619
620 let matcher = a.negate().or(b).or(c);
621 let req = Request::builder().body(()).unwrap();
622 assert_eq!(
623 matcher.matches(None, &Context::default(), &req),
624 expected,
625 "({:#?}).matches({:#?})",
626 matcher,
627 req
628 );
629 }
630 }
631
632 #[test]
633 fn test_matcher_ors_combination_negated() {
634 for v in [true, false].into_iter().permutations(3) {
635 let expected = !(v[0] || v[1] || v[2]);
636 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
637 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
638 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
639
640 let matcher = a.or(b).or(c).negate();
641 let req = Request::builder().body(()).unwrap();
642 assert_eq!(
643 matcher.matches(None, &Context::default(), &req),
644 expected,
645 "({:#?}).matches({:#?})",
646 matcher,
647 req
648 );
649 }
650 }
651
652 #[test]
653 fn test_matcher_or_and_or_and_negation() {
654 for v in [true, false].into_iter().permutations(5) {
655 let expected = (v[0] || v[1]) && (v[2] || v[3]) && !v[4];
656 let a = SocketMatcher::custom(BooleanMatcher(v[0]));
657 let b = SocketMatcher::custom(BooleanMatcher(v[1]));
658 let c = SocketMatcher::custom(BooleanMatcher(v[2]));
659 let d = SocketMatcher::custom(BooleanMatcher(v[3]));
660 let e = SocketMatcher::custom(BooleanMatcher(v[4]));
661
662 let matcher = (a.or(b)).and(c.or(d)).and(e.negate());
663 let req = Request::builder().body(()).unwrap();
664 assert_eq!(
665 matcher.matches(None, &Context::default(), &req),
666 expected,
667 "({:#?}).matches({:#?})",
668 matcher,
669 req
670 );
671 }
672 }
673}