Skip to main content

lab_ops_lab_lib/
port.rs

1//! Port utilities and management.
2//!
3//! Three layers of port management:
4//! - [`create_freebind_socket`] / [`is_port_free`] — Low-level port checking
5//! - [`PortAllocator`] — Runtime TCP pre-bind reservation for conflict prevention
6//! - [`PortAssignments`] / [`allocate_port`] — Persistent ephemeral port allocation
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::net::ToSocketAddrs;
11use std::path::Path;
12
13use color_eyre::Result;
14use color_eyre::eyre::eyre;
15use serde::Deserialize;
16use serde::Serialize;
17use socket2::Domain;
18use socket2::Socket;
19use socket2::Type;
20use tokio::net::TcpListener;
21use tokio::sync::RwLock;
22use tracing::info;
23
24const PORT_RANGE_START: u16 = 32768;
25const PORT_RANGE_END: u16 = 61000;
26
27// ---------------------------------------------------------------------------
28// Low-level socket utilities
29// ---------------------------------------------------------------------------
30
31/// Creates and configures a `Socket` for `addr` with `SO_REUSEADDR`
32/// and the appropriate `IP_FREEBIND` option.
33pub fn create_freebind_socket(addr: &SocketAddr) -> std::io::Result<Socket> {
34    let domain = if addr.is_ipv4() {
35        Domain::IPV4
36    } else {
37        Domain::IPV6
38    };
39
40    let socket = Socket::new(domain, Type::STREAM, None)?;
41
42    // Without this we may fail to bind if the socket was just released
43    // and the port is briefly in TIME_WAIT.
44    socket.set_reuse_address(true)?;
45
46    if addr.is_ipv4() {
47        socket.set_freebind_v4(true)?;
48    } else {
49        socket.set_freebind_v6(true)?;
50    }
51
52    Ok(socket)
53}
54
55/// Checks if a TCP port is free by attempting to bind to it using a socket
56/// configured with `SO_REUSEADDR` and `IP_FREEBIND`. This is more robust
57/// than a simple `TcpListener::bind` as it handles `TIME_WAIT` states gracefully.
58pub fn is_port_free<A: ToSocketAddrs>(addr: A) -> bool {
59    let Ok(mut addrs) = addr.to_socket_addrs() else {
60        return false;
61    };
62
63    let Some(sock_addr) = addrs.next() else {
64        return false;
65    };
66
67    let Ok(socket) = create_freebind_socket(&sock_addr) else {
68        return false;
69    };
70
71    socket.bind(&sock_addr.into()).is_ok()
72}
73
74// ---------------------------------------------------------------------------
75// PortAllocator — runtime TCP pre-bind reservation
76// ---------------------------------------------------------------------------
77
78/// A concurrency-safe port reservation system backed by TCP pre-bind.
79///
80/// Ports are keyed by [`SocketAddr`]. Allocating a port binds a
81/// [`TcpListener`] to it, preventing other processes or concurrent daemon
82/// operations from claiming the same port.
83pub struct PortAllocator {
84    sockets: RwLock<HashMap<SocketAddr, TcpListener>>,
85}
86
87impl Default for PortAllocator {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl PortAllocator {
94    pub fn new() -> Self {
95        Self {
96            sockets: RwLock::new(HashMap::new()),
97        }
98    }
99
100    /// Reserves `addr` by binding a [`TcpListener`] to it.
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if the port is already bound by another process.
105    pub async fn allocate(&self, addr: SocketAddr) -> Result<()> {
106        let socket = create_freebind_socket(&addr)
107            .map_err(|e| eyre!("Failed to create socket for {addr}: {e}"))?;
108
109        socket
110            .bind(&addr.into())
111            .map_err(|e| eyre!("Failed to reserve {addr}: {e}"))?;
112
113        socket
114            .listen(128)
115            .map_err(|e| eyre!("Failed to listen on {addr}: {e}"))?;
116
117        let std_listener: std::net::TcpListener = socket.into();
118        std_listener
119            .set_nonblocking(true)
120            .map_err(|e| eyre!("Failed to set nonblocking for {addr}: {e}"))?;
121
122        let listener = TcpListener::from_std(std_listener)
123            .map_err(|e| eyre!("Failed to create tokio listener for {addr}: {e}"))?;
124
125        info!("Reserved {addr}");
126        self.sockets.write().await.insert(addr, listener);
127        Ok(())
128    }
129
130    /// Releases the reservation for `addr`, if any.
131    pub async fn deallocate(&self, addr: SocketAddr) {
132        self.sockets.write().await.remove(&addr);
133        info!("Released {addr}");
134    }
135
136    /// Returns `true` if `addr` has an active reservation.
137    pub async fn is_allocated(&self, addr: SocketAddr) -> bool {
138        if self.sockets.read().await.contains_key(&addr) {
139            return true;
140        }
141
142        let Ok(socket) = create_freebind_socket(&addr) else {
143            return false;
144        };
145
146        match socket.bind(&addr.into()) {
147            Err(e) => e.kind() == std::io::ErrorKind::AddrInUse,
148            Ok(_) => false,
149        }
150    }
151
152    /// Releases all active reservations.
153    pub async fn deallocate_all(&self) {
154        let mut sockets = self.sockets.write().await;
155        let count = sockets.len();
156        sockets.clear();
157        info!("Released all {count} reservations");
158    }
159}
160
161// ---------------------------------------------------------------------------
162// PortAssignments — persistent ephemeral port allocation
163// ---------------------------------------------------------------------------
164
165/// Persistent mapping of service keys to allocated host ports.
166///
167/// Keys follow the format `"{service_name}-{container_port}"` (e.g. `"example-drive-80"`).
168/// Loaded from and saved to `ports.json` in the state directory.
169#[derive(Debug, Clone, Serialize, Deserialize, Default)]
170pub struct PortAssignments {
171    assignments: HashMap<String, u16>,
172}
173
174impl PortAssignments {
175    /// Load assignments from a JSON file. Returns empty defaults if the
176    /// file does not exist or is unreadable.
177    pub fn load(path: &Path) -> Self {
178        std::fs::read_to_string(path)
179            .ok()
180            .and_then(|s| serde_json::from_str(&s).ok())
181            .unwrap_or_default()
182    }
183
184    /// Persist assignments to a JSON file. Creates parent directories as
185    /// needed.
186    pub fn save(&self, path: &Path) -> Result<(), std::io::Error> {
187        if let Some(parent) = path.parent() {
188            std::fs::create_dir_all(parent)?;
189        }
190        let contents = serde_json::to_string_pretty(self)?;
191        std::fs::write(path, contents)
192    }
193
194    /// Look up a port assignment by key. Returns `None` if not assigned.
195    pub fn get(&self, key: &str) -> Option<u16> {
196        self.assignments.get(key).copied()
197    }
198
199    /// Assign a port for the given key.
200    pub fn set(&mut self, key: String, port: u16) {
201        self.assignments.insert(key, port);
202    }
203
204    /// Remove an assignment by key. Returns the previously assigned port,
205    /// or `None`.
206    #[allow(dead_code)]
207    pub fn remove(&mut self, key: &str) -> Option<u16> {
208        self.assignments.remove(key)
209    }
210
211    /// Returns `true` if the given port is already assigned.
212    pub fn is_used(&self, port: u16) -> bool {
213        self.assignments.values().any(|&p| p == port)
214    }
215
216    /// Look up an existing assignment, or allocate a free port and assign it.
217    ///
218    /// Returns `None` if no port is available in the ephemeral range.
219    pub fn get_or_allocate(&mut self, key: &str) -> Option<u16> {
220        if let Some(&p) = self.assignments.get(key) {
221            return Some(p);
222        }
223        let p = self.allocate_port()?;
224        self.assignments.insert(key.to_string(), p);
225        Some(p)
226    }
227
228    fn allocate_port(&self) -> Option<u16> {
229        (PORT_RANGE_START..=PORT_RANGE_END)
230            .find(|&port| !self.is_used(port) && is_port_free(format!("0.0.0.0:{port}")))
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use std::net::TcpListener;
237
238    use tempfile::TempDir;
239
240    use super::*;
241
242    #[test]
243    fn set_and_get() {
244        let mut pa = PortAssignments::default();
245        pa.set("example-drive-80".into(), 32000);
246        assert_eq!(pa.get("example-drive-80"), Some(32000));
247    }
248
249    #[test]
250    fn remove() {
251        let mut pa = PortAssignments::default();
252        pa.set("key".into(), 32000);
253        assert_eq!(pa.remove("key"), Some(32000));
254        assert_eq!(pa.get("key"), None);
255    }
256
257    #[test]
258    fn persistence() {
259        let dir = TempDir::new().unwrap();
260        let path = dir.path().join("ports.json");
261
262        let mut pa = PortAssignments::default();
263        pa.set("s1".into(), 40000);
264        pa.set("s2".into(), 40001);
265        pa.save(&path).unwrap();
266
267        let loaded = PortAssignments::load(&path);
268        assert_eq!(loaded.get("s1"), Some(40000));
269        assert_eq!(loaded.get("s2"), Some(40001));
270    }
271
272    #[test]
273    fn get_or_allocate_returns_existing() {
274        let mut pa = PortAssignments::default();
275        pa.set("key".into(), 32000);
276        assert_eq!(pa.get_or_allocate("key"), Some(32000));
277    }
278
279    #[test]
280    fn get_or_allocate_creates_new() {
281        let mut pa = PortAssignments::default();
282        let p = pa.get_or_allocate("new-key").unwrap();
283        assert!((PORT_RANGE_START..=PORT_RANGE_END).contains(&p));
284        assert_eq!(pa.get("new-key"), Some(p));
285    }
286
287    #[test]
288    fn is_port_free_localhost() {
289        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
290        let port = listener.local_addr().unwrap().port();
291        drop(listener);
292
293        assert!(is_port_free(format!("127.0.0.1:{port}")));
294    }
295
296    #[test]
297    fn is_port_free_occupied() {
298        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
299        let port = listener.local_addr().unwrap().port();
300        assert!(!is_port_free(format!("127.0.0.1:{port}")));
301        drop(listener);
302        assert!(is_port_free(format!("127.0.0.1:{port}")));
303    }
304
305    #[test]
306    fn allocate_port_assigns_unique() {
307        let mut pa = PortAssignments::default();
308        let p1 = pa.allocate_port().unwrap();
309        pa.set("s1".into(), p1);
310        let p2 = pa.allocate_port().unwrap();
311        assert_ne!(p1, p2);
312    }
313}