relay_core_lib/capture/
loop_detection.rs1use if_addrs::get_if_addrs;
2use std::collections::{BTreeSet, HashSet};
3use std::net::{IpAddr, SocketAddr};
4use std::sync::{Arc, RwLock};
5
6pub struct LoopDetector {
7 listen_addrs: BTreeSet<SocketAddr>,
9
10 local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15 pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16 Self {
17 listen_addrs,
18 local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19 }
20 }
21
22 pub fn would_loop(&self, target: SocketAddr) -> bool {
24 if self.listen_addrs.contains(&target) {
26 return true;
27 }
28
29 if self.is_local_ip(target.ip()) {
31 let listen_ports: HashSet<u16> = self.listen_addrs.iter().map(|a| a.port()).collect();
32 if listen_ports.contains(&target.port()) {
33 return true;
34 }
35 }
36
37 false
38 }
39
40 fn is_local_ip(&self, ip: IpAddr) -> bool {
41 if ip.is_loopback() {
42 return true;
43 }
44 if ip.is_unspecified() {
45 return true;
46 }
47
48 if let Ok(guard) = self.local_addrs.read()
50 && guard.contains(&ip)
51 {
52 return true;
53 }
54
55 false
56 }
57
58 pub async fn refresh_local_addrs(&self) {
60 let mut ips = BTreeSet::new();
61
62 if let Ok(ifaces) = get_if_addrs() {
63 for iface in ifaces {
64 ips.insert(iface.ip());
65 }
66 }
67
68 ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
70 ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
71 ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
72 ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
73
74 if let Ok(mut guard) = self.local_addrs.write() {
75 *guard = ips;
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use std::str::FromStr;
84
85 #[test]
86 fn test_would_loop_on_direct_listen_match() {
87 let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
88 let detector = LoopDetector::new(listen);
89 let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
90 assert!(detector.would_loop(target));
91 }
92
93 #[test]
94 fn test_would_loop_on_cached_local_ip_and_listen_port() {
95 let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
96 let detector = LoopDetector::new(listen);
97
98 if let Ok(mut guard) = detector.local_addrs.write() {
99 guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
100 }
101
102 let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
103 assert!(detector.would_loop(target));
104 }
105
106 #[tokio::test]
107 async fn test_refresh_local_addrs_populates_cache() {
108 let detector = LoopDetector::new(BTreeSet::new());
109 detector.refresh_local_addrs().await;
110
111 let guard = detector.local_addrs.read().expect("lock");
112 assert!(
113 !guard.is_empty(),
114 "refresh should populate at least fallback local addresses"
115 );
116 }
117}