use crate::config::{AuthConfig, ChannelConfig, ChannelTypeParams, ReconnectionConfig};
use crate::error::{AppError, Result};
use backon::{ExponentialBuilder, Retryable};
use russh::*;
use russh_keys::key::KeyPair;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
#[derive(Clone)]
struct ClientHandler;
#[async_trait::async_trait]
impl client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh_keys::key::PublicKey,
) -> std::result::Result<bool, Self::Error> {
Ok(true) }
}
#[derive(Clone)]
struct ReverseForwardHandler {
channel_name: String,
local_host: String,
local_port: u16,
}
#[async_trait::async_trait]
impl client::Handler for ReverseForwardHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh_keys::key::PublicKey,
) -> std::result::Result<bool, Self::Error> {
Ok(true)
}
async fn server_channel_open_forwarded_tcpip(
&mut self,
channel: russh::Channel<russh::client::Msg>,
_connected_address: &str,
_connected_port: u32,
_originator_address: &str,
_originator_port: u32,
_session: &mut russh::client::Session,
) -> std::result::Result<(), Self::Error> {
let local_addr = format!("{}:{}", self.local_host, self.local_port);
let channel_name = self.channel_name.clone();
match TcpStream::connect(&local_addr).await {
Ok(mut stream) => {
let mut channel_stream = channel.into_stream();
tokio::spawn(async move {
if let Err(e) = tokio::io::copy_bidirectional(&mut stream, &mut channel_stream).await {
debug!(channel = %channel_name, error = ?e, "Forwarded-tcpip relay ended");
}
});
}
Err(e) => {
error!(
channel = %channel_name,
local = %local_addr,
error = ?e,
"Failed to connect to local address for forwarded-tcpip"
);
}
}
Ok(())
}
}
pub struct SshManager {
config: ChannelConfig,
reconnection_config: ReconnectionConfig,
shutdown_tx: Option<mpsc::Sender<()>>,
cancellation_token: Option<CancellationToken>,
}
impl SshManager {
pub fn new(config: ChannelConfig, reconnection_config: ReconnectionConfig) -> Self {
Self {
config,
reconnection_config,
shutdown_tx: None,
cancellation_token: None,
}
}
pub async fn start(&mut self) -> Result<()> {
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
let cancel = CancellationToken::new();
self.cancellation_token = Some(cancel.clone());
self.shutdown_tx = Some(shutdown_tx);
let config = self.config.clone();
let reconnection_config = self.reconnection_config.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!(channel = %config.name, "Shutting down SSH manager");
break;
}
_ = cancel.cancelled() => break,
result = Self::connect_and_manage_channel(&config, &reconnection_config, cancel.clone()) => {
match result {
Ok(_) => {
warn!(channel = %config.name, "Connection closed unexpectedly");
}
Err(e) => {
error!(channel = %config.name, error = ?e, "Connection error");
}
}
}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
if let Some(token) = self.cancellation_token.take() {
token.cancel();
}
Ok(())
}
async fn connect_and_manage_channel(
config: &ChannelConfig,
reconnection_config: &ReconnectionConfig,
cancel: CancellationToken,
) -> Result<()> {
let builder = if reconnection_config.use_exponential_backoff {
let mut builder = ExponentialBuilder::default()
.with_min_delay(Duration::from_secs(reconnection_config.initial_delay_secs))
.with_max_delay(Duration::from_secs(reconnection_config.max_delay_secs));
if reconnection_config.max_retries > 0 {
builder = builder.with_max_times(reconnection_config.max_retries as usize);
}
builder
} else {
let mut builder = ExponentialBuilder::default()
.with_min_delay(Duration::from_secs(reconnection_config.initial_delay_secs))
.with_max_delay(Duration::from_secs(reconnection_config.initial_delay_secs));
if reconnection_config.max_retries > 0 {
builder = builder.with_max_times(reconnection_config.max_retries as usize);
}
builder
};
(|| async { Self::establish_connection(config, cancel.clone()).await })
.retry(&builder)
.await
.map_err(|e| AppError::SshConnection(format!("Failed to establish connection: {}", e)))
}
async fn establish_connection(config: &ChannelConfig, cancel: CancellationToken) -> Result<()> {
info!(
channel = %config.name,
host = %config.host,
port = config.port,
"Establishing SSH connection"
);
match &config.params {
ChannelTypeParams::ForwardedTcpIp { .. } => run_forwarded_tcpip(config, cancel).await,
ChannelTypeParams::DirectTcpIp { .. } => {
let mut session = connect_and_authenticate(config, ClientHandler).await?;
info!(channel = %config.name, "Opening channel");
run_direct_tcpip_listener(&mut session, config, cancel).await
}
}
}
}
async fn run_forwarded_tcpip(config: &ChannelConfig, cancel: CancellationToken) -> Result<()> {
let ChannelTypeParams::ForwardedTcpIp {
remote_bind_host,
remote_bind_port,
local_connect_host,
local_connect_port,
} = &config.params
else {
return Err(AppError::SshChannel(
"run_forwarded_tcpip expects ForwardedTcpIp params".to_string(),
));
};
let handler = ReverseForwardHandler {
channel_name: config.name.clone(),
local_host: local_connect_host.clone(),
local_port: *local_connect_port,
};
let mut session = connect_and_authenticate(config, handler).await?;
info!(
channel = %config.name,
remote_bind = %format!("{}:{}", remote_bind_host, remote_bind_port),
"Requesting remote port forward (tcpip-forward)"
);
let bound_port = session
.tcpip_forward(remote_bind_host.as_str(), *remote_bind_port as u32)
.await
.map_err(|e| AppError::SshChannel(format!("tcpip-forward failed: {}", e)))?;
let actual_port = if bound_port == 0 {
*remote_bind_port
} else {
bound_port as u16
};
info!(
channel = %config.name,
remote = %format!("{}:{}", remote_bind_host, actual_port),
local = %format!("{}:{}", local_connect_host, local_connect_port),
"Remote forward active (incoming connections will be bridged to local)"
);
tokio::select! {
_ = cancel.cancelled() => {
info!(channel = %config.name, "Forward cancelled");
Ok(())
}
result = &mut session => {
result.map_err(|e| AppError::SshConnection(format!("Session ended: {}", e)))
}
}
}
async fn connect_and_authenticate<H>(
config: &ChannelConfig,
handler: H,
) -> Result<client::Handle<H>>
where
H: client::Handler + Send + 'static,
{
let config_arc = Arc::new(russh::client::Config {
keepalive_interval: Some(Duration::from_secs(15)),
keepalive_max: 3,
..Default::default()
});
let mut session =
russh::client::connect(config_arc, (config.host.as_str(), config.port), handler)
.await
.map_err(|e| AppError::SshConnection(format!("Failed to connect: {:?}", e)))?;
info!(channel = %config.name, "SSH connection established, authenticating");
match &config.auth {
AuthConfig::Password { password } => {
session
.authenticate_password(&config.username, password)
.await
.map_err(|e| {
AppError::SshAuthentication(format!("Password authentication failed: {}", e))
})?;
}
AuthConfig::Key {
key_path,
passphrase,
} => {
let key = load_secret_key(key_path, passphrase.as_deref()).await?;
session
.authenticate_publickey(&config.username, Arc::new(key))
.await
.map_err(|e| AppError::SshAuthentication(format!("Key authentication failed: {}", e)))?;
}
}
info!(channel = %config.name, "Authentication successful");
Ok(session)
}
async fn load_secret_key(key_path: &Path, passphrase: Option<&str>) -> Result<KeyPair> {
let key_path = key_path.to_path_buf();
let passphrase = passphrase.map(|s| s.to_string());
tokio::task::spawn_blocking(move || {
let key_data = std::fs::read_to_string(&key_path).map_err(AppError::Io)?;
let key_result = if let Some(passphrase) = passphrase {
russh_keys::decode_secret_key(&key_data, Some(&passphrase))
} else {
russh_keys::decode_secret_key(&key_data, None)
};
key_result.map_err(|e| AppError::SshAuthentication(format!("Failed to decode key: {}", e)))
})
.await
.map_err(|e| AppError::SshAuthentication(format!("Task join error: {}", e)))?
}
async fn run_direct_tcpip_listener(
session: &mut client::Handle<ClientHandler>,
config: &ChannelConfig,
cancel: CancellationToken,
) -> Result<()> {
let ChannelTypeParams::DirectTcpIp {
listen_host,
local_port,
dest_host,
dest_port,
} = &config.params
else {
return Err(AppError::SshChannel(
"run_direct_tcpip_listener expects DirectTcpIp params".to_string(),
));
};
let listen_addr = format!("{}:{}", listen_host, local_port);
let listener = TcpListener::bind(&listen_addr).await.map_err(|e| {
AppError::SshChannel(format!(
"Failed to bind {}: {}. Try another port or run as admin for port < 1024.",
listen_addr, e
))
})?;
info!(
channel = %config.name,
listen = %listen_addr,
"Local listener started, accepting connections"
);
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!(channel = %config.name, "Listener cancelled");
return Ok(());
}
result = &mut *session => {
let reason = result.map_err(|e| e.to_string())
.err()
.unwrap_or_else(|| "connection closed".to_string());
warn!(
channel = %config.name,
reason = %reason,
"SSH session ended, triggering reconnection"
);
return Err(AppError::SshConnection(
format!("SSH session ended: {}", reason)
));
}
accept_result = listener.accept() => {
let (mut stream, peer_addr) = match accept_result {
Ok(x) => x,
Err(e) => {
error!(channel = %config.name, error = ?e, "Accept failed");
continue;
}
};
let channel_name = config.name.clone();
let dest_host = dest_host.clone();
let dest_port = *dest_port;
match session.channel_open_direct_tcpip(
&dest_host,
dest_port as u32,
"127.0.0.1",
0u32,
).await {
Ok(channel) => {
debug!(
channel = %channel_name,
peer = %peer_addr,
dest = %format!("{}:{}", dest_host, dest_port),
"Direct TCP/IP channel opened for connection"
);
let mut channel_stream = channel.into_stream();
tokio::spawn(async move {
if let Err(e) =
tokio::io::copy_bidirectional(&mut stream, &mut channel_stream).await
{
debug!(channel = %channel_name, error = ?e, "Relay ended");
}
});
}
Err(e @ Error::ChannelOpenFailure(_)) => {
error!(
channel = %channel_name,
peer = %peer_addr,
error = ?e,
"Channel open refused by server (connection alive)"
);
}
Err(e) => {
error!(
channel = %channel_name,
error = ?e,
"SSH session dead detected via channel_open, triggering reconnection"
);
return Err(AppError::SshConnection(
format!("SSH session dead: {}", e)
));
}
}
}
}
}
}