1use anyhow::{anyhow, Context};
2
3#[derive(Debug, Clone)]
4pub struct PortRanges {
5 ranges: Vec<std::ops::Range<u16>>,
6}
7
8impl PortRanges {
9 pub fn parse(ranges_str: &str) -> anyhow::Result<Self> {
11 let mut ranges = Vec::new();
12 for range_str in ranges_str.split(',') {
13 let range_str = range_str.trim();
14 if range_str.is_empty() {
15 continue;
16 }
17 if let Some((start_str, end_str)) = range_str.split_once('-') {
18 let start: u16 = start_str
19 .trim()
20 .parse()
21 .with_context(|| format!("Invalid start port in range: {start_str}"))?;
22 let end: u16 = end_str
23 .trim()
24 .parse()
25 .with_context(|| format!("Invalid end port in range: {end_str}"))?;
26 if start > end {
27 return Err(anyhow!(
28 "Invalid port range: start port {} > end port {}",
29 start,
30 end
31 ));
32 }
33 if start == 0 {
34 return Err(anyhow!("Port 0 is not allowed in ranges"));
35 }
36 ranges.push(start..end + 1); } else {
38 let port: u16 = range_str
40 .parse()
41 .with_context(|| format!("Invalid port: {range_str}"))?;
42 if port == 0 {
43 return Err(anyhow!("Port 0 is not allowed"));
44 }
45 ranges.push(port..port + 1);
46 }
47 }
48 if ranges.is_empty() {
49 return Err(anyhow!("No valid port ranges found"));
50 }
51 Ok(PortRanges { ranges })
52 }
53
54 pub fn bind_udp_socket(&self, ip: std::net::IpAddr) -> anyhow::Result<std::net::UdpSocket> {
56 use rand::seq::SliceRandom;
57 use std::time::{Duration, Instant};
58 let mut all_ports: Vec<u16> = Vec::new();
60 for range in &self.ranges {
61 all_ports.extend(range.clone());
62 }
63 let mut rng = rand::thread_rng();
65 all_ports.shuffle(&mut rng);
66 let start_time = Instant::now();
67 let max_duration_secs = match std::env::var("RCP_UDP_BIND_MAX_DURATION_SECONDS")
69 .ok()
70 .and_then(|s| s.parse::<u64>().ok())
71 {
72 Some(x) => {
73 tracing::debug!(
74 "Using custom UDP bind timeout: {x}s (from RCP_UDP_BIND_MAX_DURATION_SECONDS)",
75 );
76 x
77 }
78 None => 5,
79 };
80 let max_duration = Duration::from_secs(max_duration_secs);
81 let mut attempts = 0;
82 let mut last_error = None;
83 for port in all_ports {
84 if start_time.elapsed() > max_duration {
85 tracing::warn!(
86 "Port binding timeout after {} attempts in {:?}",
87 attempts,
88 start_time.elapsed()
89 );
90 break;
91 }
92 attempts += 1;
93 let addr = std::net::SocketAddr::new(ip, port);
94 match std::net::UdpSocket::bind(addr) {
95 Ok(socket) => {
96 tracing::info!(
97 "Successfully bound to manually selected port {}:{} after {} attempts",
98 ip,
99 port,
100 attempts
101 );
102 return Ok(socket);
103 }
104 Err(e) => {
105 tracing::debug!("Failed to bind to {}:{}: {}", ip, port, e);
106 let is_addr_in_use = e.kind() == std::io::ErrorKind::AddrInUse;
108 last_error = Some(e);
109 if is_addr_in_use && attempts % 10 == 0 {
110 std::thread::sleep(Duration::from_millis(1));
111 }
112 }
113 }
114 }
115 Err(anyhow!(
116 "Failed to bind to any port in the specified ranges after {} attempts in {:?}: {}",
117 attempts,
118 start_time.elapsed(),
119 last_error
120 .map(|e| e.to_string())
121 .unwrap_or_else(|| "no ports available".to_string())
122 ))
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_parse_single_port() {
132 let ranges = PortRanges::parse("8080").unwrap();
133 assert_eq!(ranges.ranges.len(), 1);
134 assert_eq!(ranges.ranges[0], 8080..8081);
135 }
136
137 #[test]
138 fn test_parse_range() {
139 let ranges = PortRanges::parse("8000-8999").unwrap();
140 assert_eq!(ranges.ranges.len(), 1);
141 assert_eq!(ranges.ranges[0], 8000..9000);
142 }
143
144 #[test]
145 fn test_parse_multiple_ranges() {
146 let ranges = PortRanges::parse("8000-8999,10000-10999,12345").unwrap();
147 assert_eq!(ranges.ranges.len(), 3);
148 assert_eq!(ranges.ranges[0], 8000..9000);
149 assert_eq!(ranges.ranges[1], 10000..11000);
150 assert_eq!(ranges.ranges[2], 12345..12346);
151 }
152
153 #[test]
154 fn test_parse_invalid_range() {
155 assert!(PortRanges::parse("9000-8000").is_err()); assert!(PortRanges::parse("0-100").is_err()); assert!(PortRanges::parse("abc").is_err()); assert!(PortRanges::parse("").is_err()); }
160}