1use std::{
4 collections::HashSet,
5 fmt, io,
6 net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener},
7 process::{Child, Command, ExitStatus, Stdio},
8};
9
10#[cfg(unix)]
11use std::os::unix::process::CommandExt;
12#[cfg(unix)]
13use std::sync::atomic::{AtomicI32, Ordering};
14
15use bindport_core::PortRange;
16
17pub const PORT_ENV_VAR: &str = "PORT";
18
19#[cfg(unix)]
20static FORWARDED_CHILD_PID: AtomicI32 = AtomicI32::new(0);
21
22#[cfg(unix)]
23const RESERVED_CHILD_PID: i32 = -1;
24
25#[derive(Debug)]
26pub enum RunnerError {
27 NoCommand,
28 NoAvailablePort { range: PortRange },
29 SignalForwarding { source: io::Error },
30 Spawn { command: String, source: io::Error },
31 Wait { command: String, source: io::Error },
32}
33
34impl fmt::Display for RunnerError {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 Self::NoCommand => write!(f, "no command provided after `--`"),
38 Self::NoAvailablePort { range } => {
39 write!(
40 f,
41 "no available port found in range {}-{}",
42 range.start, range.end
43 )
44 }
45 Self::SignalForwarding { source } => {
46 write!(f, "failed to install signal forwarding: {source}")
47 }
48 Self::Spawn { command, source } => {
49 write!(f, "failed to spawn `{command}`: {source}")
50 }
51 Self::Wait { command, source } => {
52 write!(f, "failed waiting for `{command}`: {source}")
53 }
54 }
55 }
56}
57
58impl std::error::Error for RunnerError {}
59
60pub struct RunningChild {
61 child: Child,
62 port: u16,
63 program: String,
64 signal_forwarding: SignalForwardingState,
65}
66
67impl RunningChild {
68 pub const fn port(&self) -> u16 {
69 self.port
70 }
71
72 pub fn pid(&self) -> u32 {
73 self.child.id()
74 }
75
76 pub fn kill(&mut self) -> io::Result<()> {
77 self.child.kill()
78 }
79
80 pub fn wait(&mut self) -> Result<ExitStatus, RunnerError> {
81 let status = self.child.wait().map_err(|source| RunnerError::Wait {
82 command: self.program.clone(),
83 source,
84 });
85 let signal_forwarding = self
86 .signal_forwarding
87 .deactivate()
88 .map_err(|source| RunnerError::SignalForwarding { source });
89
90 match (status, signal_forwarding) {
91 (Ok(status), Ok(())) => Ok(status),
92 (Err(error), _) | (_, Err(error)) => Err(error),
93 }
94 }
95}
96
97impl Drop for RunningChild {
98 fn drop(&mut self) {
99 let _ = self.signal_forwarding.deactivate();
100 }
101}
102
103#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
104pub struct AllocationHints {
105 pub preferred_port: Option<u16>,
106 pub scan_start: Option<u16>,
107}
108
109pub fn allocate_port(range: PortRange, skip_ports: &[u16]) -> Result<u16, RunnerError> {
115 allocate_port_with_hints(range, skip_ports, AllocationHints::default())
116}
117
118pub fn allocate_port_with_hints(
119 range: PortRange,
120 skip_ports: &[u16],
121 hints: AllocationHints,
122) -> Result<u16, RunnerError> {
123 allocate_port_with_hints_and_availability(range, skip_ports, hints, is_port_available)
124}
125
126fn allocate_port_with_hints_and_availability(
127 range: PortRange,
128 skip_ports: &[u16],
129 hints: AllocationHints,
130 mut is_available: impl FnMut(u16) -> bool,
131) -> Result<u16, RunnerError> {
132 let skip_ports = skip_ports.iter().copied().collect::<HashSet<_>>();
133
134 if let Some(port) = hints
135 .preferred_port
136 .filter(|port| range.contains(*port) && !skip_ports.contains(port))
137 && is_available(port)
138 {
139 return Ok(port);
140 }
141
142 let range_len = range.len();
143 if range_len == 0 {
144 return Err(RunnerError::NoAvailablePort { range });
145 }
146
147 let scan_start = hints
148 .scan_start
149 .filter(|port| range.contains(*port))
150 .unwrap_or(range.start);
151 let scan_start_offset = scan_start as u32 - range.start as u32;
152
153 for offset in 0..range_len {
154 let port = range.start as u32 + ((scan_start_offset + offset) % range_len);
155 let port = u16::try_from(port).expect("port remains within configured range");
156
157 if skip_ports.contains(&port) {
158 continue;
159 }
160
161 if is_available(port) {
162 return Ok(port);
163 }
164 }
165
166 Err(RunnerError::NoAvailablePort { range })
167}
168
169pub fn is_port_available(port: u16) -> bool {
175 let v4 = loopback_free(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
176 let v6 = loopback_free(SocketAddrV6::new(Ipv6Addr::LOCALHOST, port, 0, 0));
177
178 v4 && v6
179}
180
181fn loopback_free(addr: impl Into<SocketAddr>) -> bool {
182 match TcpListener::bind(addr.into()) {
183 Ok(_) => true,
184 Err(error) => bind_error_leaves_port_available(error.kind()),
185 }
186}
187
188fn bind_error_leaves_port_available(kind: io::ErrorKind) -> bool {
189 kind != io::ErrorKind::AddrInUse
190}
191
192pub fn run_child(
193 command: &[String],
194 range: PortRange,
195 skip_ports: &[u16],
196) -> Result<ExitStatus, RunnerError> {
197 let mut child = spawn_child(command, range, skip_ports)?;
198
199 child.wait()
200}
201
202pub fn spawn_child(
207 command: &[String],
208 range: PortRange,
209 skip_ports: &[u16],
210) -> Result<RunningChild, RunnerError> {
211 spawn_child_with_hints(command, range, skip_ports, AllocationHints::default())
212}
213
214pub fn spawn_child_with_hints(
215 command: &[String],
216 range: PortRange,
217 skip_ports: &[u16],
218 allocation_hints: AllocationHints,
219) -> Result<RunningChild, RunnerError> {
220 let port = allocate_port_with_hints(range, skip_ports, allocation_hints)?;
221
222 spawn_child_on_port(command, port, &[])
223}
224
225pub fn spawn_child_on_port(
226 command: &[String],
227 port: u16,
228 extra_env: &[(String, String)],
229) -> Result<RunningChild, RunnerError> {
230 let (program, args) = command.split_first().ok_or(RunnerError::NoCommand)?;
231
232 let mut signal_forwarding =
233 prepare_signal_forwarding().map_err(|source| RunnerError::SignalForwarding { source })?;
234
235 let mut process = Command::new(program);
236 process
237 .args(args)
238 .env(PORT_ENV_VAR, port.to_string())
239 .stdin(Stdio::inherit())
240 .stdout(Stdio::inherit())
241 .stderr(Stdio::inherit());
242 process.envs(extra_env.iter().map(|(name, value)| (name, value)));
243 prepare_child_signal_mask(&mut process, &signal_forwarding);
244
245 let child = match process.spawn() {
246 Ok(child) => child,
247 Err(source) => {
248 if let Err(source) = signal_forwarding.deactivate() {
249 return Err(RunnerError::SignalForwarding { source });
250 }
251
252 return Err(RunnerError::Spawn {
253 command: program.clone(),
254 source,
255 });
256 }
257 };
258 let child = match signal_forwarding.activate_for_pid(child.id()) {
259 Ok(()) => child,
260 Err(source) => {
261 let mut child = child;
262 let _ = child.kill();
263 let _ = child.wait();
264 let _ = signal_forwarding.deactivate();
265
266 return Err(RunnerError::SignalForwarding { source });
267 }
268 };
269
270 Ok(RunningChild {
271 child,
272 port,
273 program: program.clone(),
274 signal_forwarding,
275 })
276}
277
278#[cfg(unix)]
279struct SignalForwardingState {
280 saved_handlers: Option<SavedSignalHandlers>,
281 saved_signal_mask: Option<libc::sigset_t>,
282}
283
284#[cfg(not(unix))]
285struct SignalForwardingState;
286
287#[cfg(unix)]
288struct SavedSignalHandlers {
289 sigint: libc::sigaction,
290 sigterm: libc::sigaction,
291}
292
293#[cfg(unix)]
294impl SignalForwardingState {
295 fn activate_for_pid(&mut self, pid: u32) -> io::Result<()> {
296 FORWARDED_CHILD_PID.store(pid as i32, Ordering::SeqCst);
297 self.restore_signal_mask()
298 }
299
300 fn deactivate(&mut self) -> io::Result<()> {
301 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
302 let signal_mask = self.restore_signal_mask();
303
304 let handlers = if let Some(saved_handlers) = self.saved_handlers.take() {
305 restore_signal_forwarding_handlers(&saved_handlers)
306 } else {
307 Ok(())
308 };
309
310 signal_mask.and(handlers)
311 }
312
313 fn restore_signal_mask(&mut self) -> io::Result<()> {
314 if let Some(saved_signal_mask) = self.saved_signal_mask.as_ref() {
315 restore_signal_mask(saved_signal_mask)?;
316 self.saved_signal_mask = None;
317 }
318
319 Ok(())
320 }
321}
322
323#[cfg(not(unix))]
324impl SignalForwardingState {
325 fn activate_for_pid(&mut self, _pid: u32) -> io::Result<()> {
326 Ok(())
327 }
328
329 fn deactivate(&mut self) -> io::Result<()> {
330 Ok(())
331 }
332}
333
334#[cfg(unix)]
335fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
336 reserve_signal_forwarding()?;
337 let saved_signal_mask = match block_signal_forwarding_signals() {
338 Ok(saved_signal_mask) => saved_signal_mask,
339 Err(error) => {
340 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
341 return Err(error);
342 }
343 };
344
345 match install_signal_forwarding_handlers() {
346 Ok(saved_handlers) => Ok(SignalForwardingState {
347 saved_handlers: Some(saved_handlers),
348 saved_signal_mask: Some(saved_signal_mask),
349 }),
350 Err(error) => {
351 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
352 let _ = restore_signal_mask(&saved_signal_mask);
353 Err(error)
354 }
355 }
356}
357
358#[cfg(not(unix))]
359fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
360 Ok(SignalForwardingState)
361}
362
363#[cfg(unix)]
364fn prepare_child_signal_mask(command: &mut Command, signal_forwarding: &SignalForwardingState) {
365 if let Some(saved_signal_mask) = signal_forwarding.saved_signal_mask {
366 unsafe {
367 command.pre_exec(move || restore_signal_mask(&saved_signal_mask));
368 }
369 }
370}
371
372#[cfg(not(unix))]
373fn prepare_child_signal_mask(_command: &mut Command, _signal_forwarding: &SignalForwardingState) {}
374
375#[cfg(unix)]
376fn reserve_signal_forwarding() -> io::Result<()> {
377 FORWARDED_CHILD_PID
378 .compare_exchange(0, RESERVED_CHILD_PID, Ordering::SeqCst, Ordering::SeqCst)
379 .map(|_| ())
380 .map_err(|_| {
381 io::Error::new(
382 io::ErrorKind::AlreadyExists,
383 "signal forwarding is already active",
384 )
385 })
386}
387
388#[cfg(unix)]
389fn install_signal_forwarding_handlers() -> io::Result<SavedSignalHandlers> {
390 let sigint = install_signal_forwarding_handler(libc::SIGINT)?;
391
392 match install_signal_forwarding_handler(libc::SIGTERM) {
393 Ok(sigterm) => Ok(SavedSignalHandlers { sigint, sigterm }),
394 Err(error) => {
395 let _ = restore_signal_handler(libc::SIGINT, &sigint);
396 Err(error)
397 }
398 }
399}
400
401#[cfg(unix)]
402fn block_signal_forwarding_signals() -> io::Result<libc::sigset_t> {
403 let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
404 let mut previous = unsafe { std::mem::zeroed::<libc::sigset_t>() };
405
406 if unsafe { libc::sigemptyset(&mut mask) } == -1 {
407 return Err(io::Error::last_os_error());
408 }
409 if unsafe { libc::sigaddset(&mut mask, libc::SIGINT) } == -1 {
410 return Err(io::Error::last_os_error());
411 }
412 if unsafe { libc::sigaddset(&mut mask, libc::SIGTERM) } == -1 {
413 return Err(io::Error::last_os_error());
414 }
415
416 let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, &mask, &mut previous) };
417 if result == -1 {
418 return Err(io::Error::last_os_error());
419 }
420
421 Ok(previous)
422}
423
424#[cfg(unix)]
425fn restore_signal_mask(mask: &libc::sigset_t) -> io::Result<()> {
426 let result = unsafe { libc::sigprocmask(libc::SIG_SETMASK, mask, std::ptr::null_mut()) };
427 if result == -1 {
428 return Err(io::Error::last_os_error());
429 }
430
431 Ok(())
432}
433
434#[cfg(unix)]
435fn install_signal_forwarding_handler(signal: libc::c_int) -> io::Result<libc::sigaction> {
436 let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
437 let mut previous = unsafe { std::mem::zeroed::<libc::sigaction>() };
438 action.sa_sigaction = forward_signal_to_child as *const () as usize;
439 action.sa_flags = 0;
440
441 let mask_result = unsafe { libc::sigemptyset(&mut action.sa_mask) };
442 if mask_result == -1 {
443 return Err(io::Error::last_os_error());
444 }
445
446 let action_result = unsafe { libc::sigaction(signal, &action, &mut previous) };
447 if action_result == -1 {
448 return Err(io::Error::last_os_error());
449 }
450
451 Ok(previous)
452}
453
454#[cfg(unix)]
455fn restore_signal_forwarding_handlers(saved_handlers: &SavedSignalHandlers) -> io::Result<()> {
456 restore_signal_handler(libc::SIGINT, &saved_handlers.sigint)?;
457 restore_signal_handler(libc::SIGTERM, &saved_handlers.sigterm)
458}
459
460#[cfg(unix)]
461fn restore_signal_handler(signal: libc::c_int, action: &libc::sigaction) -> io::Result<()> {
462 let result = unsafe { libc::sigaction(signal, action, std::ptr::null_mut()) };
463 if result == -1 {
464 return Err(io::Error::last_os_error());
465 }
466
467 Ok(())
468}
469
470#[cfg(unix)]
471extern "C" fn forward_signal_to_child(signal: libc::c_int) {
472 let pid = FORWARDED_CHILD_PID.load(Ordering::SeqCst);
473
474 if pid > 0 {
475 unsafe {
476 libc::kill(pid, signal);
477 }
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn allocate_port_skips_reserved_ports() {
487 let range = PortRange {
488 start: 29_000,
489 end: 29_001,
490 };
491
492 assert_eq!(
493 allocate_port_with_hints_and_availability(
494 range,
495 &[29_000],
496 AllocationHints::default(),
497 |_| true
498 )
499 .expect("port"),
500 29_001
501 );
502 }
503
504 #[test]
505 fn allocate_port_prefers_available_prior_port() {
506 let range = PortRange {
507 start: 29_000,
508 end: 29_002,
509 };
510 let hints = AllocationHints {
511 preferred_port: Some(29_002),
512 scan_start: None,
513 };
514
515 assert_eq!(
516 allocate_port_with_hints_and_availability(range, &[], hints, |_| true).expect("port"),
517 29_002
518 );
519 }
520
521 #[test]
522 fn allocate_port_scans_from_hint_with_wraparound() {
523 let range = PortRange {
524 start: 29_000,
525 end: 29_003,
526 };
527 let hints = AllocationHints {
528 preferred_port: None,
529 scan_start: Some(29_002),
530 };
531
532 assert_eq!(
533 allocate_port_with_hints_and_availability(
534 range,
535 &[29_002, 29_003, 29_000],
536 hints,
537 |_| true
538 )
539 .expect("port"),
540 29_001
541 );
542 }
543
544 #[test]
545 fn allocate_port_reports_exhausted_range() {
546 let range = PortRange {
547 start: 29_000,
548 end: 29_000,
549 };
550
551 let error = allocate_port_with_hints_and_availability(
552 range,
553 &[29_000],
554 AllocationHints::default(),
555 |_| true,
556 )
557 .expect_err("range should be exhausted");
558 assert!(matches!(error, RunnerError::NoAvailablePort { range: _ }));
559 }
560
561 #[test]
562 fn bind_errors_only_conflict_when_address_is_in_use() {
563 assert!(!bind_error_leaves_port_available(io::ErrorKind::AddrInUse));
564 assert!(bind_error_leaves_port_available(
565 io::ErrorKind::AddrNotAvailable
566 ));
567 assert!(bind_error_leaves_port_available(io::ErrorKind::Unsupported));
568 }
569
570 #[cfg(unix)]
571 #[test]
572 fn signal_forwarding_rejects_concurrent_children_and_restores_handlers() {
573 let before_int = current_signal_action(libc::SIGINT);
574 let before_term = current_signal_action(libc::SIGTERM);
575 let before_mask = current_signal_mask();
576 let command = vec!["sleep".to_string(), "5".to_string()];
577 let range = PortRange {
578 start: 29_000,
579 end: 29_010,
580 };
581
582 let mut first = spawn_child(&command, range, &[]).expect("first child");
583 let error = match spawn_child(&command, range, &[]) {
584 Ok(mut second) => {
585 let _ = second.kill();
586 let _ = second.wait();
587 panic!("second child was not rejected");
588 }
589 Err(error) => error,
590 };
591
592 assert!(
593 matches!(error, RunnerError::SignalForwarding { source } if source.kind() == io::ErrorKind::AlreadyExists)
594 );
595
596 first.kill().expect("kill first child");
597 first.wait().expect("wait for first child");
598
599 assert_signal_action_matches(libc::SIGINT, &before_int);
600 assert_signal_action_matches(libc::SIGTERM, &before_term);
601 assert_signal_mask_matches(libc::SIGINT, &before_mask);
602 assert_signal_mask_matches(libc::SIGTERM, &before_mask);
603 }
604
605 #[cfg(unix)]
606 fn current_signal_action(signal: libc::c_int) -> libc::sigaction {
607 let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
608 let result = unsafe { libc::sigaction(signal, std::ptr::null(), &mut action) };
609 assert_eq!(result, 0, "read signal action for {signal}");
610 action
611 }
612
613 #[cfg(unix)]
614 fn assert_signal_action_matches(signal: libc::c_int, expected: &libc::sigaction) {
615 let actual = current_signal_action(signal);
616 assert_eq!(actual.sa_sigaction, expected.sa_sigaction);
617 }
618
619 #[cfg(unix)]
620 fn current_signal_mask() -> libc::sigset_t {
621 let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
622 let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, std::ptr::null(), &mut mask) };
623 assert_eq!(result, 0, "read signal mask");
624 mask
625 }
626
627 #[cfg(unix)]
628 fn assert_signal_mask_matches(signal: libc::c_int, expected: &libc::sigset_t) {
629 let actual = current_signal_mask();
630 let actual_member = unsafe { libc::sigismember(&actual, signal) };
631 let expected_member = unsafe { libc::sigismember(expected, signal) };
632
633 assert_eq!(actual_member, expected_member);
634 }
635}