Skip to main content

bindport_runner/
lib.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    collections::HashSet,
5    fmt, io,
6    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener},
7    process::{Child, Command, ExitStatus, Stdio},
8};
9
10#[cfg(unix)]
11use std::os::unix::process::CommandExt;
12#[cfg(unix)]
13use std::sync::atomic::{AtomicI32, Ordering};
14
15use bindport_core::PortRange;
16
17pub const PORT_ENV_VAR: &str = "PORT";
18
19#[cfg(unix)]
20static FORWARDED_CHILD_PID: AtomicI32 = AtomicI32::new(0);
21
22#[cfg(unix)]
23const RESERVED_CHILD_PID: i32 = -1;
24
25#[derive(Debug)]
26pub enum RunnerError {
27    NoCommand,
28    NoAvailablePort { range: PortRange },
29    SignalForwarding { source: io::Error },
30    Spawn { command: String, source: io::Error },
31    Wait { command: String, source: io::Error },
32}
33
34impl fmt::Display for RunnerError {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            Self::NoCommand => write!(f, "no command provided after `--`"),
38            Self::NoAvailablePort { range } => {
39                write!(
40                    f,
41                    "no available port found in range {}-{}",
42                    range.start, range.end
43                )
44            }
45            Self::SignalForwarding { source } => {
46                write!(f, "failed to install signal forwarding: {source}")
47            }
48            Self::Spawn { command, source } => {
49                write!(f, "failed to spawn `{command}`: {source}")
50            }
51            Self::Wait { command, source } => {
52                write!(f, "failed waiting for `{command}`: {source}")
53            }
54        }
55    }
56}
57
58impl std::error::Error for RunnerError {}
59
60pub struct RunningChild {
61    child: Child,
62    port: u16,
63    program: String,
64    signal_forwarding: SignalForwardingState,
65}
66
67impl RunningChild {
68    pub const fn port(&self) -> u16 {
69        self.port
70    }
71
72    pub fn pid(&self) -> u32 {
73        self.child.id()
74    }
75
76    pub fn kill(&mut self) -> io::Result<()> {
77        self.child.kill()
78    }
79
80    pub fn wait(&mut self) -> Result<ExitStatus, RunnerError> {
81        let status = self.child.wait().map_err(|source| RunnerError::Wait {
82            command: self.program.clone(),
83            source,
84        });
85        let signal_forwarding = self
86            .signal_forwarding
87            .deactivate()
88            .map_err(|source| RunnerError::SignalForwarding { source });
89
90        match (status, signal_forwarding) {
91            (Ok(status), Ok(())) => Ok(status),
92            (Err(error), _) | (_, Err(error)) => Err(error),
93        }
94    }
95}
96
97impl Drop for RunningChild {
98    fn drop(&mut self) {
99        let _ = self.signal_forwarding.deactivate();
100    }
101}
102
103#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
104pub struct AllocationHints {
105    pub preferred_port: Option<u16>,
106    pub scan_start: Option<u16>,
107}
108
109/// Scans the configured TCP loopback range and returns an available port.
110///
111/// This bootstrap runner drops the probe listener before spawning the child, so
112/// another process can still claim the port before the child binds. The
113/// registry/lease slice must close that gap for strong coordination.
114pub fn allocate_port(range: PortRange, skip_ports: &[u16]) -> Result<u16, RunnerError> {
115    allocate_port_with_hints(range, skip_ports, AllocationHints::default())
116}
117
118pub fn allocate_port_with_hints(
119    range: PortRange,
120    skip_ports: &[u16],
121    hints: AllocationHints,
122) -> Result<u16, RunnerError> {
123    allocate_port_with_hints_and_availability(range, skip_ports, hints, is_port_available)
124}
125
126fn allocate_port_with_hints_and_availability(
127    range: PortRange,
128    skip_ports: &[u16],
129    hints: AllocationHints,
130    mut is_available: impl FnMut(u16) -> bool,
131) -> Result<u16, RunnerError> {
132    let skip_ports = skip_ports.iter().copied().collect::<HashSet<_>>();
133
134    if let Some(port) = hints
135        .preferred_port
136        .filter(|port| range.contains(*port) && !skip_ports.contains(port))
137        && is_available(port)
138    {
139        return Ok(port);
140    }
141
142    let range_len = range.len();
143    if range_len == 0 {
144        return Err(RunnerError::NoAvailablePort { range });
145    }
146
147    let scan_start = hints
148        .scan_start
149        .filter(|port| range.contains(*port))
150        .unwrap_or(range.start);
151    let scan_start_offset = scan_start as u32 - range.start as u32;
152
153    for offset in 0..range_len {
154        let port = range.start as u32 + ((scan_start_offset + offset) % range_len);
155        let port = u16::try_from(port).expect("port remains within configured range");
156
157        if skip_ports.contains(&port) {
158            continue;
159        }
160
161        if is_available(port) {
162            return Ok(port);
163        }
164    }
165
166    Err(RunnerError::NoAvailablePort { range })
167}
168
169/// Returns true when no supported TCP loopback family reports `port` in use.
170///
171/// Missing address families are not conflicts, so IPv4-only hosts can still
172/// allocate a loopback port. UDP availability is outside the current runner
173/// scope.
174pub fn is_port_available(port: u16) -> bool {
175    let v4 = loopback_free(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
176    let v6 = loopback_free(SocketAddrV6::new(Ipv6Addr::LOCALHOST, port, 0, 0));
177
178    v4 && v6
179}
180
181fn loopback_free(addr: impl Into<SocketAddr>) -> bool {
182    match TcpListener::bind(addr.into()) {
183        Ok(_) => true,
184        Err(error) => bind_error_leaves_port_available(error.kind()),
185    }
186}
187
188fn bind_error_leaves_port_available(kind: io::ErrorKind) -> bool {
189    kind != io::ErrorKind::AddrInUse
190}
191
192pub fn run_child(
193    command: &[String],
194    range: PortRange,
195    skip_ports: &[u16],
196) -> Result<ExitStatus, RunnerError> {
197    let mut child = spawn_child(command, range, skip_ports)?;
198
199    child.wait()
200}
201
202/// Spawns a wrapped command with the selected port in its environment.
203///
204/// On Unix, SIGINT/SIGTERM forwarding uses process-global signal handlers while
205/// the returned child is active. A second concurrent forwarded child is rejected.
206pub fn spawn_child(
207    command: &[String],
208    range: PortRange,
209    skip_ports: &[u16],
210) -> Result<RunningChild, RunnerError> {
211    spawn_child_with_hints(command, range, skip_ports, AllocationHints::default())
212}
213
214pub fn spawn_child_with_hints(
215    command: &[String],
216    range: PortRange,
217    skip_ports: &[u16],
218    allocation_hints: AllocationHints,
219) -> Result<RunningChild, RunnerError> {
220    let (program, args) = command.split_first().ok_or(RunnerError::NoCommand)?;
221    let port = allocate_port_with_hints(range, skip_ports, allocation_hints)?;
222
223    let mut signal_forwarding =
224        prepare_signal_forwarding().map_err(|source| RunnerError::SignalForwarding { source })?;
225
226    let mut process = Command::new(program);
227    process
228        .args(args)
229        .env(PORT_ENV_VAR, port.to_string())
230        .stdin(Stdio::inherit())
231        .stdout(Stdio::inherit())
232        .stderr(Stdio::inherit());
233    prepare_child_signal_mask(&mut process, &signal_forwarding);
234
235    let child = match process.spawn() {
236        Ok(child) => child,
237        Err(source) => {
238            if let Err(source) = signal_forwarding.deactivate() {
239                return Err(RunnerError::SignalForwarding { source });
240            }
241
242            return Err(RunnerError::Spawn {
243                command: program.clone(),
244                source,
245            });
246        }
247    };
248    let child = match signal_forwarding.activate_for_pid(child.id()) {
249        Ok(()) => child,
250        Err(source) => {
251            let mut child = child;
252            let _ = child.kill();
253            let _ = child.wait();
254            let _ = signal_forwarding.deactivate();
255
256            return Err(RunnerError::SignalForwarding { source });
257        }
258    };
259
260    Ok(RunningChild {
261        child,
262        port,
263        program: program.clone(),
264        signal_forwarding,
265    })
266}
267
268#[cfg(unix)]
269struct SignalForwardingState {
270    saved_handlers: Option<SavedSignalHandlers>,
271    saved_signal_mask: Option<libc::sigset_t>,
272}
273
274#[cfg(not(unix))]
275struct SignalForwardingState;
276
277#[cfg(unix)]
278struct SavedSignalHandlers {
279    sigint: libc::sigaction,
280    sigterm: libc::sigaction,
281}
282
283#[cfg(unix)]
284impl SignalForwardingState {
285    fn activate_for_pid(&mut self, pid: u32) -> io::Result<()> {
286        FORWARDED_CHILD_PID.store(pid as i32, Ordering::SeqCst);
287        self.restore_signal_mask()
288    }
289
290    fn deactivate(&mut self) -> io::Result<()> {
291        FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
292        let signal_mask = self.restore_signal_mask();
293
294        let handlers = if let Some(saved_handlers) = self.saved_handlers.take() {
295            restore_signal_forwarding_handlers(&saved_handlers)
296        } else {
297            Ok(())
298        };
299
300        signal_mask.and(handlers)
301    }
302
303    fn restore_signal_mask(&mut self) -> io::Result<()> {
304        if let Some(saved_signal_mask) = self.saved_signal_mask.as_ref() {
305            restore_signal_mask(saved_signal_mask)?;
306            self.saved_signal_mask = None;
307        }
308
309        Ok(())
310    }
311}
312
313#[cfg(not(unix))]
314impl SignalForwardingState {
315    fn activate_for_pid(&mut self, _pid: u32) -> io::Result<()> {
316        Ok(())
317    }
318
319    fn deactivate(&mut self) -> io::Result<()> {
320        Ok(())
321    }
322}
323
324#[cfg(unix)]
325fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
326    reserve_signal_forwarding()?;
327    let saved_signal_mask = match block_signal_forwarding_signals() {
328        Ok(saved_signal_mask) => saved_signal_mask,
329        Err(error) => {
330            FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
331            return Err(error);
332        }
333    };
334
335    match install_signal_forwarding_handlers() {
336        Ok(saved_handlers) => Ok(SignalForwardingState {
337            saved_handlers: Some(saved_handlers),
338            saved_signal_mask: Some(saved_signal_mask),
339        }),
340        Err(error) => {
341            FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
342            let _ = restore_signal_mask(&saved_signal_mask);
343            Err(error)
344        }
345    }
346}
347
348#[cfg(not(unix))]
349fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
350    Ok(SignalForwardingState)
351}
352
353#[cfg(unix)]
354fn prepare_child_signal_mask(command: &mut Command, signal_forwarding: &SignalForwardingState) {
355    if let Some(saved_signal_mask) = signal_forwarding.saved_signal_mask {
356        unsafe {
357            command.pre_exec(move || restore_signal_mask(&saved_signal_mask));
358        }
359    }
360}
361
362#[cfg(not(unix))]
363fn prepare_child_signal_mask(_command: &mut Command, _signal_forwarding: &SignalForwardingState) {}
364
365#[cfg(unix)]
366fn reserve_signal_forwarding() -> io::Result<()> {
367    FORWARDED_CHILD_PID
368        .compare_exchange(0, RESERVED_CHILD_PID, Ordering::SeqCst, Ordering::SeqCst)
369        .map(|_| ())
370        .map_err(|_| {
371            io::Error::new(
372                io::ErrorKind::AlreadyExists,
373                "signal forwarding is already active",
374            )
375        })
376}
377
378#[cfg(unix)]
379fn install_signal_forwarding_handlers() -> io::Result<SavedSignalHandlers> {
380    let sigint = install_signal_forwarding_handler(libc::SIGINT)?;
381
382    match install_signal_forwarding_handler(libc::SIGTERM) {
383        Ok(sigterm) => Ok(SavedSignalHandlers { sigint, sigterm }),
384        Err(error) => {
385            let _ = restore_signal_handler(libc::SIGINT, &sigint);
386            Err(error)
387        }
388    }
389}
390
391#[cfg(unix)]
392fn block_signal_forwarding_signals() -> io::Result<libc::sigset_t> {
393    let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
394    let mut previous = unsafe { std::mem::zeroed::<libc::sigset_t>() };
395
396    if unsafe { libc::sigemptyset(&mut mask) } == -1 {
397        return Err(io::Error::last_os_error());
398    }
399    if unsafe { libc::sigaddset(&mut mask, libc::SIGINT) } == -1 {
400        return Err(io::Error::last_os_error());
401    }
402    if unsafe { libc::sigaddset(&mut mask, libc::SIGTERM) } == -1 {
403        return Err(io::Error::last_os_error());
404    }
405
406    let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, &mask, &mut previous) };
407    if result == -1 {
408        return Err(io::Error::last_os_error());
409    }
410
411    Ok(previous)
412}
413
414#[cfg(unix)]
415fn restore_signal_mask(mask: &libc::sigset_t) -> io::Result<()> {
416    let result = unsafe { libc::sigprocmask(libc::SIG_SETMASK, mask, std::ptr::null_mut()) };
417    if result == -1 {
418        return Err(io::Error::last_os_error());
419    }
420
421    Ok(())
422}
423
424#[cfg(unix)]
425fn install_signal_forwarding_handler(signal: libc::c_int) -> io::Result<libc::sigaction> {
426    let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
427    let mut previous = unsafe { std::mem::zeroed::<libc::sigaction>() };
428    action.sa_sigaction = forward_signal_to_child as *const () as usize;
429    action.sa_flags = 0;
430
431    let mask_result = unsafe { libc::sigemptyset(&mut action.sa_mask) };
432    if mask_result == -1 {
433        return Err(io::Error::last_os_error());
434    }
435
436    let action_result = unsafe { libc::sigaction(signal, &action, &mut previous) };
437    if action_result == -1 {
438        return Err(io::Error::last_os_error());
439    }
440
441    Ok(previous)
442}
443
444#[cfg(unix)]
445fn restore_signal_forwarding_handlers(saved_handlers: &SavedSignalHandlers) -> io::Result<()> {
446    restore_signal_handler(libc::SIGINT, &saved_handlers.sigint)?;
447    restore_signal_handler(libc::SIGTERM, &saved_handlers.sigterm)
448}
449
450#[cfg(unix)]
451fn restore_signal_handler(signal: libc::c_int, action: &libc::sigaction) -> io::Result<()> {
452    let result = unsafe { libc::sigaction(signal, action, std::ptr::null_mut()) };
453    if result == -1 {
454        return Err(io::Error::last_os_error());
455    }
456
457    Ok(())
458}
459
460#[cfg(unix)]
461extern "C" fn forward_signal_to_child(signal: libc::c_int) {
462    let pid = FORWARDED_CHILD_PID.load(Ordering::SeqCst);
463
464    if pid > 0 {
465        unsafe {
466            libc::kill(pid, signal);
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn allocate_port_skips_reserved_ports() {
477        let range = PortRange {
478            start: 29_000,
479            end: 29_001,
480        };
481
482        assert_eq!(
483            allocate_port_with_hints_and_availability(
484                range,
485                &[29_000],
486                AllocationHints::default(),
487                |_| true
488            )
489            .expect("port"),
490            29_001
491        );
492    }
493
494    #[test]
495    fn allocate_port_prefers_available_prior_port() {
496        let range = PortRange {
497            start: 29_000,
498            end: 29_002,
499        };
500        let hints = AllocationHints {
501            preferred_port: Some(29_002),
502            scan_start: None,
503        };
504
505        assert_eq!(
506            allocate_port_with_hints_and_availability(range, &[], hints, |_| true).expect("port"),
507            29_002
508        );
509    }
510
511    #[test]
512    fn allocate_port_scans_from_hint_with_wraparound() {
513        let range = PortRange {
514            start: 29_000,
515            end: 29_003,
516        };
517        let hints = AllocationHints {
518            preferred_port: None,
519            scan_start: Some(29_002),
520        };
521
522        assert_eq!(
523            allocate_port_with_hints_and_availability(
524                range,
525                &[29_002, 29_003, 29_000],
526                hints,
527                |_| true
528            )
529            .expect("port"),
530            29_001
531        );
532    }
533
534    #[test]
535    fn allocate_port_reports_exhausted_range() {
536        let range = PortRange {
537            start: 29_000,
538            end: 29_000,
539        };
540
541        let error = allocate_port_with_hints_and_availability(
542            range,
543            &[29_000],
544            AllocationHints::default(),
545            |_| true,
546        )
547        .expect_err("range should be exhausted");
548        assert!(matches!(error, RunnerError::NoAvailablePort { range: _ }));
549    }
550
551    #[test]
552    fn bind_errors_only_conflict_when_address_is_in_use() {
553        assert!(!bind_error_leaves_port_available(io::ErrorKind::AddrInUse));
554        assert!(bind_error_leaves_port_available(
555            io::ErrorKind::AddrNotAvailable
556        ));
557        assert!(bind_error_leaves_port_available(io::ErrorKind::Unsupported));
558    }
559
560    #[cfg(unix)]
561    #[test]
562    fn signal_forwarding_rejects_concurrent_children_and_restores_handlers() {
563        let before_int = current_signal_action(libc::SIGINT);
564        let before_term = current_signal_action(libc::SIGTERM);
565        let before_mask = current_signal_mask();
566        let command = vec!["sleep".to_string(), "5".to_string()];
567        let range = PortRange {
568            start: 29_000,
569            end: 29_010,
570        };
571
572        let mut first = spawn_child(&command, range, &[]).expect("first child");
573        let error = match spawn_child(&command, range, &[]) {
574            Ok(mut second) => {
575                let _ = second.kill();
576                let _ = second.wait();
577                panic!("second child was not rejected");
578            }
579            Err(error) => error,
580        };
581
582        assert!(
583            matches!(error, RunnerError::SignalForwarding { source } if source.kind() == io::ErrorKind::AlreadyExists)
584        );
585
586        first.kill().expect("kill first child");
587        first.wait().expect("wait for first child");
588
589        assert_signal_action_matches(libc::SIGINT, &before_int);
590        assert_signal_action_matches(libc::SIGTERM, &before_term);
591        assert_signal_mask_matches(libc::SIGINT, &before_mask);
592        assert_signal_mask_matches(libc::SIGTERM, &before_mask);
593    }
594
595    #[cfg(unix)]
596    fn current_signal_action(signal: libc::c_int) -> libc::sigaction {
597        let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
598        let result = unsafe { libc::sigaction(signal, std::ptr::null(), &mut action) };
599        assert_eq!(result, 0, "read signal action for {signal}");
600        action
601    }
602
603    #[cfg(unix)]
604    fn assert_signal_action_matches(signal: libc::c_int, expected: &libc::sigaction) {
605        let actual = current_signal_action(signal);
606        assert_eq!(actual.sa_sigaction, expected.sa_sigaction);
607    }
608
609    #[cfg(unix)]
610    fn current_signal_mask() -> libc::sigset_t {
611        let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
612        let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, std::ptr::null(), &mut mask) };
613        assert_eq!(result, 0, "read signal mask");
614        mask
615    }
616
617    #[cfg(unix)]
618    fn assert_signal_mask_matches(signal: libc::c_int, expected: &libc::sigset_t) {
619        let actual = current_signal_mask();
620        let actual_member = unsafe { libc::sigismember(&actual, signal) };
621        let expected_member = unsafe { libc::sigismember(expected, signal) };
622
623        assert_eq!(actual_member, expected_member);
624    }
625}