Skip to main content

bindport_runner/
lib.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    collections::HashSet,
5    fmt, io,
6    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener},
7    process::{Child, Command, ExitStatus, Stdio},
8};
9
10#[cfg(unix)]
11use std::os::unix::process::CommandExt;
12#[cfg(unix)]
13use std::sync::atomic::{AtomicI32, Ordering};
14
15use bindport_core::PortRange;
16
17pub const PORT_ENV_VAR: &str = "PORT";
18
19#[cfg(unix)]
20static FORWARDED_CHILD_PID: AtomicI32 = AtomicI32::new(0);
21
22#[cfg(unix)]
23const RESERVED_CHILD_PID: i32 = -1;
24
25#[derive(Debug)]
26pub enum RunnerError {
27    NoCommand,
28    NoAvailablePort { range: PortRange },
29    SignalForwarding { source: io::Error },
30    Spawn { command: String, source: io::Error },
31    Wait { command: String, source: io::Error },
32}
33
34impl fmt::Display for RunnerError {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        match self {
37            Self::NoCommand => write!(f, "no command provided after `--`"),
38            Self::NoAvailablePort { range } => {
39                write!(
40                    f,
41                    "no available port found in range {}-{}",
42                    range.start, range.end
43                )
44            }
45            Self::SignalForwarding { source } => {
46                write!(f, "failed to install signal forwarding: {source}")
47            }
48            Self::Spawn { command, source } => {
49                write!(f, "failed to spawn `{command}`: {source}")
50            }
51            Self::Wait { command, source } => {
52                write!(f, "failed waiting for `{command}`: {source}")
53            }
54        }
55    }
56}
57
58impl std::error::Error for RunnerError {}
59
60pub struct RunningChild {
61    child: Child,
62    port: u16,
63    program: String,
64    signal_forwarding: SignalForwardingState,
65}
66
67impl RunningChild {
68    pub const fn port(&self) -> u16 {
69        self.port
70    }
71
72    pub fn pid(&self) -> u32 {
73        self.child.id()
74    }
75
76    pub fn kill(&mut self) -> io::Result<()> {
77        self.child.kill()
78    }
79
80    pub fn wait(&mut self) -> Result<ExitStatus, RunnerError> {
81        let status = self.child.wait().map_err(|source| RunnerError::Wait {
82            command: self.program.clone(),
83            source,
84        });
85        let signal_forwarding = self
86            .signal_forwarding
87            .deactivate()
88            .map_err(|source| RunnerError::SignalForwarding { source });
89
90        match (status, signal_forwarding) {
91            (Ok(status), Ok(())) => Ok(status),
92            (Err(error), _) | (_, Err(error)) => Err(error),
93        }
94    }
95}
96
97impl Drop for RunningChild {
98    fn drop(&mut self) {
99        let _ = self.signal_forwarding.deactivate();
100    }
101}
102
103#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
104pub struct AllocationHints {
105    pub preferred_port: Option<u16>,
106    pub scan_start: Option<u16>,
107}
108
109/// Scans the configured TCP loopback range and returns an available port.
110///
111/// This bootstrap runner drops the probe listener before spawning the child, so
112/// another process can still claim the port before the child binds. The
113/// registry/lease slice must close that gap for strong coordination.
114pub fn allocate_port(range: PortRange, skip_ports: &[u16]) -> Result<u16, RunnerError> {
115    allocate_port_with_hints(range, skip_ports, AllocationHints::default())
116}
117
118pub fn allocate_port_with_hints(
119    range: PortRange,
120    skip_ports: &[u16],
121    hints: AllocationHints,
122) -> Result<u16, RunnerError> {
123    allocate_port_with_hints_and_availability(range, skip_ports, hints, is_port_available)
124}
125
126fn allocate_port_with_hints_and_availability(
127    range: PortRange,
128    skip_ports: &[u16],
129    hints: AllocationHints,
130    mut is_available: impl FnMut(u16) -> bool,
131) -> Result<u16, RunnerError> {
132    let skip_ports = skip_ports.iter().copied().collect::<HashSet<_>>();
133
134    if let Some(port) = hints
135        .preferred_port
136        .filter(|port| range.contains(*port) && !skip_ports.contains(port))
137        && is_available(port)
138    {
139        return Ok(port);
140    }
141
142    let range_len = range.len();
143    if range_len == 0 {
144        return Err(RunnerError::NoAvailablePort { range });
145    }
146
147    let scan_start = hints
148        .scan_start
149        .filter(|port| range.contains(*port))
150        .unwrap_or(range.start);
151    let scan_start_offset = scan_start as u32 - range.start as u32;
152
153    for offset in 0..range_len {
154        let port = range.start as u32 + ((scan_start_offset + offset) % range_len);
155        let port = u16::try_from(port).expect("port remains within configured range");
156
157        if skip_ports.contains(&port) {
158            continue;
159        }
160
161        if is_available(port) {
162            return Ok(port);
163        }
164    }
165
166    Err(RunnerError::NoAvailablePort { range })
167}
168
169/// Returns true when no supported TCP loopback family reports `port` in use.
170///
171/// Missing address families are not conflicts, so IPv4-only hosts can still
172/// allocate a loopback port. UDP availability is outside the current runner
173/// scope.
174pub fn is_port_available(port: u16) -> bool {
175    let v4 = loopback_free(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
176    let v6 = loopback_free(SocketAddrV6::new(Ipv6Addr::LOCALHOST, port, 0, 0));
177
178    v4 && v6
179}
180
181fn loopback_free(addr: impl Into<SocketAddr>) -> bool {
182    match TcpListener::bind(addr.into()) {
183        Ok(_) => true,
184        Err(error) => bind_error_leaves_port_available(error.kind()),
185    }
186}
187
188fn bind_error_leaves_port_available(kind: io::ErrorKind) -> bool {
189    kind != io::ErrorKind::AddrInUse
190}
191
192pub fn run_child(
193    command: &[String],
194    range: PortRange,
195    skip_ports: &[u16],
196) -> Result<ExitStatus, RunnerError> {
197    let mut child = spawn_child(command, range, skip_ports)?;
198
199    child.wait()
200}
201
202/// Spawns a wrapped command with the selected port in its environment.
203///
204/// On Unix, SIGINT/SIGTERM forwarding uses process-global signal handlers while
205/// the returned child is active. A second concurrent forwarded child is rejected.
206pub fn spawn_child(
207    command: &[String],
208    range: PortRange,
209    skip_ports: &[u16],
210) -> Result<RunningChild, RunnerError> {
211    spawn_child_with_hints(command, range, skip_ports, AllocationHints::default())
212}
213
214pub fn spawn_child_with_hints(
215    command: &[String],
216    range: PortRange,
217    skip_ports: &[u16],
218    allocation_hints: AllocationHints,
219) -> Result<RunningChild, RunnerError> {
220    let port = allocate_port_with_hints(range, skip_ports, allocation_hints)?;
221
222    spawn_child_on_port(command, port, &[])
223}
224
225pub fn spawn_child_on_port(
226    command: &[String],
227    port: u16,
228    extra_env: &[(String, String)],
229) -> Result<RunningChild, RunnerError> {
230    let (program, args) = command.split_first().ok_or(RunnerError::NoCommand)?;
231
232    let mut signal_forwarding =
233        prepare_signal_forwarding().map_err(|source| RunnerError::SignalForwarding { source })?;
234
235    let mut process = Command::new(program);
236    process
237        .args(args)
238        .env(PORT_ENV_VAR, port.to_string())
239        .stdin(Stdio::inherit())
240        .stdout(Stdio::inherit())
241        .stderr(Stdio::inherit());
242    process.envs(extra_env.iter().map(|(name, value)| (name, value)));
243    prepare_child_signal_mask(&mut process, &signal_forwarding);
244
245    let child = match process.spawn() {
246        Ok(child) => child,
247        Err(source) => {
248            if let Err(source) = signal_forwarding.deactivate() {
249                return Err(RunnerError::SignalForwarding { source });
250            }
251
252            return Err(RunnerError::Spawn {
253                command: program.clone(),
254                source,
255            });
256        }
257    };
258    let child = match signal_forwarding.activate_for_pid(child.id()) {
259        Ok(()) => child,
260        Err(source) => {
261            let mut child = child;
262            let _ = child.kill();
263            let _ = child.wait();
264            let _ = signal_forwarding.deactivate();
265
266            return Err(RunnerError::SignalForwarding { source });
267        }
268    };
269
270    Ok(RunningChild {
271        child,
272        port,
273        program: program.clone(),
274        signal_forwarding,
275    })
276}
277
278#[cfg(unix)]
279struct SignalForwardingState {
280    saved_handlers: Option<SavedSignalHandlers>,
281    saved_signal_mask: Option<libc::sigset_t>,
282}
283
284#[cfg(not(unix))]
285struct SignalForwardingState;
286
287#[cfg(unix)]
288struct SavedSignalHandlers {
289    sigint: libc::sigaction,
290    sigterm: libc::sigaction,
291}
292
293#[cfg(unix)]
294impl SignalForwardingState {
295    fn activate_for_pid(&mut self, pid: u32) -> io::Result<()> {
296        FORWARDED_CHILD_PID.store(pid as i32, Ordering::SeqCst);
297        self.restore_signal_mask()
298    }
299
300    fn deactivate(&mut self) -> io::Result<()> {
301        FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
302        let signal_mask = self.restore_signal_mask();
303
304        let handlers = if let Some(saved_handlers) = self.saved_handlers.take() {
305            restore_signal_forwarding_handlers(&saved_handlers)
306        } else {
307            Ok(())
308        };
309
310        signal_mask.and(handlers)
311    }
312
313    fn restore_signal_mask(&mut self) -> io::Result<()> {
314        if let Some(saved_signal_mask) = self.saved_signal_mask.as_ref() {
315            restore_signal_mask(saved_signal_mask)?;
316            self.saved_signal_mask = None;
317        }
318
319        Ok(())
320    }
321}
322
323#[cfg(not(unix))]
324impl SignalForwardingState {
325    fn activate_for_pid(&mut self, _pid: u32) -> io::Result<()> {
326        Ok(())
327    }
328
329    fn deactivate(&mut self) -> io::Result<()> {
330        Ok(())
331    }
332}
333
334#[cfg(unix)]
335fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
336    reserve_signal_forwarding()?;
337    let saved_signal_mask = match block_signal_forwarding_signals() {
338        Ok(saved_signal_mask) => saved_signal_mask,
339        Err(error) => {
340            FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
341            return Err(error);
342        }
343    };
344
345    match install_signal_forwarding_handlers() {
346        Ok(saved_handlers) => Ok(SignalForwardingState {
347            saved_handlers: Some(saved_handlers),
348            saved_signal_mask: Some(saved_signal_mask),
349        }),
350        Err(error) => {
351            FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
352            let _ = restore_signal_mask(&saved_signal_mask);
353            Err(error)
354        }
355    }
356}
357
358#[cfg(not(unix))]
359fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
360    Ok(SignalForwardingState)
361}
362
363#[cfg(unix)]
364fn prepare_child_signal_mask(command: &mut Command, signal_forwarding: &SignalForwardingState) {
365    if let Some(saved_signal_mask) = signal_forwarding.saved_signal_mask {
366        unsafe {
367            command.pre_exec(move || restore_signal_mask(&saved_signal_mask));
368        }
369    }
370}
371
372#[cfg(not(unix))]
373fn prepare_child_signal_mask(_command: &mut Command, _signal_forwarding: &SignalForwardingState) {}
374
375#[cfg(unix)]
376fn reserve_signal_forwarding() -> io::Result<()> {
377    FORWARDED_CHILD_PID
378        .compare_exchange(0, RESERVED_CHILD_PID, Ordering::SeqCst, Ordering::SeqCst)
379        .map(|_| ())
380        .map_err(|_| {
381            io::Error::new(
382                io::ErrorKind::AlreadyExists,
383                "signal forwarding is already active",
384            )
385        })
386}
387
388#[cfg(unix)]
389fn install_signal_forwarding_handlers() -> io::Result<SavedSignalHandlers> {
390    let sigint = install_signal_forwarding_handler(libc::SIGINT)?;
391
392    match install_signal_forwarding_handler(libc::SIGTERM) {
393        Ok(sigterm) => Ok(SavedSignalHandlers { sigint, sigterm }),
394        Err(error) => {
395            let _ = restore_signal_handler(libc::SIGINT, &sigint);
396            Err(error)
397        }
398    }
399}
400
401#[cfg(unix)]
402fn block_signal_forwarding_signals() -> io::Result<libc::sigset_t> {
403    let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
404    let mut previous = unsafe { std::mem::zeroed::<libc::sigset_t>() };
405
406    if unsafe { libc::sigemptyset(&mut mask) } == -1 {
407        return Err(io::Error::last_os_error());
408    }
409    if unsafe { libc::sigaddset(&mut mask, libc::SIGINT) } == -1 {
410        return Err(io::Error::last_os_error());
411    }
412    if unsafe { libc::sigaddset(&mut mask, libc::SIGTERM) } == -1 {
413        return Err(io::Error::last_os_error());
414    }
415
416    let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, &mask, &mut previous) };
417    if result == -1 {
418        return Err(io::Error::last_os_error());
419    }
420
421    Ok(previous)
422}
423
424#[cfg(unix)]
425fn restore_signal_mask(mask: &libc::sigset_t) -> io::Result<()> {
426    let result = unsafe { libc::sigprocmask(libc::SIG_SETMASK, mask, std::ptr::null_mut()) };
427    if result == -1 {
428        return Err(io::Error::last_os_error());
429    }
430
431    Ok(())
432}
433
434#[cfg(unix)]
435fn install_signal_forwarding_handler(signal: libc::c_int) -> io::Result<libc::sigaction> {
436    let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
437    let mut previous = unsafe { std::mem::zeroed::<libc::sigaction>() };
438    action.sa_sigaction = forward_signal_to_child as *const () as usize;
439    action.sa_flags = 0;
440
441    let mask_result = unsafe { libc::sigemptyset(&mut action.sa_mask) };
442    if mask_result == -1 {
443        return Err(io::Error::last_os_error());
444    }
445
446    let action_result = unsafe { libc::sigaction(signal, &action, &mut previous) };
447    if action_result == -1 {
448        return Err(io::Error::last_os_error());
449    }
450
451    Ok(previous)
452}
453
454#[cfg(unix)]
455fn restore_signal_forwarding_handlers(saved_handlers: &SavedSignalHandlers) -> io::Result<()> {
456    restore_signal_handler(libc::SIGINT, &saved_handlers.sigint)?;
457    restore_signal_handler(libc::SIGTERM, &saved_handlers.sigterm)
458}
459
460#[cfg(unix)]
461fn restore_signal_handler(signal: libc::c_int, action: &libc::sigaction) -> io::Result<()> {
462    let result = unsafe { libc::sigaction(signal, action, std::ptr::null_mut()) };
463    if result == -1 {
464        return Err(io::Error::last_os_error());
465    }
466
467    Ok(())
468}
469
470#[cfg(unix)]
471extern "C" fn forward_signal_to_child(signal: libc::c_int) {
472    let pid = FORWARDED_CHILD_PID.load(Ordering::SeqCst);
473
474    if pid > 0 {
475        unsafe {
476            libc::kill(pid, signal);
477        }
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn allocate_port_skips_reserved_ports() {
487        let range = PortRange {
488            start: 29_000,
489            end: 29_001,
490        };
491
492        assert_eq!(
493            allocate_port_with_hints_and_availability(
494                range,
495                &[29_000],
496                AllocationHints::default(),
497                |_| true
498            )
499            .expect("port"),
500            29_001
501        );
502    }
503
504    #[test]
505    fn allocate_port_prefers_available_prior_port() {
506        let range = PortRange {
507            start: 29_000,
508            end: 29_002,
509        };
510        let hints = AllocationHints {
511            preferred_port: Some(29_002),
512            scan_start: None,
513        };
514
515        assert_eq!(
516            allocate_port_with_hints_and_availability(range, &[], hints, |_| true).expect("port"),
517            29_002
518        );
519    }
520
521    #[test]
522    fn allocate_port_scans_from_hint_with_wraparound() {
523        let range = PortRange {
524            start: 29_000,
525            end: 29_003,
526        };
527        let hints = AllocationHints {
528            preferred_port: None,
529            scan_start: Some(29_002),
530        };
531
532        assert_eq!(
533            allocate_port_with_hints_and_availability(
534                range,
535                &[29_002, 29_003, 29_000],
536                hints,
537                |_| true
538            )
539            .expect("port"),
540            29_001
541        );
542    }
543
544    #[test]
545    fn allocate_port_reports_exhausted_range() {
546        let range = PortRange {
547            start: 29_000,
548            end: 29_000,
549        };
550
551        let error = allocate_port_with_hints_and_availability(
552            range,
553            &[29_000],
554            AllocationHints::default(),
555            |_| true,
556        )
557        .expect_err("range should be exhausted");
558        assert!(matches!(error, RunnerError::NoAvailablePort { range: _ }));
559    }
560
561    #[test]
562    fn bind_errors_only_conflict_when_address_is_in_use() {
563        assert!(!bind_error_leaves_port_available(io::ErrorKind::AddrInUse));
564        assert!(bind_error_leaves_port_available(
565            io::ErrorKind::AddrNotAvailable
566        ));
567        assert!(bind_error_leaves_port_available(io::ErrorKind::Unsupported));
568    }
569
570    #[cfg(unix)]
571    #[test]
572    fn signal_forwarding_rejects_concurrent_children_and_restores_handlers() {
573        let before_int = current_signal_action(libc::SIGINT);
574        let before_term = current_signal_action(libc::SIGTERM);
575        let before_mask = current_signal_mask();
576        let command = vec!["sleep".to_string(), "5".to_string()];
577        let range = PortRange {
578            start: 29_000,
579            end: 29_010,
580        };
581
582        let mut first = spawn_child(&command, range, &[]).expect("first child");
583        let error = match spawn_child(&command, range, &[]) {
584            Ok(mut second) => {
585                let _ = second.kill();
586                let _ = second.wait();
587                panic!("second child was not rejected");
588            }
589            Err(error) => error,
590        };
591
592        assert!(
593            matches!(error, RunnerError::SignalForwarding { source } if source.kind() == io::ErrorKind::AlreadyExists)
594        );
595
596        first.kill().expect("kill first child");
597        first.wait().expect("wait for first child");
598
599        assert_signal_action_matches(libc::SIGINT, &before_int);
600        assert_signal_action_matches(libc::SIGTERM, &before_term);
601        assert_signal_mask_matches(libc::SIGINT, &before_mask);
602        assert_signal_mask_matches(libc::SIGTERM, &before_mask);
603    }
604
605    #[cfg(unix)]
606    fn current_signal_action(signal: libc::c_int) -> libc::sigaction {
607        let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
608        let result = unsafe { libc::sigaction(signal, std::ptr::null(), &mut action) };
609        assert_eq!(result, 0, "read signal action for {signal}");
610        action
611    }
612
613    #[cfg(unix)]
614    fn assert_signal_action_matches(signal: libc::c_int, expected: &libc::sigaction) {
615        let actual = current_signal_action(signal);
616        assert_eq!(actual.sa_sigaction, expected.sa_sigaction);
617    }
618
619    #[cfg(unix)]
620    fn current_signal_mask() -> libc::sigset_t {
621        let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
622        let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, std::ptr::null(), &mut mask) };
623        assert_eq!(result, 0, "read signal mask");
624        mask
625    }
626
627    #[cfg(unix)]
628    fn assert_signal_mask_matches(signal: libc::c_int, expected: &libc::sigset_t) {
629        let actual = current_signal_mask();
630        let actual_member = unsafe { libc::sigismember(&actual, signal) };
631        let expected_member = unsafe { libc::sigismember(expected, signal) };
632
633        assert_eq!(actual_member, expected_member);
634    }
635}