use crate::utils::buffer_pool::global;
use anyhow::Result;
use russh::Channel;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, trace, warn};
#[derive(Debug)]
pub struct TunnelStats {
pub bytes_local_to_remote: Arc<AtomicU64>,
pub bytes_remote_to_local: Arc<AtomicU64>,
pub started_at: Instant,
pub error_count: Arc<AtomicU64>,
}
impl Default for TunnelStats {
fn default() -> Self {
Self::new()
}
}
impl TunnelStats {
pub fn new() -> Self {
Self {
bytes_local_to_remote: Arc::new(AtomicU64::new(0)),
bytes_remote_to_local: Arc::new(AtomicU64::new(0)),
started_at: Instant::now(),
error_count: Arc::new(AtomicU64::new(0)),
}
}
pub fn total_bytes(&self) -> u64 {
self.bytes_local_to_remote.load(Ordering::Relaxed)
+ self.bytes_remote_to_local.load(Ordering::Relaxed)
}
pub fn duration(&self) -> std::time::Duration {
self.started_at.elapsed()
}
pub fn errors(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}
}
#[allow(dead_code)] pub struct Tunnel {
stats: TunnelStats,
cancel_token: CancellationToken,
}
impl Tunnel {
pub fn new() -> Self {
Self {
stats: TunnelStats::new(),
cancel_token: CancellationToken::new(),
}
}
pub async fn run(
mut tcp_stream: TcpStream,
mut ssh_channel: Channel<russh::client::Msg>,
cancel_token: CancellationToken,
) -> Result<TunnelStats> {
let stats = TunnelStats::new();
let mut buffer = global::get_medium_buffer();
debug!("Starting bidirectional tunnel");
loop {
tokio::select! {
result = tcp_stream.read(buffer.as_mut_slice()) => {
match result {
Ok(0) => {
trace!("TCP socket closed, ending tunnel");
break;
}
Ok(n) => {
match ssh_channel.data(&buffer.as_slice()[..n]).await {
Ok(_) => {
let bytes = stats.bytes_local_to_remote.fetch_add(n as u64, Ordering::Relaxed) + n as u64;
trace!("Forwarded {} bytes TCP→SSH (total: {})", n, bytes);
}
Err(e) => {
stats.error_count.fetch_add(1, Ordering::Relaxed);
error!("Failed to write to SSH channel: {}", e);
return Err(anyhow::anyhow!("SSH channel write error: {e}"));
}
}
}
Err(e) => {
stats.error_count.fetch_add(1, Ordering::Relaxed);
if e.kind() == std::io::ErrorKind::ConnectionAborted ||
e.kind() == std::io::ErrorKind::ConnectionReset {
trace!("TCP connection closed: {}", e);
break;
} else {
error!("TCP read error: {}", e);
return Err(anyhow::anyhow!("TCP read error: {e}"));
}
}
}
}
msg = ssh_channel.wait() => {
match msg {
Some(russh::ChannelMsg::Data { data }) => {
match tcp_stream.write_all(&data).await {
Ok(_) => {
let bytes = stats.bytes_remote_to_local.fetch_add(data.len() as u64, Ordering::Relaxed) + data.len() as u64;
trace!("Forwarded {} bytes SSH→TCP (total: {})", data.len(), bytes);
}
Err(e) => {
stats.error_count.fetch_add(1, Ordering::Relaxed);
if e.kind() == std::io::ErrorKind::BrokenPipe ||
e.kind() == std::io::ErrorKind::ConnectionAborted {
trace!("TCP connection closed: {}", e);
break;
} else {
error!("TCP write error: {}", e);
return Err(anyhow::anyhow!("TCP write error: {e}"));
}
}
}
}
Some(russh::ChannelMsg::Eof) => {
trace!("SSH channel EOF");
break;
}
Some(russh::ChannelMsg::Close) => {
trace!("SSH channel closed");
break;
}
Some(other) => {
trace!("Ignoring SSH channel message: {:?}", other);
}
None => {
trace!("SSH channel stream ended");
break;
}
}
}
_ = cancel_token.cancelled() => {
trace!("Tunnel cancelled");
break;
}
}
}
if let Err(e) = ssh_channel.eof().await {
warn!("Failed to send EOF to SSH channel: {}", e);
}
if let Err(e) = ssh_channel.close().await {
warn!("Failed to close SSH channel: {}", e);
}
let l2r_bytes = stats.bytes_local_to_remote.load(Ordering::Relaxed);
let r2l_bytes = stats.bytes_remote_to_local.load(Ordering::Relaxed);
let errors = stats.error_count.load(Ordering::Relaxed);
let duration = stats.duration();
debug!(
"Tunnel completed: {} bytes L→R, {} bytes R→L, {} errors, duration: {:?}",
l2r_bytes, r2l_bytes, errors, duration
);
Ok(stats)
}
pub async fn run_with_stats<F>(
tcp_stream: TcpStream,
ssh_channel: Channel<russh::client::Msg>,
cancel_token: CancellationToken,
mut stats_callback: F,
report_interval: std::time::Duration,
) -> Result<TunnelStats>
where
F: FnMut(&TunnelStats) + Send + 'static,
{
let stats = TunnelStats::new();
let stats_clone = TunnelStats {
bytes_local_to_remote: Arc::clone(&stats.bytes_local_to_remote),
bytes_remote_to_local: Arc::clone(&stats.bytes_remote_to_local),
started_at: stats.started_at,
error_count: Arc::clone(&stats.error_count),
};
let reporting_cancel = cancel_token.clone();
let stats_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(report_interval);
loop {
tokio::select! {
_ = interval.tick() => {
stats_callback(&stats_clone);
}
_ = reporting_cancel.cancelled() => {
stats_callback(&stats_clone);
break;
}
}
}
});
let result = Self::run(tcp_stream, ssh_channel, cancel_token).await;
stats_task.abort();
let _ = stats_task.await;
result
}
}
impl Default for Tunnel {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
#[test]
fn test_tunnel_stats() {
let stats = TunnelStats::new();
assert_eq!(stats.total_bytes(), 0);
assert_eq!(stats.errors(), 0);
stats.bytes_local_to_remote.store(100, Ordering::Relaxed);
stats.bytes_remote_to_local.store(200, Ordering::Relaxed);
stats.error_count.store(1, Ordering::Relaxed);
assert_eq!(stats.total_bytes(), 300);
assert_eq!(stats.errors(), 1);
}
#[tokio::test]
async fn test_tunnel_cancellation() {
let cancel_token = CancellationToken::new();
cancel_token.cancel();
tokio::select! {
_ = cancel_token.cancelled() => {
}
_ = sleep(Duration::from_millis(100)) => {
panic!("Cancellation should be immediate");
}
}
}
#[test]
fn test_stats_atomic_operations() {
let stats = TunnelStats::new();
let bytes = Arc::clone(&stats.bytes_local_to_remote);
bytes.fetch_add(50, Ordering::Relaxed);
bytes.fetch_add(25, Ordering::Relaxed);
assert_eq!(bytes.load(Ordering::Relaxed), 75);
assert_eq!(stats.total_bytes(), 75);
}
}