#![cfg_attr(target_os = "windows", allow(unused_imports, dead_code))]
use std::{
io::{self, Read, Write},
net::{TcpListener, TcpStream},
str::FromStr,
sync::{Arc, Barrier},
time::Duration,
};
use dumbpipe::EndpointTicket;
use rand::RngExt;
fn dumbpipe_bin() -> &'static str {
env!("CARGO_BIN_EXE_dumbpipe")
}
fn read_ascii_lines(mut n: usize, reader: &mut impl Read) -> io::Result<Vec<u8>> {
let mut buf = [0u8; 1];
let mut res = Vec::new();
loop {
if reader.read(&mut buf)? != 1 {
break;
}
let char = buf[0];
res.push(char);
if char != b'\n' {
continue;
}
if n > 1 {
n -= 1;
} else {
break;
}
}
Ok(res)
}
fn wait2() -> Arc<Barrier> {
Arc::new(Barrier::new(2))
}
fn random_port() -> u16 {
rand::rng().random_range(10000u16..60000)
}
#[test]
#[ignore = "flaky"]
fn connect_listen_happy() {
let listen_to_connect = b"hello from listen";
let connect_to_listen = b"hello from connect";
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") .stdin_bytes(listen_to_connect)
.stderr_to_stdout() .reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") .stdin_bytes(connect_to_listen)
.stderr_null()
.stdout_capture()
.run()
.unwrap();
assert!(connect.status.success());
assert!(connect.stdout.starts_with(listen_to_connect));
let mut listen_stdout = Vec::new();
listen.read_to_end(&mut listen_stdout).unwrap();
assert!(listen_stdout.starts_with(connect_to_listen));
}
#[test]
#[ignore = "flaky"]
fn connect_listen_custom_alpn_happy() {
let listen_to_connect = b"hello from listen";
let connect_to_listen = b"hello from connect";
let mut listen = duct::cmd(
dumbpipe_bin(),
["listen", "--custom-alpn", "utf8:mysuperalpn/0.1.0"],
)
.env_remove("RUST_LOG") .stdin_bytes(listen_to_connect)
.stderr_to_stdout() .reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let connect = duct::cmd(
dumbpipe_bin(),
[
"connect",
&ticket.to_string(),
"--custom-alpn",
"utf8:mysuperalpn/0.1.0",
],
)
.env_remove("RUST_LOG") .stdin_bytes(connect_to_listen)
.stderr_null()
.stdout_capture()
.run()
.unwrap();
assert!(connect.status.success());
assert!(connect.stdout.starts_with(listen_to_connect));
let mut listen_stdout = Vec::new();
listen.read_to_end(&mut listen_stdout).unwrap();
assert!(listen_stdout.starts_with(connect_to_listen));
}
#[cfg(unix)]
#[test]
fn connect_listen_ctrlc_connect() {
use nix::{
sys::signal::{self, Signal},
unistd::Pid,
};
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") .stdin_bytes(b"hello from listen\n")
.stderr_to_stdout() .reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let mut connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") .stderr_null()
.stdout_capture()
.reader()
.unwrap();
read_ascii_lines(1, &mut connect).unwrap();
for pid in connect.pids() {
signal::kill(Pid::from_raw(pid as i32), Signal::SIGINT).unwrap();
}
let mut tmp = Vec::new();
listen.read_to_end(&mut tmp).ok();
connect.read_to_end(&mut tmp).ok();
}
#[cfg(unix)]
#[test]
fn connect_listen_ctrlc_listen() {
use std::time::Duration;
use nix::{
sys::signal::{self, Signal},
unistd::Pid,
};
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") .stderr_to_stdout()
.reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let mut connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") .stderr_null()
.stdout_capture()
.reader()
.unwrap();
std::thread::sleep(Duration::from_secs(1));
for pid in listen.pids() {
signal::kill(Pid::from_raw(pid as i32), Signal::SIGINT).unwrap();
}
let mut tmp = Vec::new();
listen.read_to_end(&mut tmp).ok();
connect.read_to_end(&mut tmp).ok();
}
#[test]
#[cfg(unix)]
#[ignore = "flaky"]
fn listen_tcp_happy() {
let b1 = wait2();
let b2 = b1.clone();
let port = random_port();
let host_port = format!("localhost:{port}");
let host_port_2 = host_port.clone();
std::thread::spawn(move || {
let server = TcpListener::bind(host_port_2).unwrap();
b1.wait();
let (mut stream, _addr) = server.accept().unwrap();
stream.write_all(b"hello from tcp").unwrap();
stream.flush().unwrap();
drop(stream);
});
b2.wait();
let mut listen_tcp = duct::cmd(dumbpipe_bin(), ["listen-tcp", "--host", &host_port])
.env_remove("RUST_LOG") .stderr_to_stdout() .reader()
.unwrap();
let header = read_ascii_lines(4, &mut listen_tcp).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") .stderr_null()
.stdout_capture()
.stdin_bytes(b"hello from connect")
.run()
.unwrap();
assert!(connect.status.success());
assert!(connect.stdout.starts_with(b"hello from tcp"));
}
#[test]
fn connect_tcp_happy() {
let port = random_port();
let host_port = format!("localhost:{port}");
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") .stdin_bytes(b"hello from listen\n")
.stderr_to_stdout() .reader()
.unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = EndpointTicket::from_str(ticket).unwrap();
let ticket = ticket.to_string();
let _connect_tcp = duct::cmd(
dumbpipe_bin(),
["connect-tcp", "--addr", &host_port, &ticket],
)
.env_remove("RUST_LOG") .stderr_to_stdout() .reader()
.unwrap();
std::thread::sleep(Duration::from_secs(1));
let mut conn = TcpStream::connect(host_port).unwrap();
conn.write_all(b"hello from tcp").unwrap();
conn.flush().unwrap();
let mut buf = Vec::new();
conn.read_to_end(&mut buf).unwrap();
assert_eq!(&buf, b"hello from listen\n");
}
#[cfg(all(test, unix))]
mod unix_socket_tests {
use std::{
io::{BufRead, Read, Write},
net::Shutdown,
os::unix::net::{UnixListener, UnixStream},
path::{Path, PathBuf},
sync::{Arc, Barrier},
time::{Duration, Instant},
};
use tempfile::TempDir;
use super::*;
fn wait_until<F>(timeout: Duration, mut condition: F)
where
F: FnMut() -> bool,
{
let deadline = Instant::now() + timeout;
while !condition() {
if Instant::now() >= deadline {
panic!("timeout waiting for condition");
}
std::thread::sleep(Duration::from_millis(25));
}
}
fn wait_for_path<P: AsRef<Path>>(path: P, timeout: Duration) {
let p = path.as_ref().to_path_buf();
wait_until(timeout, move || p.exists());
}
fn temp_socket_path() -> (TempDir, PathBuf) {
let temp_dir = tempfile::tempdir().unwrap();
let socket_path = temp_dir.path().join("test.sock");
(temp_dir, socket_path)
}
fn drain_stderr(
stderr: std::process::ChildStderr,
prefix: &'static str,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
let reader = std::io::BufReader::new(stderr);
for line in reader.lines().map_while(Result::ok) {
eprintln!("[{prefix}] {line}");
}
})
}
fn dummy_unix_server(
socket_path: PathBuf,
barrier: Arc<Barrier>,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
let _ = std::fs::remove_file(&socket_path);
let listener = UnixListener::bind(&socket_path).unwrap();
barrier.wait();
for stream in listener.incoming() {
if let Ok(mut stream) = stream {
std::thread::spawn(move || {
let mut buf = vec![0; 1024];
if let Ok(n) = stream.read(&mut buf) {
if n > 0 {
if stream.write_all(b"hello from unix").is_ok() {
stream.shutdown(Shutdown::Write).ok();
}
}
}
while stream.read(&mut buf).unwrap_or(0) > 0 {}
});
} else {
break;
}
}
})
}
#[test]
fn unix_socket_roundtrip() {
let (_tmp_dir, backend_sock) = temp_socket_path();
let client_sock = backend_sock.with_extension("client");
let barrier = Arc::new(Barrier::new(2));
let _backend_thread = dummy_unix_server(backend_sock.clone(), barrier.clone());
barrier.wait();
let deadline = Instant::now() + Duration::from_secs(5);
while Instant::now() < deadline {
if UnixStream::connect(&backend_sock).is_ok() {
break;
}
std::thread::sleep(Duration::from_millis(100));
}
if UnixStream::connect(&backend_sock).is_err() {
panic!("backend server not connectable after 5s");
}
let mut listen_proc = std::process::Command::new(dumbpipe_bin())
.args([
"listen-unix",
"--socket-path",
backend_sock.to_str().unwrap(),
])
.env_remove("RUST_LOG")
.stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::piped()) .spawn()
.expect("spawn listen-unix");
let listen_stderr = listen_proc.stderr.take().unwrap();
let mut ticket = String::new();
let mut stderr_reader = std::io::BufReader::new(listen_stderr);
for line in stderr_reader.by_ref().lines() {
let line = line.unwrap();
eprintln!("[listen-unix-stderr] {line}");
if line.contains("connect-unix") {
ticket = line.split_whitespace().last().unwrap().to_owned();
break;
}
}
assert!(!ticket.is_empty(), "Failed to get ticket");
let listen_stderr_thread = std::thread::spawn(move || {
for line in stderr_reader.lines().map_while(Result::ok) {
eprintln!("[listen-unix-stderr] {line}");
}
});
let mut connect_proc = std::process::Command::new(dumbpipe_bin())
.args([
"connect-unix",
"--socket-path",
client_sock.to_str().unwrap(),
&ticket,
])
.env_remove("RUST_LOG")
.stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::piped()) .spawn()
.expect("spawn connect-unix");
let connect_stderr = connect_proc.stderr.take().unwrap();
let connect_stderr_thread = drain_stderr(connect_stderr, "connect-unix-stderr");
wait_for_path(&client_sock, Duration::from_secs(5));
let mut client = UnixStream::connect(&client_sock).expect("connect to client socket");
client
.write_all(b"hello from client")
.expect("client write");
let mut reply = Vec::new();
client.read_to_end(&mut reply).expect("client read");
assert_eq!(&reply, b"hello from unix");
listen_proc.kill().ok();
listen_proc.wait().ok();
connect_proc.kill().ok();
connect_proc.wait().ok();
listen_stderr_thread.join().ok();
connect_stderr_thread.join().ok();
}
}