1use std::{
4 collections::HashSet,
5 fmt, io,
6 net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener},
7 process::{Child, Command, ExitStatus, Stdio},
8};
9
10#[cfg(unix)]
11use std::os::unix::process::CommandExt;
12#[cfg(unix)]
13use std::sync::atomic::{AtomicI32, Ordering};
14
15use bindport_core::PortRange;
16
17pub const PORT_ENV_VAR: &str = "PORT";
18
19#[cfg(unix)]
20static FORWARDED_CHILD_PID: AtomicI32 = AtomicI32::new(0);
21
22#[cfg(unix)]
23const RESERVED_CHILD_PID: i32 = -1;
24
25#[derive(Debug)]
26pub enum RunnerError {
27 NoCommand,
28 NoAvailablePort { range: PortRange },
29 SignalForwarding { source: io::Error },
30 Spawn { command: String, source: io::Error },
31 Wait { command: String, source: io::Error },
32}
33
34impl fmt::Display for RunnerError {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 Self::NoCommand => write!(f, "no command provided after `--`"),
38 Self::NoAvailablePort { range } => {
39 write!(
40 f,
41 "no available port found in range {}-{}",
42 range.start, range.end
43 )
44 }
45 Self::SignalForwarding { source } => {
46 write!(f, "failed to install signal forwarding: {source}")
47 }
48 Self::Spawn { command, source } => {
49 write!(f, "failed to spawn `{command}`: {source}")
50 }
51 Self::Wait { command, source } => {
52 write!(f, "failed waiting for `{command}`: {source}")
53 }
54 }
55 }
56}
57
58impl std::error::Error for RunnerError {}
59
60pub struct RunningChild {
61 child: Child,
62 port: u16,
63 program: String,
64 signal_forwarding: SignalForwardingState,
65}
66
67impl RunningChild {
68 pub const fn port(&self) -> u16 {
69 self.port
70 }
71
72 pub fn pid(&self) -> u32 {
73 self.child.id()
74 }
75
76 pub fn kill(&mut self) -> io::Result<()> {
77 self.child.kill()
78 }
79
80 pub fn wait(&mut self) -> Result<ExitStatus, RunnerError> {
81 let status = self.child.wait().map_err(|source| RunnerError::Wait {
82 command: self.program.clone(),
83 source,
84 });
85 let signal_forwarding = self
86 .signal_forwarding
87 .deactivate()
88 .map_err(|source| RunnerError::SignalForwarding { source });
89
90 match (status, signal_forwarding) {
91 (Ok(status), Ok(())) => Ok(status),
92 (Err(error), _) | (_, Err(error)) => Err(error),
93 }
94 }
95}
96
97impl Drop for RunningChild {
98 fn drop(&mut self) {
99 let _ = self.signal_forwarding.deactivate();
100 }
101}
102
103#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
104pub struct AllocationHints {
105 pub preferred_port: Option<u16>,
106 pub scan_start: Option<u16>,
107}
108
109pub fn allocate_port(range: PortRange, skip_ports: &[u16]) -> Result<u16, RunnerError> {
115 allocate_port_with_hints(range, skip_ports, AllocationHints::default())
116}
117
118pub fn allocate_port_with_hints(
119 range: PortRange,
120 skip_ports: &[u16],
121 hints: AllocationHints,
122) -> Result<u16, RunnerError> {
123 allocate_port_with_hints_and_availability(range, skip_ports, hints, is_port_available)
124}
125
126fn allocate_port_with_hints_and_availability(
127 range: PortRange,
128 skip_ports: &[u16],
129 hints: AllocationHints,
130 mut is_available: impl FnMut(u16) -> bool,
131) -> Result<u16, RunnerError> {
132 let skip_ports = skip_ports.iter().copied().collect::<HashSet<_>>();
133
134 if let Some(port) = hints
135 .preferred_port
136 .filter(|port| range.contains(*port) && !skip_ports.contains(port))
137 && is_available(port)
138 {
139 return Ok(port);
140 }
141
142 let range_len = range.len();
143 if range_len == 0 {
144 return Err(RunnerError::NoAvailablePort { range });
145 }
146
147 let scan_start = hints
148 .scan_start
149 .filter(|port| range.contains(*port))
150 .unwrap_or(range.start);
151 let scan_start_offset = scan_start as u32 - range.start as u32;
152
153 for offset in 0..range_len {
154 let port = range.start as u32 + ((scan_start_offset + offset) % range_len);
155 let port = u16::try_from(port).expect("port remains within configured range");
156
157 if skip_ports.contains(&port) {
158 continue;
159 }
160
161 if is_available(port) {
162 return Ok(port);
163 }
164 }
165
166 Err(RunnerError::NoAvailablePort { range })
167}
168
169pub fn is_port_available(port: u16) -> bool {
175 let v4 = loopback_free(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
176 let v6 = loopback_free(SocketAddrV6::new(Ipv6Addr::LOCALHOST, port, 0, 0));
177
178 v4 && v6
179}
180
181fn loopback_free(addr: impl Into<SocketAddr>) -> bool {
182 match TcpListener::bind(addr.into()) {
183 Ok(_) => true,
184 Err(error) => bind_error_leaves_port_available(error.kind()),
185 }
186}
187
188fn bind_error_leaves_port_available(kind: io::ErrorKind) -> bool {
189 kind != io::ErrorKind::AddrInUse
190}
191
192pub fn run_child(
193 command: &[String],
194 range: PortRange,
195 skip_ports: &[u16],
196) -> Result<ExitStatus, RunnerError> {
197 let mut child = spawn_child(command, range, skip_ports)?;
198
199 child.wait()
200}
201
202pub fn spawn_child(
207 command: &[String],
208 range: PortRange,
209 skip_ports: &[u16],
210) -> Result<RunningChild, RunnerError> {
211 spawn_child_with_hints(command, range, skip_ports, AllocationHints::default())
212}
213
214pub fn spawn_child_with_hints(
215 command: &[String],
216 range: PortRange,
217 skip_ports: &[u16],
218 allocation_hints: AllocationHints,
219) -> Result<RunningChild, RunnerError> {
220 let (program, args) = command.split_first().ok_or(RunnerError::NoCommand)?;
221 let port = allocate_port_with_hints(range, skip_ports, allocation_hints)?;
222
223 let mut signal_forwarding =
224 prepare_signal_forwarding().map_err(|source| RunnerError::SignalForwarding { source })?;
225
226 let mut process = Command::new(program);
227 process
228 .args(args)
229 .env(PORT_ENV_VAR, port.to_string())
230 .stdin(Stdio::inherit())
231 .stdout(Stdio::inherit())
232 .stderr(Stdio::inherit());
233 prepare_child_signal_mask(&mut process, &signal_forwarding);
234
235 let child = match process.spawn() {
236 Ok(child) => child,
237 Err(source) => {
238 if let Err(source) = signal_forwarding.deactivate() {
239 return Err(RunnerError::SignalForwarding { source });
240 }
241
242 return Err(RunnerError::Spawn {
243 command: program.clone(),
244 source,
245 });
246 }
247 };
248 let child = match signal_forwarding.activate_for_pid(child.id()) {
249 Ok(()) => child,
250 Err(source) => {
251 let mut child = child;
252 let _ = child.kill();
253 let _ = child.wait();
254 let _ = signal_forwarding.deactivate();
255
256 return Err(RunnerError::SignalForwarding { source });
257 }
258 };
259
260 Ok(RunningChild {
261 child,
262 port,
263 program: program.clone(),
264 signal_forwarding,
265 })
266}
267
268#[cfg(unix)]
269struct SignalForwardingState {
270 saved_handlers: Option<SavedSignalHandlers>,
271 saved_signal_mask: Option<libc::sigset_t>,
272}
273
274#[cfg(not(unix))]
275struct SignalForwardingState;
276
277#[cfg(unix)]
278struct SavedSignalHandlers {
279 sigint: libc::sigaction,
280 sigterm: libc::sigaction,
281}
282
283#[cfg(unix)]
284impl SignalForwardingState {
285 fn activate_for_pid(&mut self, pid: u32) -> io::Result<()> {
286 FORWARDED_CHILD_PID.store(pid as i32, Ordering::SeqCst);
287 self.restore_signal_mask()
288 }
289
290 fn deactivate(&mut self) -> io::Result<()> {
291 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
292 let signal_mask = self.restore_signal_mask();
293
294 let handlers = if let Some(saved_handlers) = self.saved_handlers.take() {
295 restore_signal_forwarding_handlers(&saved_handlers)
296 } else {
297 Ok(())
298 };
299
300 signal_mask.and(handlers)
301 }
302
303 fn restore_signal_mask(&mut self) -> io::Result<()> {
304 if let Some(saved_signal_mask) = self.saved_signal_mask.as_ref() {
305 restore_signal_mask(saved_signal_mask)?;
306 self.saved_signal_mask = None;
307 }
308
309 Ok(())
310 }
311}
312
313#[cfg(not(unix))]
314impl SignalForwardingState {
315 fn activate_for_pid(&mut self, _pid: u32) -> io::Result<()> {
316 Ok(())
317 }
318
319 fn deactivate(&mut self) -> io::Result<()> {
320 Ok(())
321 }
322}
323
324#[cfg(unix)]
325fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
326 reserve_signal_forwarding()?;
327 let saved_signal_mask = match block_signal_forwarding_signals() {
328 Ok(saved_signal_mask) => saved_signal_mask,
329 Err(error) => {
330 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
331 return Err(error);
332 }
333 };
334
335 match install_signal_forwarding_handlers() {
336 Ok(saved_handlers) => Ok(SignalForwardingState {
337 saved_handlers: Some(saved_handlers),
338 saved_signal_mask: Some(saved_signal_mask),
339 }),
340 Err(error) => {
341 FORWARDED_CHILD_PID.store(0, Ordering::SeqCst);
342 let _ = restore_signal_mask(&saved_signal_mask);
343 Err(error)
344 }
345 }
346}
347
348#[cfg(not(unix))]
349fn prepare_signal_forwarding() -> io::Result<SignalForwardingState> {
350 Ok(SignalForwardingState)
351}
352
353#[cfg(unix)]
354fn prepare_child_signal_mask(command: &mut Command, signal_forwarding: &SignalForwardingState) {
355 if let Some(saved_signal_mask) = signal_forwarding.saved_signal_mask {
356 unsafe {
357 command.pre_exec(move || restore_signal_mask(&saved_signal_mask));
358 }
359 }
360}
361
362#[cfg(not(unix))]
363fn prepare_child_signal_mask(_command: &mut Command, _signal_forwarding: &SignalForwardingState) {}
364
365#[cfg(unix)]
366fn reserve_signal_forwarding() -> io::Result<()> {
367 FORWARDED_CHILD_PID
368 .compare_exchange(0, RESERVED_CHILD_PID, Ordering::SeqCst, Ordering::SeqCst)
369 .map(|_| ())
370 .map_err(|_| {
371 io::Error::new(
372 io::ErrorKind::AlreadyExists,
373 "signal forwarding is already active",
374 )
375 })
376}
377
378#[cfg(unix)]
379fn install_signal_forwarding_handlers() -> io::Result<SavedSignalHandlers> {
380 let sigint = install_signal_forwarding_handler(libc::SIGINT)?;
381
382 match install_signal_forwarding_handler(libc::SIGTERM) {
383 Ok(sigterm) => Ok(SavedSignalHandlers { sigint, sigterm }),
384 Err(error) => {
385 let _ = restore_signal_handler(libc::SIGINT, &sigint);
386 Err(error)
387 }
388 }
389}
390
391#[cfg(unix)]
392fn block_signal_forwarding_signals() -> io::Result<libc::sigset_t> {
393 let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
394 let mut previous = unsafe { std::mem::zeroed::<libc::sigset_t>() };
395
396 if unsafe { libc::sigemptyset(&mut mask) } == -1 {
397 return Err(io::Error::last_os_error());
398 }
399 if unsafe { libc::sigaddset(&mut mask, libc::SIGINT) } == -1 {
400 return Err(io::Error::last_os_error());
401 }
402 if unsafe { libc::sigaddset(&mut mask, libc::SIGTERM) } == -1 {
403 return Err(io::Error::last_os_error());
404 }
405
406 let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, &mask, &mut previous) };
407 if result == -1 {
408 return Err(io::Error::last_os_error());
409 }
410
411 Ok(previous)
412}
413
414#[cfg(unix)]
415fn restore_signal_mask(mask: &libc::sigset_t) -> io::Result<()> {
416 let result = unsafe { libc::sigprocmask(libc::SIG_SETMASK, mask, std::ptr::null_mut()) };
417 if result == -1 {
418 return Err(io::Error::last_os_error());
419 }
420
421 Ok(())
422}
423
424#[cfg(unix)]
425fn install_signal_forwarding_handler(signal: libc::c_int) -> io::Result<libc::sigaction> {
426 let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
427 let mut previous = unsafe { std::mem::zeroed::<libc::sigaction>() };
428 action.sa_sigaction = forward_signal_to_child as *const () as usize;
429 action.sa_flags = 0;
430
431 let mask_result = unsafe { libc::sigemptyset(&mut action.sa_mask) };
432 if mask_result == -1 {
433 return Err(io::Error::last_os_error());
434 }
435
436 let action_result = unsafe { libc::sigaction(signal, &action, &mut previous) };
437 if action_result == -1 {
438 return Err(io::Error::last_os_error());
439 }
440
441 Ok(previous)
442}
443
444#[cfg(unix)]
445fn restore_signal_forwarding_handlers(saved_handlers: &SavedSignalHandlers) -> io::Result<()> {
446 restore_signal_handler(libc::SIGINT, &saved_handlers.sigint)?;
447 restore_signal_handler(libc::SIGTERM, &saved_handlers.sigterm)
448}
449
450#[cfg(unix)]
451fn restore_signal_handler(signal: libc::c_int, action: &libc::sigaction) -> io::Result<()> {
452 let result = unsafe { libc::sigaction(signal, action, std::ptr::null_mut()) };
453 if result == -1 {
454 return Err(io::Error::last_os_error());
455 }
456
457 Ok(())
458}
459
460#[cfg(unix)]
461extern "C" fn forward_signal_to_child(signal: libc::c_int) {
462 let pid = FORWARDED_CHILD_PID.load(Ordering::SeqCst);
463
464 if pid > 0 {
465 unsafe {
466 libc::kill(pid, signal);
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn allocate_port_skips_reserved_ports() {
477 let range = PortRange {
478 start: 29_000,
479 end: 29_001,
480 };
481
482 assert_eq!(
483 allocate_port_with_hints_and_availability(
484 range,
485 &[29_000],
486 AllocationHints::default(),
487 |_| true
488 )
489 .expect("port"),
490 29_001
491 );
492 }
493
494 #[test]
495 fn allocate_port_prefers_available_prior_port() {
496 let range = PortRange {
497 start: 29_000,
498 end: 29_002,
499 };
500 let hints = AllocationHints {
501 preferred_port: Some(29_002),
502 scan_start: None,
503 };
504
505 assert_eq!(
506 allocate_port_with_hints_and_availability(range, &[], hints, |_| true).expect("port"),
507 29_002
508 );
509 }
510
511 #[test]
512 fn allocate_port_scans_from_hint_with_wraparound() {
513 let range = PortRange {
514 start: 29_000,
515 end: 29_003,
516 };
517 let hints = AllocationHints {
518 preferred_port: None,
519 scan_start: Some(29_002),
520 };
521
522 assert_eq!(
523 allocate_port_with_hints_and_availability(
524 range,
525 &[29_002, 29_003, 29_000],
526 hints,
527 |_| true
528 )
529 .expect("port"),
530 29_001
531 );
532 }
533
534 #[test]
535 fn allocate_port_reports_exhausted_range() {
536 let range = PortRange {
537 start: 29_000,
538 end: 29_000,
539 };
540
541 let error = allocate_port_with_hints_and_availability(
542 range,
543 &[29_000],
544 AllocationHints::default(),
545 |_| true,
546 )
547 .expect_err("range should be exhausted");
548 assert!(matches!(error, RunnerError::NoAvailablePort { range: _ }));
549 }
550
551 #[test]
552 fn bind_errors_only_conflict_when_address_is_in_use() {
553 assert!(!bind_error_leaves_port_available(io::ErrorKind::AddrInUse));
554 assert!(bind_error_leaves_port_available(
555 io::ErrorKind::AddrNotAvailable
556 ));
557 assert!(bind_error_leaves_port_available(io::ErrorKind::Unsupported));
558 }
559
560 #[cfg(unix)]
561 #[test]
562 fn signal_forwarding_rejects_concurrent_children_and_restores_handlers() {
563 let before_int = current_signal_action(libc::SIGINT);
564 let before_term = current_signal_action(libc::SIGTERM);
565 let before_mask = current_signal_mask();
566 let command = vec!["sleep".to_string(), "5".to_string()];
567 let range = PortRange {
568 start: 29_000,
569 end: 29_010,
570 };
571
572 let mut first = spawn_child(&command, range, &[]).expect("first child");
573 let error = match spawn_child(&command, range, &[]) {
574 Ok(mut second) => {
575 let _ = second.kill();
576 let _ = second.wait();
577 panic!("second child was not rejected");
578 }
579 Err(error) => error,
580 };
581
582 assert!(
583 matches!(error, RunnerError::SignalForwarding { source } if source.kind() == io::ErrorKind::AlreadyExists)
584 );
585
586 first.kill().expect("kill first child");
587 first.wait().expect("wait for first child");
588
589 assert_signal_action_matches(libc::SIGINT, &before_int);
590 assert_signal_action_matches(libc::SIGTERM, &before_term);
591 assert_signal_mask_matches(libc::SIGINT, &before_mask);
592 assert_signal_mask_matches(libc::SIGTERM, &before_mask);
593 }
594
595 #[cfg(unix)]
596 fn current_signal_action(signal: libc::c_int) -> libc::sigaction {
597 let mut action = unsafe { std::mem::zeroed::<libc::sigaction>() };
598 let result = unsafe { libc::sigaction(signal, std::ptr::null(), &mut action) };
599 assert_eq!(result, 0, "read signal action for {signal}");
600 action
601 }
602
603 #[cfg(unix)]
604 fn assert_signal_action_matches(signal: libc::c_int, expected: &libc::sigaction) {
605 let actual = current_signal_action(signal);
606 assert_eq!(actual.sa_sigaction, expected.sa_sigaction);
607 }
608
609 #[cfg(unix)]
610 fn current_signal_mask() -> libc::sigset_t {
611 let mut mask = unsafe { std::mem::zeroed::<libc::sigset_t>() };
612 let result = unsafe { libc::sigprocmask(libc::SIG_BLOCK, std::ptr::null(), &mut mask) };
613 assert_eq!(result, 0, "read signal mask");
614 mask
615 }
616
617 #[cfg(unix)]
618 fn assert_signal_mask_matches(signal: libc::c_int, expected: &libc::sigset_t) {
619 let actual = current_signal_mask();
620 let actual_member = unsafe { libc::sigismember(&actual, signal) };
621 let expected_member = unsafe { libc::sigismember(expected, signal) };
622
623 assert_eq!(actual_member, expected_member);
624 }
625}