Skip to main content

qos_test_primitives/
lib.rs

1//! Primitive types for test setup.
2
3use std::{
4	net::{Ipv4Addr, SocketAddrV4, TcpListener, TcpStream},
5	ops::{Deref, DerefMut},
6	path::Path,
7	process::Child,
8	thread,
9	time::Duration,
10};
11
12const MAX_PORT_BIND_WAIT_TIME: Duration = Duration::from_secs(5);
13const PORT_BIND_WAIT_TIME_INCREMENT: Duration = Duration::from_millis(500);
14const POST_BIND_SLEEP: Duration = Duration::from_millis(500);
15const FIND_FREE_PORT_RETRY_DELAY: Duration = Duration::from_millis(50);
16const MAX_FIND_FREE_PORT_ATTEMPTS: usize = 50;
17const EXIT_DELAY: Duration = Duration::from_millis(50);
18
19/// Wrapper type for [`std::process::Child`] that kills the process on drop.
20#[derive(Debug)]
21pub struct ChildWrapper(Child);
22
23impl From<Child> for ChildWrapper {
24	fn from(child: Child) -> Self {
25		Self(child)
26	}
27}
28
29impl Deref for ChildWrapper {
30	type Target = Child;
31
32	fn deref(&self) -> &Self::Target {
33		&self.0
34	}
35}
36
37impl DerefMut for ChildWrapper {
38	fn deref_mut(&mut self) -> &mut Self::Target {
39		&mut self.0
40	}
41}
42
43impl Drop for ChildWrapper {
44	fn drop(&mut self) {
45		#[cfg(unix)]
46		{
47			use nix::{sys::signal::Signal::SIGINT, unistd::Pid};
48			let pid = Pid::from_raw(self.0.id() as i32);
49			match nix::sys::signal::kill(pid, SIGINT) {
50				Ok(()) => {}
51				Err(err) => eprintln!("error sending signal to child: {err}"),
52			}
53
54			// allow clean exit
55			std::thread::sleep(EXIT_DELAY);
56		}
57
58		// Kill the process and explicitly ignore the result
59		drop(self.0.kill());
60	}
61}
62
63/// Generic wrapper type for anything that implements [`std::convert::AsRef<std::path::Path>`] that attempts to remove a file or
64/// directory at the path on drop.
65#[derive(Debug)]
66pub struct PathWrapper<P: AsRef<Path>>(P);
67
68impl<P: AsRef<Path>> From<P> for PathWrapper<P> {
69	fn from(value: P) -> Self {
70		Self(value)
71	}
72}
73
74impl<P: AsRef<Path>> Drop for PathWrapper<P> {
75	fn drop(&mut self) {
76		// will always fail
77		drop(std::fs::remove_dir_all(&self.0));
78		drop(std::fs::remove_file(&self.0));
79	}
80}
81
82impl<P: AsRef<Path>> Deref for PathWrapper<P> {
83	type Target = Path;
84
85	fn deref(&self) -> &Self::Target {
86		self.as_ref()
87	}
88}
89
90impl<P: AsRef<Path>> AsRef<Path> for PathWrapper<P> {
91	fn as_ref(&self) -> &Path {
92		self.0.as_ref()
93	}
94}
95
96/// Get a bind-able TCP port on the local system.
97#[must_use]
98pub fn find_free_port() -> Option<u16> {
99	let mut last_err = None;
100
101	for _ in 0..MAX_FIND_FREE_PORT_ATTEMPTS {
102		match TcpListener::bind(("127.0.0.1", 0)) {
103			Ok(listener) => {
104				return listener.local_addr().ok().map(|addr| addr.port());
105			}
106			Err(err) => {
107				last_err = Some(err);
108				thread::sleep(FIND_FREE_PORT_RETRY_DELAY);
109			}
110		}
111	}
112
113	if let Some(err) = last_err {
114		eprintln!("failed to find free port: {err}");
115	}
116
117	None
118}
119
120/// Wait until the given `port` is bound. Helpful for telling if something is
121/// listening on the given port.
122///
123/// # Panics
124///
125/// Panics if the port is not bound to within `MAX_PORT_BIND_WAIT_TIME`.
126pub fn wait_until_port_is_bound(port: u16) {
127	let mut wait_time = PORT_BIND_WAIT_TIME_INCREMENT;
128
129	while wait_time < MAX_PORT_BIND_WAIT_TIME {
130		thread::sleep(wait_time);
131		if !can_connect_to_port(port) {
132			wait_time += PORT_BIND_WAIT_TIME_INCREMENT;
133		} else {
134			thread::sleep(POST_BIND_SLEEP);
135			return;
136		}
137	}
138	panic!(
139		"Server has not come up: port {} is still available after {}s",
140		port,
141		MAX_PORT_BIND_WAIT_TIME.as_secs()
142	)
143}
144
145/// Return whether or not a server is accepting connections on the given port.
146fn can_connect_to_port(port: u16) -> bool {
147	let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port);
148	TcpStream::connect_timeout(&addr.into(), Duration::from_millis(100)).is_ok()
149}