use std::io::{self, Read, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use crate::error::Result;
use crate::server::{ChannelEgress, ChannelStream, DirectTcpipHandler, DirectTcpipRequest};
type AllowFilter = Box<dyn Fn(&str, u16) -> bool + Send + Sync>;
pub struct DefaultDirectTcpipHandler {
allow: Option<AllowFilter>,
connect_timeout: Option<Duration>,
}
impl Default for DefaultDirectTcpipHandler {
fn default() -> Self {
Self::new()
}
}
impl DefaultDirectTcpipHandler {
pub fn new() -> Self {
Self {
allow: None,
connect_timeout: Some(Duration::from_secs(10)),
}
}
pub fn with_allow_list<F>(mut self, filter: F) -> Self
where
F: Fn(&str, u16) -> bool + Send + Sync + 'static,
{
self.allow = Some(Box::new(filter));
self
}
pub fn with_connect_timeout(mut self, timeout: Option<Duration>) -> Self {
self.connect_timeout = timeout;
self
}
fn allowed(&self, host: &str, port: u16) -> bool {
match &self.allow {
Some(f) => f(host, port),
None => true,
}
}
}
impl DirectTcpipHandler for DefaultDirectTcpipHandler {
fn handle(
&self,
_user: &str,
request: DirectTcpipRequest<'_>,
stream: ChannelStream,
) -> Result<()> {
let host = request.dest_host;
let port: u16 = match u16::try_from(request.dest_port) {
Ok(p) => p,
Err(_) => return Ok(()),
};
if !self.allowed(host, port) {
return Ok(());
}
let tcp = match connect_with_timeout(host, port, self.connect_timeout) {
Ok(s) => s,
Err(_) => return Ok(()),
};
splice(stream, tcp)
}
}
fn connect_with_timeout(host: &str, port: u16, timeout: Option<Duration>) -> io::Result<TcpStream> {
let target = format!("{host}:{port}");
let mut last_err: Option<io::Error> = None;
for sock in target.to_socket_addrs()? {
let res = match timeout {
Some(d) => TcpStream::connect_timeout(&sock, d),
None => TcpStream::connect(sock),
};
match res {
Ok(s) => return Ok(s),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no address resolved")))
}
fn splice(stream: ChannelStream, tcp: TcpStream) -> Result<()> {
let (raw_rx, raw_tx) = stream.into_raw();
let tcp_for_writer = tcp.try_clone().map_err(|e| {
crate::error::Error::Io(io::Error::new(
e.kind(),
"direct-tcpip: TcpStream::try_clone failed",
))
})?;
let stop = Arc::new(AtomicBool::new(false));
let stop_worker = Arc::clone(&stop);
let tx_worker = raw_tx.clone();
let worker = thread::spawn(move || {
let mut reader = tcp_for_writer;
copy_tcp_to_channel(&mut reader, &tx_worker, &stop_worker);
});
let mut writer = tcp;
copy_channel_to_tcp(&raw_rx, &mut writer, &stop);
let _ = writer.shutdown(Shutdown::Both);
stop.store(true, Ordering::SeqCst);
let _ = worker.join();
let _ = raw_tx.send(ChannelEgress::Eof);
let _ = raw_tx.send(ChannelEgress::Close);
Ok(())
}
fn copy_channel_to_tcp(rx: &Receiver<Option<Vec<u8>>>, tcp: &mut TcpStream, stop: &AtomicBool) {
loop {
if stop.load(Ordering::SeqCst) {
return;
}
match rx.recv() {
Ok(Some(chunk)) => {
if tcp.write_all(&chunk).is_err() {
return;
}
}
Ok(None) | Err(_) => return,
}
}
}
fn copy_tcp_to_channel(tcp: &mut TcpStream, tx: &SyncSender<ChannelEgress>, stop: &AtomicBool) {
let mut buf = [0u8; 32 * 1024];
loop {
if stop.load(Ordering::SeqCst) {
return;
}
let n = match tcp.read(&mut buf) {
Ok(0) => return,
Ok(n) => n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(_) => return,
};
if tx.send(ChannelEgress::Data(buf[..n].to_vec())).is_err() {
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use std::sync::mpsc;
use std::time::Duration;
fn echo_server() -> (std::net::SocketAddr, thread::JoinHandle<()>) {
let l = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = l.local_addr().expect("addr");
let h = thread::spawn(move || {
if let Ok((mut s, _)) = l.accept() {
let mut buf = [0u8; 1024];
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
if s.write_all(&buf[..n]).is_err() {
break;
}
}
}
}
}
});
(addr, h)
}
#[test]
fn direct_tcpip_round_trip_through_echo_server() {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) = mpsc::sync_channel::<ChannelEgress>(32);
let stream = ChannelStream::new(ingress_rx, egress_tx);
let (addr, echo) = echo_server();
let host = addr.ip().to_string();
let port = addr.port();
let handler = thread::spawn(move || {
let h = DefaultDirectTcpipHandler::new();
let req = DirectTcpipRequest {
dest_host: &host,
dest_port: port as u32,
orig_host: "client",
orig_port: 0,
};
h.handle("test-user", req, stream).expect("handle");
});
ingress_tx
.send(Some(b"ping".to_vec()))
.expect("ingress send");
let mut got = Vec::new();
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while got.len() < 4 && std::time::Instant::now() < deadline {
match egress_rx.recv_timeout(Duration::from_millis(500)) {
Ok(ChannelEgress::Data(d)) => got.extend_from_slice(&d),
Ok(ChannelEgress::Eof) | Ok(ChannelEgress::Close) => break,
Err(_) => break,
}
}
assert_eq!(&got, b"ping");
ingress_tx.send(None).expect("ingress eof");
drop(ingress_tx);
let mut saw_eof = false;
let mut saw_close = false;
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while (!saw_eof || !saw_close) && std::time::Instant::now() < deadline {
match egress_rx.recv_timeout(Duration::from_millis(500)) {
Ok(ChannelEgress::Data(_)) => continue,
Ok(ChannelEgress::Eof) => saw_eof = true,
Ok(ChannelEgress::Close) => saw_close = true,
Err(_) => break,
}
}
assert!(saw_eof, "handler should send Eof on teardown");
assert!(saw_close, "handler should send Close on teardown");
handler.join().expect("handler thread");
let _ = echo.join();
}
#[test]
fn out_of_range_port_is_rejected_silently() {
let (_ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, _egress_rx) = mpsc::sync_channel::<ChannelEgress>(8);
let stream = ChannelStream::new(ingress_rx, egress_tx);
let h = DefaultDirectTcpipHandler::new();
let req = DirectTcpipRequest {
dest_host: "127.0.0.1",
dest_port: 70_000,
orig_host: "client",
orig_port: 0,
};
h.handle("u", req, stream).expect("handle");
}
#[test]
fn allow_list_rejects_silently() {
let (_ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, _egress_rx) = mpsc::sync_channel::<ChannelEgress>(8);
let stream = ChannelStream::new(ingress_rx, egress_tx);
let h =
DefaultDirectTcpipHandler::new().with_allow_list(|host, _| host == "allowed.example");
let req = DirectTcpipRequest {
dest_host: "denied.example",
dest_port: 80,
orig_host: "client",
orig_port: 0,
};
h.handle("u", req, stream).expect("handle");
}
}