use anyhow::{Context, Result};
use russh::client::{self, Handle, Msg};
use russh::keys::*;
use russh::*;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct ReverseSshConfig {
pub server_addr: String,
pub server_port: u16,
pub username: String,
pub key_path: Option<String>,
pub password: Option<String>,
pub bind_address: String,
pub remote_port: u32,
pub local_addr: String,
pub local_port: u16,
}
struct Client {
tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
message_tx: mpsc::UnboundedSender<String>,
}
#[async_trait::async_trait]
impl client::Handler for Client {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
async fn server_channel_open_forwarded_tcpip(
&mut self,
channel: Channel<Msg>,
connected_address: &str,
connected_port: u32,
originator_address: &str,
originator_port: u32,
_session: &mut client::Session,
) -> Result<(), Self::Error> {
debug!(
"Forwarded channel: {}:{} -> {}:{}",
originator_address, originator_port, connected_address, connected_port
);
let _ = self
.tx
.send((channel, connected_address.to_string(), connected_port));
Ok(())
}
async fn data(
&mut self,
_channel: ChannelId,
data: &[u8],
_session: &mut client::Session,
) -> Result<(), Self::Error> {
if let Ok(message) = String::from_utf8(data.to_vec()) {
debug!("Received data ({} bytes): {}", data.len(), message);
let _ = self.message_tx.send(message);
} else {
debug!(
"Received {} bytes of non-UTF8 data on channel {:?}",
data.len(),
_channel
);
}
Ok(())
}
async fn extended_data(
&mut self,
_channel: ChannelId,
ext: u32,
data: &[u8],
_session: &mut client::Session,
) -> Result<(), Self::Error> {
if let Ok(message) = String::from_utf8(data.to_vec()) {
info!("Received extended data (type {}): {}", ext, message);
let _ = self.message_tx.send(message);
}
debug!(
"Received {} bytes of extended data (type {}) on channel {:?}",
data.len(),
ext,
_channel
);
Ok(())
}
}
impl Client {
fn new(
tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
message_tx: mpsc::UnboundedSender<String>,
) -> Self {
Self { tx, message_tx }
}
}
pub struct ReverseSshClient {
config: ReverseSshConfig,
handle: Option<Handle<Client>>,
}
impl ReverseSshClient {
pub fn new(config: ReverseSshConfig) -> Self {
Self {
config,
handle: None,
}
}
pub async fn connect(
&mut self,
tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
message_tx: mpsc::UnboundedSender<String>,
) -> Result<()> {
info!(
"Connecting to SSH server {}:{}",
self.config.server_addr, self.config.server_port
);
let client_config = client::Config {
inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
..<_>::default()
};
let client_handler = Client::new(tx, message_tx);
let mut session = client::connect(
Arc::new(client_config),
(self.config.server_addr.as_str(), self.config.server_port),
client_handler,
)
.await
.context("Failed to connect to SSH server")?;
let auth_result = if let Some(key_path) = &self.config.key_path {
info!("Authenticating with private key: {}", key_path);
let key_pair = russh_keys::load_secret_key(key_path, None)
.context("Failed to load private key")?;
session
.authenticate_publickey(&self.config.username, Arc::new(key_pair))
.await
} else if let Some(password) = &self.config.password {
info!("Authenticating with password");
session
.authenticate_password(&self.config.username, password)
.await
} else {
anyhow::bail!("No authentication method provided (need key_path or password)");
};
if !auth_result.context("Authentication failed")? {
anyhow::bail!("Authentication rejected by server");
}
info!("Successfully authenticated to SSH server");
self.handle = Some(session);
Ok(())
}
pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
let handle = self
.handle
.as_mut()
.context("Not connected - call connect() first")?;
if self.config.bind_address.is_empty() {
info!(
"Setting up reverse tunnel: server port {} -> local {}:{}",
self.config.remote_port, self.config.local_addr, self.config.local_port
);
} else {
info!(
"Setting up reverse tunnel: {}:{} -> local {}:{}",
self.config.bind_address,
self.config.remote_port,
self.config.local_addr,
self.config.local_port
);
}
handle
.tcpip_forward(&self.config.bind_address, self.config.remote_port)
.await
.context("Failed to set up remote port forwarding")?;
info!("Reverse tunnel established successfully");
match handle.channel_open_session().await {
Ok(channel) => {
info!("Opened shell session to receive server messages");
if let Err(e) = channel.request_shell(false).await {
warn!("Failed to request shell: {}", e);
} else {
debug!("Shell requested successfully");
}
}
Err(e) => {
warn!(
"Could not open shell session: {} (this may be normal for some servers)",
e
);
}
}
Ok(())
}
#[allow(dead_code)]
pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
let handle = self
.handle
.as_mut()
.context("Not connected - call connect() first")?;
let mut messages = Vec::new();
match handle.channel_open_session().await {
Ok(channel) => {
let _ = channel.request_shell(false).await;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let _ = channel.eof().await;
let _ = channel.close().await;
messages.push("Check SSH session output for connection URL".to_string());
}
Err(e) => {
warn!("Could not open session channel: {}", e);
}
}
Ok(messages)
}
pub async fn handle_forwarded_connections(
&mut self,
mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>,
) -> Result<()> {
info!("Waiting for forwarded connections...");
while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
info!("New forwarded connection received");
let local_addr = self.config.local_addr.clone();
let local_port = self.config.local_port;
tokio::spawn(async move {
if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
error!("Error handling connection: {}", e);
}
});
}
warn!("Connection closed by server");
Ok(())
}
#[allow(dead_code)]
pub async fn run(&mut self) -> Result<()> {
let (tx, rx) = mpsc::unbounded_channel();
let (message_tx, mut message_rx) = mpsc::unbounded_channel();
self.connect(tx, message_tx).await?;
self.setup_reverse_tunnel().await?;
tokio::spawn(async move {
while let Some(message) = message_rx.recv().await {
if !message.trim().is_empty() {
println!("[Server] {}", message.trim());
}
}
});
self.handle_forwarded_connections(rx).await?;
Ok(())
}
pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
where
F: FnMut(String) + Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (message_tx, mut message_rx) = mpsc::unbounded_channel();
self.connect(tx, message_tx).await?;
self.setup_reverse_tunnel().await?;
tokio::spawn(async move {
while let Some(message) = message_rx.recv().await {
message_handler(message);
}
});
self.handle_forwarded_connections(rx).await?;
Ok(())
}
}
async fn handle_connection(
mut channel: Channel<Msg>,
local_addr: &str,
local_port: u16,
) -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
info!("Connecting to local service {}:{}", local_addr, local_port);
let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
.parse()
.context("Invalid local address")?;
let mut local_stream = TcpStream::connect(local_socket_addr)
.await
.context("Failed to connect to local service")?;
info!("Connected to local service, starting bidirectional proxy");
let mut local_buf = vec![0u8; 8192];
loop {
tokio::select! {
msg = channel.wait() => {
match msg {
Some(russh::ChannelMsg::Data { data }) => {
debug!("Received {} bytes from SSH channel", data.len());
if let Err(e) = local_stream.write_all(&data).await {
error!("Failed to write to local service: {}", e);
break;
}
}
Some(russh::ChannelMsg::Eof) => {
debug!("Received EOF from SSH channel");
let _ = local_stream.shutdown().await;
break;
}
Some(russh::ChannelMsg::Close) => {
debug!("SSH channel closed");
break;
}
Some(other) => {
debug!("Received other channel message: {:?}", other);
}
None => {
debug!("SSH channel receiver closed");
break;
}
}
}
result = local_stream.read(&mut local_buf) => {
match result {
Ok(0) => {
debug!("Local connection closed");
break;
}
Ok(n) => {
debug!("Read {} bytes from local service", n);
if let Err(e) = channel.data(&local_buf[..n]).await {
error!("Failed to send data to SSH channel: {}", e);
break;
}
}
Err(e) => {
error!("Error reading from local service: {}", e);
break;
}
}
}
}
}
let _ = channel.eof().await;
let _ = channel.close().await;
info!("Connection proxy closed");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = ReverseSshConfig {
server_addr: "example.com".to_string(),
server_port: 22,
username: "user".to_string(),
key_path: Some("/path/to/key".to_string()),
password: None,
bind_address: String::new(),
remote_port: 8080,
local_addr: "127.0.0.1".to_string(),
local_port: 3000,
};
assert_eq!(config.server_addr, "example.com");
assert_eq!(config.remote_port, 8080);
assert!(config.bind_address.is_empty());
}
#[test]
fn test_config_with_bind_address() {
let config = ReverseSshConfig {
server_addr: "tuns.sh".to_string(),
server_port: 22,
username: "myuser".to_string(),
key_path: Some("/path/to/key".to_string()),
password: None,
bind_address: "dev".to_string(),
remote_port: 80,
local_addr: "127.0.0.1".to_string(),
local_port: 8000,
};
assert_eq!(config.server_addr, "tuns.sh");
assert_eq!(config.bind_address, "dev");
assert_eq!(config.remote_port, 80);
}
}