relay_core_lib/capture/
loop_detection.rs1use std::net::{SocketAddr, IpAddr};
2use std::collections::{BTreeSet, HashSet};
3use std::sync::{Arc, RwLock};
4use if_addrs::get_if_addrs;
5
6pub struct LoopDetector {
7 listen_addrs: BTreeSet<SocketAddr>,
9
10 local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15 pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16 Self {
17 listen_addrs,
18 local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19 }
20 }
21
22 pub fn would_loop(&self, target: SocketAddr) -> bool {
24 if self.listen_addrs.contains(&target) {
26 return true;
27 }
28
29 if self.is_local_ip(target.ip()) {
31 let listen_ports: HashSet<u16> = self.listen_addrs.iter()
32 .map(|a| a.port())
33 .collect();
34 if listen_ports.contains(&target.port()) {
35 return true;
36 }
37 }
38
39 false
40 }
41
42 fn is_local_ip(&self, ip: IpAddr) -> bool {
43 if ip.is_loopback() {
44 return true;
45 }
46 if ip.is_unspecified() {
47 return true;
48 }
49
50 if let Ok(guard) = self.local_addrs.read() {
52 if guard.contains(&ip) {
53 return true;
54 }
55 }
56
57 false
58 }
59
60 pub async fn refresh_local_addrs(&self) {
62 let mut ips = BTreeSet::new();
63
64 if let Ok(ifaces) = get_if_addrs() {
65 for iface in ifaces {
66 ips.insert(iface.ip());
67 }
68 }
69
70 ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
72 ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
73 ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
74 ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
75
76 if let Ok(mut guard) = self.local_addrs.write() {
77 *guard = ips;
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::str::FromStr;
86
87 #[test]
88 fn test_would_loop_on_direct_listen_match() {
89 let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
90 let detector = LoopDetector::new(listen);
91 let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
92 assert!(detector.would_loop(target));
93 }
94
95 #[test]
96 fn test_would_loop_on_cached_local_ip_and_listen_port() {
97 let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
98 let detector = LoopDetector::new(listen);
99
100 if let Ok(mut guard) = detector.local_addrs.write() {
101 guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
102 }
103
104 let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
105 assert!(detector.would_loop(target));
106 }
107
108 #[tokio::test]
109 async fn test_refresh_local_addrs_populates_cache() {
110 let detector = LoopDetector::new(BTreeSet::new());
111 detector.refresh_local_addrs().await;
112
113 let guard = detector.local_addrs.read().expect("lock");
114 assert!(
115 !guard.is_empty(),
116 "refresh should populate at least fallback local addresses"
117 );
118 }
119}