use std::io;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::process::{Child, Command, Stdio};
use std::time::{Duration, Instant};
use crate::connection::SshConfig;
pub const READY_TIMEOUT: Duration = Duration::from_secs(8);
#[derive(Debug)]
pub struct SshTunnel {
child: Child,
local_port: u16,
target: String,
}
impl SshTunnel {
pub async fn spawn_async(
config: SshConfig,
target_host: String,
target_port: u16,
) -> io::Result<Self> {
tokio::task::spawn_blocking(move || Self::spawn(&config, &target_host, target_port))
.await
.map_err(|e| io::Error::other(format!("ssh tunnel spawn task panicked: {e}")))?
}
pub fn spawn(config: &SshConfig, target_host: &str, target_port: u16) -> io::Result<Self> {
let local_port = pick_free_port()?;
let target = format!("{target_host}:{target_port}");
let bind_spec = format!("127.0.0.1:{local_port}:{target}");
let mut cmd = Command::new("ssh");
cmd.arg("-N") .arg("-T") .arg("-o")
.arg("ExitOnForwardFailure=yes")
.arg("-o")
.arg("ServerAliveInterval=30")
.arg("-o")
.arg("ServerAliveCountMax=3")
.arg("-o")
.arg("StrictHostKeyChecking=accept-new")
.arg("-L")
.arg(&bind_spec);
if let Some(port) = config.port {
cmd.arg("-p").arg(port.to_string());
}
if let Some(key) = config.key_path.as_ref() {
cmd.arg("-i").arg(key);
}
if let Some(jump) = config.jump_host.as_ref() {
cmd.arg("-J").arg(jump);
}
let user_at_host = format!("{}@{}", config.user, config.host);
cmd.arg(&user_at_host);
cmd.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::piped());
let child = cmd.spawn().map_err(|e| {
io::Error::new(
io::ErrorKind::NotFound,
format!("could not spawn ssh: {e} (is the OpenSSH client installed?)"),
)
})?;
let mut tunnel = Self {
child,
local_port,
target: user_at_host,
};
tunnel.wait_for_ready()?;
Ok(tunnel)
}
pub const fn local_host(&self) -> &'static str {
"127.0.0.1"
}
pub const fn local_port(&self) -> u16 {
self.local_port
}
fn wait_for_ready(&mut self) -> io::Result<()> {
let addr: SocketAddr = format!("127.0.0.1:{}", self.local_port)
.parse()
.expect("127.0.0.1:<u16> is always a valid SocketAddr");
let deadline = Instant::now() + READY_TIMEOUT;
loop {
if let Ok(stream) = TcpStream::connect_timeout(&addr, Duration::from_millis(250)) {
drop(stream);
return Ok(());
}
if let Ok(Some(status)) = self.child.try_wait() {
let stderr = self.drain_stderr();
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!(
"ssh tunnel to {} exited ({status}) before the port was ready: {}",
self.target,
stderr.trim()
),
));
}
if Instant::now() >= deadline {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
format!(
"ssh tunnel to {} did not accept connections within {:?}",
self.target, READY_TIMEOUT
),
));
}
std::thread::sleep(Duration::from_millis(100));
}
}
fn drain_stderr(&mut self) -> String {
use std::io::Read;
let mut buf = String::new();
if let Some(mut err) = self.child.stderr.take() {
let _ = err.read_to_string(&mut buf);
}
buf
}
}
impl Drop for SshTunnel {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
fn pick_free_port() -> io::Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
Ok(port)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pick_free_port_yields_bindable_port() {
let port = pick_free_port().unwrap();
let _l = TcpListener::bind(("127.0.0.1", port)).unwrap();
}
#[tokio::test(flavor = "current_thread")]
async fn spawn_async_does_not_block_current_thread_runtime() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));
let stop = Arc::new(AtomicUsize::new(0));
let c = Arc::clone(&counter);
let s = Arc::clone(&stop);
let bumper = tokio::spawn(async move {
while s.load(Ordering::Relaxed) == 0 {
c.fetch_add(1, Ordering::Relaxed);
tokio::task::yield_now().await;
}
});
let cfg = SshConfig::new("192.0.2.1", "nobody");
let outcome = SshTunnel::spawn_async(cfg, "127.0.0.1".into(), 1).await;
assert!(outcome.is_err(), "expected failure, got: {outcome:?}");
stop.store(1, Ordering::Relaxed);
let _ = bumper.await;
let n = counter.load(Ordering::Relaxed);
assert!(n > 10, "expected concurrent progress, got {n} increments");
}
#[test]
fn spawn_fails_fast_against_unreachable_host() {
let cfg = SshConfig::new("192.0.2.1", "nobody");
let start = Instant::now();
let outcome = SshTunnel::spawn(&cfg, "127.0.0.1", 1);
let elapsed = start.elapsed();
assert!(outcome.is_err(), "expected failure, got: {outcome:?}");
assert!(
elapsed <= READY_TIMEOUT + Duration::from_secs(2),
"spawn took {elapsed:?}, expected <= {READY_TIMEOUT:?}"
);
}
}