use futures::FutureExt;
use getopts::Options;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
type BoxedError = Box<dyn std::error::Error + Sync + Send + 'static>;
static DEBUG: AtomicBool = AtomicBool::new(false);
const BUF_SIZE: usize = 1024;
fn print_usage(out: &mut dyn std::io::Write, program: &str, opts: Options) {
let program_path = std::path::PathBuf::from(program);
let program_name = program_path.file_stem().unwrap().to_string_lossy();
let brief = format!(
"Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name
);
_ = write!(out, "{}", opts.usage(&brief));
}
#[tokio::main]
async fn main() -> Result<(), BoxedError> {
let args: Vec<String> = env::args().collect();
let program = args[0].clone();
let mut opts = Options::new();
opts.optopt(
"b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR",
);
opts.optopt(
"l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT",
);
opts.optflag("d", "debug", "Enable debug mode w/ connection logging");
opts.optflag("h", "help", "Print usage info and exit");
opts.optflag("V", "version", "Print version info and exit");
let matches = match opts.parse(&args[1..]) {
Ok(opts) => opts,
Err(e) => {
eprintln!("{}", e);
print_usage(&mut std::io::stderr().lock(), &program, opts);
std::process::exit(-1);
}
};
if matches.opt_present("h") {
print_usage(&mut std::io::stdout().lock(), &program, opts);
std::process::exit(0);
}
if matches.opt_present("V") {
println!(
"tcpproxy {} - {}",
env!("CARGO_PKG_VERSION"),
env!("CARGO_PKG_REPOSITORY")
);
std::process::exit(0);
}
let remote = match matches.free.len() {
1 => matches.free[0].clone(),
_ => {
print_usage(&mut std::io::stderr().lock(), &program, opts);
std::process::exit(-1);
}
};
if !remote.contains(':') {
eprintln!("A remote port is required (REMOTE_ADDR:PORT)");
std::process::exit(-1);
}
DEBUG.store(matches.opt_present("d"), Ordering::Relaxed);
let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?;
let bind_addr = match matches.opt_str("b") {
Some(addr) => addr,
None => "127.0.0.1".to_owned(),
};
forward(&bind_addr, local_port, remote).await
}
async fn forward(bind_ip: &str, local_port: i32, remote: String) -> Result<(), BoxedError> {
let bind_addr = if !bind_ip.starts_with('[') && bind_ip.contains(':') {
format!("[{}]:{}", bind_ip, local_port)
} else {
format!("{}:{}", bind_ip, local_port)
};
let bind_sock = bind_addr
.parse::<std::net::SocketAddr>()
.expect("Failed to parse bind address");
let listener = TcpListener::bind(&bind_sock).await?;
println!("Listening on {}", listener.local_addr().unwrap());
let remote: &str = Box::leak(remote.into_boxed_str());
async fn copy_with_abort<R, W>(
read: &mut R,
write: &mut W,
mut abort: broadcast::Receiver<()>,
) -> tokio::io::Result<usize>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut copied = 0;
let mut buf = [0u8; BUF_SIZE];
loop {
let bytes_read;
tokio::select! {
biased;
result = read.read(&mut buf) => {
use std::io::ErrorKind::{ConnectionReset, ConnectionAborted};
bytes_read = result.or_else(|e| match e.kind() {
ConnectionReset | ConnectionAborted => Ok(0),
_ => Err(e)
})?;
},
_ = abort.recv() => {
break;
}
}
if bytes_read == 0 {
break;
}
write.write_all(&buf[0..bytes_read]).await?;
copied += bytes_read;
}
Ok(copied)
}
loop {
let (mut client, client_addr) = listener.accept().await?;
tokio::spawn(async move {
println!("New connection from {}", client_addr);
let mut remote = match TcpStream::connect(remote).await {
Ok(result) => result,
Err(e) => {
eprintln!("Error establishing upstream connection: {e}");
return;
}
};
let (mut client_read, mut client_write) = client.split();
let (mut remote_read, mut remote_write) = remote.split();
let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
copy_with_abort(&mut client_read, &mut remote_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
};
match client_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from proxy client {} to upstream server",
count, client_addr
);
}
}
Err(err) => {
eprintln!(
"Error writing bytes from proxy client {} to upstream server",
client_addr
);
eprintln!("{}", err);
}
};
match remote_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from upstream server to proxy client {}",
count, client_addr
);
}
}
Err(err) => {
eprintln!(
"Error writing from upstream server to proxy client {}!",
client_addr
);
eprintln!("{}", err);
}
};
()
});
}
}