Skip to main content

relay_core_lib/capture/
loop_detection.rs

1use if_addrs::get_if_addrs;
2use std::collections::{BTreeSet, HashSet};
3use std::net::{IpAddr, SocketAddr};
4use std::sync::{Arc, RwLock};
5
6pub struct LoopDetector {
7    /// All addresses the proxy is listening on
8    listen_addrs: BTreeSet<SocketAddr>,
9
10    /// Local interface addresses (updated periodically)
11    local_addrs: Arc<RwLock<BTreeSet<IpAddr>>>,
12}
13
14impl LoopDetector {
15    pub fn new(listen_addrs: BTreeSet<SocketAddr>) -> Self {
16        Self {
17            listen_addrs,
18            local_addrs: Arc::new(RwLock::new(BTreeSet::new())),
19        }
20    }
21
22    /// Check if connecting to target would create a loop
23    pub fn would_loop(&self, target: SocketAddr) -> bool {
24        // Direct match with listening addresses
25        if self.listen_addrs.contains(&target) {
26            return true;
27        }
28
29        // Check if target IP is local/loopback and port matches any listen port.
30        if self.is_local_ip(target.ip()) {
31            let listen_ports: HashSet<u16> = self.listen_addrs.iter().map(|a| a.port()).collect();
32            if listen_ports.contains(&target.port()) {
33                return true;
34            }
35        }
36
37        false
38    }
39
40    fn is_local_ip(&self, ip: IpAddr) -> bool {
41        if ip.is_loopback() {
42            return true;
43        }
44        if ip.is_unspecified() {
45            return true;
46        }
47
48        // Check cached local addrs
49        if let Ok(guard) = self.local_addrs.read()
50            && guard.contains(&ip)
51        {
52            return true;
53        }
54
55        false
56    }
57
58    /// Refresh local interface addresses from system interfaces.
59    pub async fn refresh_local_addrs(&self) {
60        let mut ips = BTreeSet::new();
61
62        if let Ok(ifaces) = get_if_addrs() {
63            for iface in ifaces {
64                ips.insert(iface.ip());
65            }
66        }
67
68        // Keep loopback and unspecified in cache as a defensive fallback.
69        ips.insert(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
70        ips.insert(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST));
71        ips.insert(IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
72        ips.insert(IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED));
73
74        if let Ok(mut guard) = self.local_addrs.write() {
75            *guard = ips;
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use std::str::FromStr;
84
85    #[test]
86    fn test_would_loop_on_direct_listen_match() {
87        let listen = BTreeSet::from([SocketAddr::from_str("127.0.0.1:8080").expect("addr")]);
88        let detector = LoopDetector::new(listen);
89        let target = SocketAddr::from_str("127.0.0.1:8080").expect("addr");
90        assert!(detector.would_loop(target));
91    }
92
93    #[test]
94    fn test_would_loop_on_cached_local_ip_and_listen_port() {
95        let listen = BTreeSet::from([SocketAddr::from_str("0.0.0.0:9090").expect("addr")]);
96        let detector = LoopDetector::new(listen);
97
98        if let Ok(mut guard) = detector.local_addrs.write() {
99            guard.insert(IpAddr::from_str("10.10.10.1").expect("ip"));
100        }
101
102        let target = SocketAddr::from_str("10.10.10.1:9090").expect("addr");
103        assert!(detector.would_loop(target));
104    }
105
106    #[tokio::test]
107    async fn test_refresh_local_addrs_populates_cache() {
108        let detector = LoopDetector::new(BTreeSet::new());
109        detector.refresh_local_addrs().await;
110
111        let guard = detector.local_addrs.read().expect("lock");
112        assert!(
113            !guard.is_empty(),
114            "refresh should populate at least fallback local addresses"
115        );
116    }
117}