1use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::net::ToSocketAddrs;
11use std::path::Path;
12
13use color_eyre::Result;
14use color_eyre::eyre::eyre;
15use serde::Deserialize;
16use serde::Serialize;
17use socket2::Domain;
18use socket2::Socket;
19use socket2::Type;
20use tokio::net::TcpListener;
21use tokio::sync::RwLock;
22use tracing::info;
23
24const PORT_RANGE_START: u16 = 32768;
25const PORT_RANGE_END: u16 = 61000;
26
27pub fn create_freebind_socket(addr: &SocketAddr) -> std::io::Result<Socket> {
34 let domain = if addr.is_ipv4() {
35 Domain::IPV4
36 } else {
37 Domain::IPV6
38 };
39
40 let socket = Socket::new(domain, Type::STREAM, None)?;
41
42 socket.set_reuse_address(true)?;
45
46 if addr.is_ipv4() {
47 socket.set_freebind_v4(true)?;
48 } else {
49 socket.set_freebind_v6(true)?;
50 }
51
52 Ok(socket)
53}
54
55pub fn is_port_free<A: ToSocketAddrs>(addr: A) -> bool {
59 let Ok(mut addrs) = addr.to_socket_addrs() else {
60 return false;
61 };
62
63 let Some(sock_addr) = addrs.next() else {
64 return false;
65 };
66
67 let Ok(socket) = create_freebind_socket(&sock_addr) else {
68 return false;
69 };
70
71 socket.bind(&sock_addr.into()).is_ok()
72}
73
74pub struct PortAllocator {
84 sockets: RwLock<HashMap<SocketAddr, TcpListener>>,
85}
86
87impl Default for PortAllocator {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl PortAllocator {
94 pub fn new() -> Self {
95 Self {
96 sockets: RwLock::new(HashMap::new()),
97 }
98 }
99
100 pub async fn allocate(&self, addr: SocketAddr) -> Result<()> {
106 let socket = create_freebind_socket(&addr)
107 .map_err(|e| eyre!("Failed to create socket for {addr}: {e}"))?;
108
109 socket
110 .bind(&addr.into())
111 .map_err(|e| eyre!("Failed to reserve {addr}: {e}"))?;
112
113 socket
114 .listen(128)
115 .map_err(|e| eyre!("Failed to listen on {addr}: {e}"))?;
116
117 let std_listener: std::net::TcpListener = socket.into();
118 std_listener
119 .set_nonblocking(true)
120 .map_err(|e| eyre!("Failed to set nonblocking for {addr}: {e}"))?;
121
122 let listener = TcpListener::from_std(std_listener)
123 .map_err(|e| eyre!("Failed to create tokio listener for {addr}: {e}"))?;
124
125 info!("Reserved {addr}");
126 self.sockets.write().await.insert(addr, listener);
127 Ok(())
128 }
129
130 pub async fn deallocate(&self, addr: SocketAddr) {
132 self.sockets.write().await.remove(&addr);
133 info!("Released {addr}");
134 }
135
136 pub async fn is_allocated(&self, addr: SocketAddr) -> bool {
138 if self.sockets.read().await.contains_key(&addr) {
139 return true;
140 }
141
142 let Ok(socket) = create_freebind_socket(&addr) else {
143 return false;
144 };
145
146 match socket.bind(&addr.into()) {
147 Err(e) => e.kind() == std::io::ErrorKind::AddrInUse,
148 Ok(_) => false,
149 }
150 }
151
152 pub async fn deallocate_all(&self) {
154 let mut sockets = self.sockets.write().await;
155 let count = sockets.len();
156 sockets.clear();
157 info!("Released all {count} reservations");
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize, Default)]
170pub struct PortAssignments {
171 assignments: HashMap<String, u16>,
172}
173
174impl PortAssignments {
175 pub fn load(path: &Path) -> Self {
178 std::fs::read_to_string(path)
179 .ok()
180 .and_then(|s| serde_json::from_str(&s).ok())
181 .unwrap_or_default()
182 }
183
184 pub fn save(&self, path: &Path) -> Result<(), std::io::Error> {
187 if let Some(parent) = path.parent() {
188 std::fs::create_dir_all(parent)?;
189 }
190 let contents = serde_json::to_string_pretty(self)?;
191 std::fs::write(path, contents)
192 }
193
194 pub fn get(&self, key: &str) -> Option<u16> {
196 self.assignments.get(key).copied()
197 }
198
199 pub fn set(&mut self, key: String, port: u16) {
201 self.assignments.insert(key, port);
202 }
203
204 #[allow(dead_code)]
207 pub fn remove(&mut self, key: &str) -> Option<u16> {
208 self.assignments.remove(key)
209 }
210
211 pub fn is_used(&self, port: u16) -> bool {
213 self.assignments.values().any(|&p| p == port)
214 }
215
216 pub fn get_or_allocate(&mut self, key: &str) -> Option<u16> {
220 if let Some(&p) = self.assignments.get(key) {
221 return Some(p);
222 }
223 let p = self.allocate_port()?;
224 self.assignments.insert(key.to_string(), p);
225 Some(p)
226 }
227
228 fn allocate_port(&self) -> Option<u16> {
229 (PORT_RANGE_START..=PORT_RANGE_END)
230 .find(|&port| !self.is_used(port) && is_port_free(format!("0.0.0.0:{port}")))
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use std::net::TcpListener;
237
238 use tempfile::TempDir;
239
240 use super::*;
241
242 #[test]
243 fn set_and_get() {
244 let mut pa = PortAssignments::default();
245 pa.set("example-drive-80".into(), 32000);
246 assert_eq!(pa.get("example-drive-80"), Some(32000));
247 }
248
249 #[test]
250 fn remove() {
251 let mut pa = PortAssignments::default();
252 pa.set("key".into(), 32000);
253 assert_eq!(pa.remove("key"), Some(32000));
254 assert_eq!(pa.get("key"), None);
255 }
256
257 #[test]
258 fn persistence() {
259 let dir = TempDir::new().unwrap();
260 let path = dir.path().join("ports.json");
261
262 let mut pa = PortAssignments::default();
263 pa.set("s1".into(), 40000);
264 pa.set("s2".into(), 40001);
265 pa.save(&path).unwrap();
266
267 let loaded = PortAssignments::load(&path);
268 assert_eq!(loaded.get("s1"), Some(40000));
269 assert_eq!(loaded.get("s2"), Some(40001));
270 }
271
272 #[test]
273 fn get_or_allocate_returns_existing() {
274 let mut pa = PortAssignments::default();
275 pa.set("key".into(), 32000);
276 assert_eq!(pa.get_or_allocate("key"), Some(32000));
277 }
278
279 #[test]
280 fn get_or_allocate_creates_new() {
281 let mut pa = PortAssignments::default();
282 let p = pa.get_or_allocate("new-key").unwrap();
283 assert!((PORT_RANGE_START..=PORT_RANGE_END).contains(&p));
284 assert_eq!(pa.get("new-key"), Some(p));
285 }
286
287 #[test]
288 fn is_port_free_localhost() {
289 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
290 let port = listener.local_addr().unwrap().port();
291 drop(listener);
292
293 assert!(is_port_free(format!("127.0.0.1:{port}")));
294 }
295
296 #[test]
297 fn is_port_free_occupied() {
298 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
299 let port = listener.local_addr().unwrap().port();
300 assert!(!is_port_free(format!("127.0.0.1:{port}")));
301 drop(listener);
302 assert!(is_port_free(format!("127.0.0.1:{port}")));
303 }
304
305 #[test]
306 fn allocate_port_assigns_unique() {
307 let mut pa = PortAssignments::default();
308 let p1 = pa.allocate_port().unwrap();
309 pa.set("s1".into(), p1);
310 let p2 = pa.allocate_port().unwrap();
311 assert_ne!(p1, p2);
312 }
313}