use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::ssh::SshClient;
pub struct SshTunnel {
local_port: u16,
cancel: CancellationToken,
_accept_task: JoinHandle<()>,
}
impl SshTunnel {
pub fn local_port(&self) -> u16 {
self.local_port
}
pub async fn open(
ssh_client: Arc<RwLock<SshClient>>,
remote_host: String,
remote_port: u16,
) -> anyhow::Result<Self> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let local_port = listener.local_addr()?.port();
let cancel = CancellationToken::new();
let task_cancel = cancel.clone();
let accept_task = tokio::spawn(async move {
run_accept_loop(listener, ssh_client, remote_host, remote_port, task_cancel).await;
});
Ok(Self {
local_port,
cancel,
_accept_task: accept_task,
})
}
}
impl Drop for SshTunnel {
fn drop(&mut self) {
self.cancel.cancel();
}
}
async fn run_accept_loop(
listener: TcpListener,
ssh_client: Arc<RwLock<SshClient>>,
remote_host: String,
remote_port: u16,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("postgres tunnel accept loop cancelled");
return;
}
res = listener.accept() => {
match res {
Ok((local_stream, peer)) => {
let ssh_client = ssh_client.clone();
let remote_host = remote_host.clone();
let conn_cancel = cancel.clone();
tokio::spawn(async move {
if let Err(e) = forward_one(
local_stream,
ssh_client,
&remote_host,
remote_port,
conn_cancel,
)
.await
{
tracing::warn!(
peer = %peer,
error = %e,
"postgres tunnel forwarder ended with error"
);
}
});
}
Err(e) => {
tracing::warn!("postgres tunnel accept failed: {e}");
tokio::task::yield_now().await;
}
}
}
}
}
}
async fn forward_one(
mut local_stream: tokio::net::TcpStream,
ssh_client: Arc<RwLock<SshClient>>,
remote_host: &str,
remote_port: u16,
cancel: CancellationToken,
) -> anyhow::Result<()> {
let channel = {
let guard = ssh_client.read().await;
guard.open_direct_tcpip(remote_host, remote_port).await?
};
let mut stream = channel.into_stream();
let (mut local_read, mut local_write) = local_stream.split();
let (mut ssh_read, mut ssh_write) = tokio::io::split(&mut stream);
let local_to_ssh = async {
let r = tokio::io::copy(&mut local_read, &mut ssh_write).await;
let _ = ssh_write.shutdown().await;
r
};
let ssh_to_local = async {
let r = tokio::io::copy(&mut ssh_read, &mut local_write).await;
let _ = local_write.shutdown().await;
r
};
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("postgres tunnel forwarder cancelled");
Ok(())
}
res = async {
tokio::try_join!(local_to_ssh, ssh_to_local).map(|_| ())
} => {
res.map_err(anyhow::Error::from)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn open_binds_local_port_immediately() {
use crate::ssh::HostKeyStore;
let host_keys = Arc::new(HostKeyStore::new(
std::env::temp_dir().join("r-shell-tunnel-test-known-hosts"),
));
let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
.await
.expect("bind should succeed");
assert!(tunnel.local_port() > 0);
let probe = tokio::net::TcpStream::connect(("127.0.0.1", tunnel.local_port())).await;
assert!(probe.is_ok(), "listener should accept connections");
}
#[tokio::test]
async fn drop_releases_local_port() {
use crate::ssh::HostKeyStore;
let host_keys = Arc::new(HostKeyStore::new(
std::env::temp_dir().join("r-shell-tunnel-test-known-hosts-2"),
));
let client = Arc::new(RwLock::new(SshClient::new(host_keys)));
let tunnel = SshTunnel::open(client, "irrelevant".to_string(), 5432)
.await
.expect("bind");
let port = tunnel.local_port();
drop(tunnel);
tokio::task::yield_now().await;
let rebind = TcpListener::bind(("127.0.0.1", port)).await;
assert!(rebind.is_ok(), "port {port} should be reusable after drop");
}
}