remote/
port_ranges.rs

1use anyhow::{anyhow, Context};
2
3#[derive(Debug, Clone)]
4pub struct PortRanges {
5    ranges: Vec<std::ops::Range<u16>>,
6}
7
8impl PortRanges {
9    /// Parse port ranges from a string like "8000-8999,10000-10999"
10    pub fn parse(ranges_str: &str) -> anyhow::Result<Self> {
11        let mut ranges = Vec::new();
12        for range_str in ranges_str.split(',') {
13            let range_str = range_str.trim();
14            if range_str.is_empty() {
15                continue;
16            }
17            if let Some((start_str, end_str)) = range_str.split_once('-') {
18                let start: u16 = start_str
19                    .trim()
20                    .parse()
21                    .with_context(|| format!("Invalid start port in range: {start_str}"))?;
22                let end: u16 = end_str
23                    .trim()
24                    .parse()
25                    .with_context(|| format!("Invalid end port in range: {end_str}"))?;
26                if start > end {
27                    return Err(anyhow!(
28                        "Invalid port range: start port {} > end port {}",
29                        start,
30                        end
31                    ));
32                }
33                if start == 0 {
34                    return Err(anyhow!("Port 0 is not allowed in ranges"));
35                }
36                ranges.push(start..end + 1); // range is exclusive of end, so add 1
37            } else {
38                // single port
39                let port: u16 = range_str
40                    .parse()
41                    .with_context(|| format!("Invalid port: {range_str}"))?;
42                if port == 0 {
43                    return Err(anyhow!("Port 0 is not allowed"));
44                }
45                ranges.push(port..port + 1);
46            }
47        }
48        if ranges.is_empty() {
49            return Err(anyhow!("No valid port ranges found"));
50        }
51        Ok(PortRanges { ranges })
52    }
53
54    /// Try to bind to a UDP socket within the specified port ranges
55    pub fn bind_udp_socket(&self, ip: std::net::IpAddr) -> anyhow::Result<std::net::UdpSocket> {
56        use rand::seq::SliceRandom;
57        use std::time::{Duration, Instant};
58        // collect all possible ports from all ranges
59        let mut all_ports: Vec<u16> = Vec::new();
60        for range in &self.ranges {
61            all_ports.extend(range.clone());
62        }
63        // randomize the order to avoid always using the same ports
64        let mut rng = rand::thread_rng();
65        all_ports.shuffle(&mut rng);
66        let start_time = Instant::now();
67        // allow overriding the timeout via environment variable
68        let max_duration_secs = match std::env::var("RCP_UDP_BIND_MAX_DURATION_SECONDS")
69            .ok()
70            .and_then(|s| s.parse::<u64>().ok())
71        {
72            Some(x) => {
73                tracing::debug!(
74                    "Using custom UDP bind timeout: {x}s (from RCP_UDP_BIND_MAX_DURATION_SECONDS)",
75                );
76                x
77            }
78            None => 5,
79        };
80        let max_duration = Duration::from_secs(max_duration_secs);
81        let mut attempts = 0;
82        let mut last_error = None;
83        for port in all_ports {
84            if start_time.elapsed() > max_duration {
85                tracing::warn!(
86                    "Port binding timeout after {} attempts in {:?}",
87                    attempts,
88                    start_time.elapsed()
89                );
90                break;
91            }
92            attempts += 1;
93            let addr = std::net::SocketAddr::new(ip, port);
94            match std::net::UdpSocket::bind(addr) {
95                Ok(socket) => {
96                    tracing::info!(
97                        "Successfully bound to manually selected port {}:{} after {} attempts",
98                        ip,
99                        port,
100                        attempts
101                    );
102                    return Ok(socket);
103                }
104                Err(e) => {
105                    tracing::debug!("Failed to bind to {}:{}: {}", ip, port, e);
106                    // add small delay on port collisions to reduce thundering herd
107                    let is_addr_in_use = e.kind() == std::io::ErrorKind::AddrInUse;
108                    last_error = Some(e);
109                    if is_addr_in_use && attempts % 10 == 0 {
110                        std::thread::sleep(Duration::from_millis(1));
111                    }
112                }
113            }
114        }
115        Err(anyhow!(
116            "Failed to bind to any port in the specified ranges after {} attempts in {:?}: {}",
117            attempts,
118            start_time.elapsed(),
119            last_error
120                .map(|e| e.to_string())
121                .unwrap_or_else(|| "no ports available".to_string())
122        ))
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_parse_single_port() {
132        let ranges = PortRanges::parse("8080").unwrap();
133        assert_eq!(ranges.ranges.len(), 1);
134        assert_eq!(ranges.ranges[0], 8080..8081);
135    }
136
137    #[test]
138    fn test_parse_range() {
139        let ranges = PortRanges::parse("8000-8999").unwrap();
140        assert_eq!(ranges.ranges.len(), 1);
141        assert_eq!(ranges.ranges[0], 8000..9000);
142    }
143
144    #[test]
145    fn test_parse_multiple_ranges() {
146        let ranges = PortRanges::parse("8000-8999,10000-10999,12345").unwrap();
147        assert_eq!(ranges.ranges.len(), 3);
148        assert_eq!(ranges.ranges[0], 8000..9000);
149        assert_eq!(ranges.ranges[1], 10000..11000);
150        assert_eq!(ranges.ranges[2], 12345..12346);
151    }
152
153    #[test]
154    fn test_parse_invalid_range() {
155        assert!(PortRanges::parse("9000-8000").is_err()); // start > end
156        assert!(PortRanges::parse("0-100").is_err()); // port 0 not allowed
157        assert!(PortRanges::parse("abc").is_err()); // non-numeric
158        assert!(PortRanges::parse("").is_err()); // empty
159    }
160}