relay_core_lib/capture/
loop_detection.rs1use std::net::{SocketAddr, IpAddr};
2use std::collections::{BTreeSet, HashSet};
3use std::sync::{Arc, RwLock};
4use if_addrs::get_if_addrs;
5
6pub struct LoopDetector {
7 listen_addrs: BTreeSet<SocketAddr>,
9
10 local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15 pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16 Self {
17 listen_addrs,
18 local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19 }
20 }
21
22 pub fn would_loop(&self, target: SocketAddr) -> bool {
24 if self.listen_addrs.contains(&target) {
26 return true;
27 }
28
29 if self.is_local_ip(target.ip()) {
31 let listen_ports: HashSet<u16> = self.listen_addrs.iter()
32 .map(|a| a.port())
33 .collect();
34 if listen_ports.contains(&target.port()) {
35 return true;
36 }
37 }
38
39 false
40 }
41
42 fn is_local_ip(&self, ip: IpAddr) -> bool {
43 if ip.is_loopback() {
44 return true;
45 }
46 if ip.is_unspecified() {
47 return true;
48 }
49
50 if let Ok(guard) = self.local_addrs.read()
52 && guard.contains(&ip) {
53 return true;
54 }
55
56 false
57 }
58
59 pub async fn refresh_local_addrs(&self) {
61 let mut ips = BTreeSet::new();
62
63 if let Ok(ifaces) = get_if_addrs() {
64 for iface in ifaces {
65 ips.insert(iface.ip());
66 }
67 }
68
69 ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
71 ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
72 ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
73 ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
74
75 if let Ok(mut guard) = self.local_addrs.write() {
76 *guard = ips;
77 }
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use std::str::FromStr;
85
86 #[test]
87 fn test_would_loop_on_direct_listen_match() {
88 let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
89 let detector = LoopDetector::new(listen);
90 let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
91 assert!(detector.would_loop(target));
92 }
93
94 #[test]
95 fn test_would_loop_on_cached_local_ip_and_listen_port() {
96 let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
97 let detector = LoopDetector::new(listen);
98
99 if let Ok(mut guard) = detector.local_addrs.write() {
100 guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
101 }
102
103 let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
104 assert!(detector.would_loop(target));
105 }
106
107 #[tokio::test]
108 async fn test_refresh_local_addrs_populates_cache() {
109 let detector = LoopDetector::new(BTreeSet::new());
110 detector.refresh_local_addrs().await;
111
112 let guard = detector.local_addrs.read().expect("lock");
113 assert!(
114 !guard.is_empty(),
115 "refresh should populate at least fallback local addresses"
116 );
117 }
118}