Skip to main content

relay_core_lib/capture/
loop_detection.rs

1use std::net::{SocketAddr, IpAddr};
2use std::collections::{BTreeSet, HashSet};
3use std::sync::{Arc, RwLock};
4use if_addrs::get_if_addrs;
5
6pub struct LoopDetector {
7    /// All addresses the proxy is listening on
8    listen_addrs: BTreeSet<SocketAddr>,
9    
10    /// Local interface addresses (updated periodically)
11    local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15    pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16        Self {
17            listen_addrs,
18            local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19        }
20    }
21
22    /// Check if connecting to target would create a loop
23    pub fn would_loop(&self, target: SocketAddr) -> bool {
24        // Direct match with listening addresses
25        if self.listen_addrs.contains(&target) {
26            return true;
27        }
28        
29        // Check if target IP is local/loopback and port matches any listen port.
30        if self.is_local_ip(target.ip()) {
31             let listen_ports: HashSet<u16> = self.listen_addrs.iter()
32                .map(|a| a.port())
33                .collect();
34            if listen_ports.contains(&target.port()) {
35                return true;
36            }
37        }
38        
39        false
40    }
41    
42    fn is_local_ip(&self, ip: IpAddr) -> bool {
43        if ip.is_loopback() {
44            return true;
45        }
46        if ip.is_unspecified() {
47            return true;
48        }
49        
50        // Check cached local addrs
51        if let Ok(guard) = self.local_addrs.read() {
52            if guard.contains(&ip) {
53                return true;
54            }
55        }
56        
57        false
58    }
59    
60    /// Refresh local interface addresses from system interfaces.
61    pub async fn refresh_local_addrs(&self) {
62        let mut ips = BTreeSet::new();
63
64        if let Ok(ifaces) = get_if_addrs() {
65            for iface in ifaces {
66                ips.insert(iface.ip());
67            }
68        }
69
70        // Keep loopback and unspecified in cache as a defensive fallback.
71        ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
72        ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
73        ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
74        ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
75
76        if let Ok(mut guard) = self.local_addrs.write() {
77            *guard = ips;
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::str::FromStr;
86
87    #[test]
88    fn test_would_loop_on_direct_listen_match() {
89        let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
90        let detector = LoopDetector::new(listen);
91        let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
92        assert!(detector.would_loop(target));
93    }
94
95    #[test]
96    fn test_would_loop_on_cached_local_ip_and_listen_port() {
97        let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
98        let detector = LoopDetector::new(listen);
99
100        if let Ok(mut guard) = detector.local_addrs.write() {
101            guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
102        }
103
104        let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
105        assert!(detector.would_loop(target));
106    }
107
108    #[tokio::test]
109    async fn test_refresh_local_addrs_populates_cache() {
110        let detector = LoopDetector::new(BTreeSet::new());
111        detector.refresh_local_addrs().await;
112
113        let guard = detector.local_addrs.read().expect("lock");
114        assert!(
115            !guard.is_empty(),
116            "refresh should populate at least fallback local addresses"
117        );
118    }
119}