Skip to main content

relay_core_lib/capture/
loop_detection.rs

1use std::net::{SocketAddr, IpAddr};
2use std::collections::{BTreeSet, HashSet};
3use std::sync::{Arc, RwLock};
4use if_addrs::get_if_addrs;
5
6pub struct LoopDetector {
7    /// All addresses the proxy is listening on
8    listen_addrs: BTreeSet<SocketAddr>,
9    
10    /// Local interface addresses (updated periodically)
11    local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15    pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16        Self {
17            listen_addrs,
18            local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19        }
20    }
21
22    /// Check if connecting to target would create a loop
23    pub fn would_loop(&self, target: SocketAddr) -> bool {
24        // Direct match with listening addresses
25        if self.listen_addrs.contains(&target) {
26            return true;
27        }
28        
29        // Check if target IP is local/loopback and port matches any listen port.
30        if self.is_local_ip(target.ip()) {
31             let listen_ports: HashSet<u16> = self.listen_addrs.iter()
32                .map(|a| a.port())
33                .collect();
34            if listen_ports.contains(&target.port()) {
35                return true;
36            }
37        }
38        
39        false
40    }
41    
42    fn is_local_ip(&self, ip: IpAddr) -> bool {
43        if ip.is_loopback() {
44            return true;
45        }
46        if ip.is_unspecified() {
47            return true;
48        }
49        
50        // Check cached local addrs
51        if let Ok(guard) = self.local_addrs.read()
52            && guard.contains(&ip) {
53                return true;
54            }
55        
56        false
57    }
58    
59    /// Refresh local interface addresses from system interfaces.
60    pub async fn refresh_local_addrs(&self) {
61        let mut ips = BTreeSet::new();
62
63        if let Ok(ifaces) = get_if_addrs() {
64            for iface in ifaces {
65                ips.insert(iface.ip());
66            }
67        }
68
69        // Keep loopback and unspecified in cache as a defensive fallback.
70        ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
71        ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
72        ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
73        ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
74
75        if let Ok(mut guard) = self.local_addrs.write() {
76            *guard = ips;
77        }
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use std::str::FromStr;
85
86    #[test]
87    fn test_would_loop_on_direct_listen_match() {
88        let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
89        let detector = LoopDetector::new(listen);
90        let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
91        assert!(detector.would_loop(target));
92    }
93
94    #[test]
95    fn test_would_loop_on_cached_local_ip_and_listen_port() {
96        let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
97        let detector = LoopDetector::new(listen);
98
99        if let Ok(mut guard) = detector.local_addrs.write() {
100            guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
101        }
102
103        let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
104        assert!(detector.would_loop(target));
105    }
106
107    #[tokio::test]
108    async fn test_refresh_local_addrs_populates_cache() {
109        let detector = LoopDetector::new(BTreeSet::new());
110        detector.refresh_local_addrs().await;
111
112        let guard = detector.local_addrs.read().expect("lock");
113        assert!(
114            !guard.is_empty(),
115            "refresh should populate at least fallback local addresses"
116        );
117    }
118}