use std::net::TcpListener;
use std::process::{Child, Command, Stdio};
use std::time::Duration;
use super::error::HostError;
use super::schema::HostConfig;
pub struct SshTunnel {
child: Child,
local_port: u16,
host_name: String,
}
impl SshTunnel {
pub fn new(host: &HostConfig, host_name: &str) -> Result<Self, HostError> {
let local_port = find_available_port()?;
let mut cmd = Command::new("ssh");
cmd.arg("-L")
.arg(format!("{local_port}:/var/run/docker.sock"));
cmd.arg("-N");
cmd.arg("-o").arg("BatchMode=yes");
cmd.arg("-o").arg("StrictHostKeyChecking=accept-new");
cmd.arg("-o").arg("ConnectTimeout=10");
cmd.arg("-o").arg("RequestTTY=no");
if let Some(jump) = &host.jump_host {
cmd.arg("-J").arg(jump);
}
if let Some(key) = &host.identity_file {
cmd.arg("-i").arg(key);
}
if let Some(port) = host.port {
cmd.arg("-p").arg(port.to_string());
}
cmd.arg(format!("{}@{}", host.user, host.hostname));
cmd.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::piped());
tracing::debug!(
"Spawning SSH tunnel: ssh -L {}:/var/run/docker.sock {}@{}",
local_port,
host.user,
host.hostname
);
let child = cmd.spawn().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
HostError::SshSpawn("SSH not found. Install OpenSSH client.".to_string())
} else {
HostError::SshSpawn(e.to_string())
}
})?;
Ok(Self {
child,
local_port,
host_name: host_name.to_string(),
})
}
pub fn local_port(&self) -> u16 {
self.local_port
}
pub fn docker_url(&self) -> String {
format!("tcp://127.0.0.1:{}", self.local_port)
}
pub fn host_name(&self) -> &str {
&self.host_name
}
pub async fn wait_ready(&self) -> Result<(), HostError> {
let max_attempts = 3;
let initial_delay_ms = 100;
for attempt in 0..max_attempts {
if attempt > 0 {
let delay = Duration::from_millis(initial_delay_ms * 2u64.pow(attempt));
tracing::debug!("Tunnel wait attempt {} after {:?}", attempt + 1, delay);
tokio::time::sleep(delay).await;
}
match std::net::TcpStream::connect_timeout(
&format!("127.0.0.1:{}", self.local_port).parse().unwrap(),
Duration::from_secs(1),
) {
Ok(_) => {
tracing::debug!("SSH tunnel ready on port {}", self.local_port);
return Ok(());
}
Err(e) => {
tracing::debug!("Tunnel not ready: {}", e);
}
}
}
Err(HostError::TunnelTimeout(max_attempts))
}
pub fn is_alive(&mut self) -> bool {
matches!(self.child.try_wait(), Ok(None))
}
}
impl Drop for SshTunnel {
fn drop(&mut self) {
tracing::debug!(
"Cleaning up SSH tunnel to {} (port {})",
self.host_name,
self.local_port
);
if let Err(e) = self.child.kill() {
tracing::debug!("SSH tunnel kill result: {}", e);
}
let _ = self.child.wait();
}
}
fn find_available_port() -> Result<u16, HostError> {
let listener =
TcpListener::bind("127.0.0.1:0").map_err(|e| HostError::PortAllocation(e.to_string()))?;
let port = listener
.local_addr()
.map_err(|e| HostError::PortAllocation(e.to_string()))?
.port();
drop(listener);
Ok(port)
}
pub async fn test_connection(host: &HostConfig) -> Result<String, HostError> {
let mut cmd = Command::new("ssh");
cmd.arg("-o")
.arg("BatchMode=yes")
.arg("-o")
.arg("ConnectTimeout=10")
.arg("-o")
.arg("StrictHostKeyChecking=accept-new");
cmd.args(host.ssh_args());
cmd.arg("docker")
.arg("version")
.arg("--format")
.arg("{{.Server.Version}}");
cmd.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
let output = cmd.output().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
HostError::SshSpawn("SSH not found. Install OpenSSH client.".to_string())
} else {
HostError::SshSpawn(e.to_string())
}
})?;
if output.status.success() {
let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
tracing::info!("Docker version on remote: {}", version);
Ok(version)
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
if stderr.contains("Permission denied") || stderr.contains("Host key verification failed") {
return Err(HostError::AuthFailed {
key_hint: host.identity_file.clone(),
});
}
if stderr.contains("command not found") || stderr.contains("not found") {
return Err(HostError::RemoteDockerUnavailable(
"Docker is not installed on remote host".to_string(),
));
}
Err(HostError::ConnectionFailed(stderr.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
fn can_bind_localhost() -> bool {
TcpListener::bind(("127.0.0.1", 0)).is_ok()
}
#[test]
fn test_find_available_port() {
if !can_bind_localhost() {
eprintln!("Skipping test: cannot bind to localhost in this environment.");
return;
}
let port = find_available_port().unwrap();
assert!(port > 0);
let listener = TcpListener::bind(format!("127.0.0.1:{port}"));
assert!(listener.is_ok());
}
#[test]
fn test_docker_url_format() {
let url = format!("tcp://127.0.0.1:{}", 12345);
assert_eq!(url, "tcp://127.0.0.1:12345");
}
}