use std::collections::BTreeMap;
use std::io::{ErrorKind, Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::error::{Error, Result};
use crate::server::{ChannelEgress, ForwardContext, TcpipForwardHandler};
const ACCEPT_POLL_INTERVAL: Duration = Duration::from_millis(100);
struct Binding {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
}
impl Drop for Binding {
fn drop(&mut self) {
self.stop.store(true, Ordering::SeqCst);
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
type AllowFilter = Box<dyn Fn(&str, &str, u16) -> bool + Send + Sync>;
pub struct DefaultTcpipForwardHandler {
bindings: Mutex<BTreeMap<(String, u16), Binding>>,
allow: Option<AllowFilter>,
}
impl Default for DefaultTcpipForwardHandler {
fn default() -> Self {
Self::new()
}
}
impl DefaultTcpipForwardHandler {
pub fn new() -> Self {
Self {
bindings: Mutex::new(BTreeMap::new()),
allow: None,
}
}
pub fn with_allow_filter<F>(mut self, filter: F) -> Self
where
F: Fn(&str, &str, u16) -> bool + Send + Sync + 'static,
{
self.allow = Some(Box::new(filter));
self
}
fn allowed(&self, user: &str, bind_address: &str, bind_port: u16) -> bool {
match &self.allow {
Some(f) => f(user, bind_address, bind_port),
None => true,
}
}
pub fn binding_count(&self) -> usize {
self.bindings.lock().map(|m| m.len()).unwrap_or(0)
}
}
fn spawn_splice(tcp: TcpStream, stream: crate::server::ChannelStream) {
let (chan_rx, chan_tx) = stream.into_raw();
let Ok(tcp_in) = tcp.try_clone() else {
let _ = chan_tx.send(ChannelEgress::Eof);
let _ = chan_tx.send(ChannelEgress::Close);
return;
};
let tcp_out = tcp;
let chan_tx_a = chan_tx.clone();
let mut tcp_in_a = tcp_in;
let a = thread::spawn(move || {
let mut buf = [0u8; 32 * 1024];
loop {
match tcp_in_a.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if chan_tx_a
.send(ChannelEgress::Data(buf[..n].to_vec()))
.is_err()
{
break;
}
}
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
let _ = chan_tx_a.send(ChannelEgress::Eof);
});
let mut tcp_out_b = tcp_out;
let b = thread::spawn(move || {
while let Ok(Some(chunk)) = chan_rx.recv() {
if tcp_out_b.write_all(&chunk).is_err() {
break;
}
}
let _ = tcp_out_b.shutdown(std::net::Shutdown::Read);
});
thread::spawn(move || {
let _ = a.join();
let _ = b.join();
let _ = chan_tx.send(ChannelEgress::Close);
});
}
fn resolve_bind(bind_address: &str, port: u16) -> Result<SocketAddr> {
match bind_address {
"" | "0.0.0.0" => Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)),
"::" => Ok(SocketAddr::new(
IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED),
port,
)),
"localhost" | "127.0.0.1" => Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)),
"::1" => Ok(SocketAddr::new(
IpAddr::V6(std::net::Ipv6Addr::LOCALHOST),
port,
)),
other => other
.parse::<IpAddr>()
.map(|ip| SocketAddr::new(ip, port))
.map_err(|_| Error::Protocol("tcpip-forward: invalid bind address")),
}
}
impl TcpipForwardHandler for DefaultTcpipForwardHandler {
fn bind(
&self,
user: &str,
bind_address: &str,
bind_port: u16,
ctx: ForwardContext,
) -> Result<u16> {
if !self.allowed(user, bind_address, bind_port) {
return Err(Error::Protocol("tcpip-forward: bind refused by policy"));
}
let addr = resolve_bind(bind_address, bind_port)?;
let listener = TcpListener::bind(addr)?;
let actual_port = listener.local_addr()?.port();
listener.set_nonblocking(true)?;
let stop = Arc::new(AtomicBool::new(false));
let stop_thread = Arc::clone(&stop);
let bind_address_owned = bind_address.to_string();
let handle = thread::spawn(move || {
while !stop_thread.load(Ordering::SeqCst) {
match listener.accept() {
Ok((conn, peer)) => {
let (orig_host, orig_port) = match peer {
SocketAddr::V4(a) => (a.ip().to_string(), a.port()),
SocketAddr::V6(a) => (a.ip().to_string(), a.port()),
};
match ctx.open_forwarded_tcpip(
&bind_address_owned,
actual_port,
&orig_host,
orig_port,
) {
Ok(channel_stream) => {
spawn_splice(conn, channel_stream);
}
Err(_) => {
let _ = conn.shutdown(std::net::Shutdown::Both);
}
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {
thread::sleep(ACCEPT_POLL_INTERVAL);
}
Err(_) => break,
}
}
});
let mut map = self
.bindings
.lock()
.map_err(|_| Error::Protocol("tcpip-forward: lock poisoned"))?;
let key = (bind_address.to_string(), actual_port);
if let Some(existing) = map.remove(&key) {
drop(existing);
}
map.insert(
key,
Binding {
stop,
handle: Some(handle),
},
);
Ok(actual_port)
}
fn unbind(&self, _user: &str, bind_address: &str, bind_port: u16) -> Result<()> {
let mut map = self
.bindings
.lock()
.map_err(|_| Error::Protocol("tcpip-forward: lock poisoned"))?;
let key = (bind_address.to_string(), bind_port);
if let Some(binding) = map.remove(&key) {
drop(map);
drop(binding);
Ok(())
} else {
Err(Error::Protocol("cancel-tcpip-forward: no such binding"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bind_port_zero_picks_and_returns_a_port() {
let h = DefaultTcpipForwardHandler::new();
let port = h
.bind("u", "127.0.0.1", 0, ForwardContext::for_test_no_opens())
.expect("bind");
assert!(port > 0, "kernel-assigned port should be non-zero");
assert_eq!(h.binding_count(), 1);
h.unbind("u", "127.0.0.1", port).expect("unbind");
assert_eq!(h.binding_count(), 0);
}
#[test]
fn unbind_releases_the_listener_so_a_fresh_bind_succeeds() {
let h = DefaultTcpipForwardHandler::new();
let port = h
.bind("u", "127.0.0.1", 0, ForwardContext::for_test_no_opens())
.expect("first bind");
h.unbind("u", "127.0.0.1", port).expect("unbind");
let again = h
.bind("u", "127.0.0.1", port, ForwardContext::for_test_no_opens())
.expect("rebind released port");
assert_eq!(again, port);
h.unbind("u", "127.0.0.1", port).expect("final unbind");
}
#[test]
fn unbind_of_unknown_binding_errors() {
let h = DefaultTcpipForwardHandler::new();
assert!(h.unbind("u", "127.0.0.1", 12345).is_err());
}
#[test]
fn invalid_bind_address_is_rejected() {
let h = DefaultTcpipForwardHandler::new();
assert!(h
.bind(
"u",
"not-an-ip-or-name",
0,
ForwardContext::for_test_no_opens(),
)
.is_err());
}
#[test]
fn allow_filter_can_refuse_bind() {
let h = DefaultTcpipForwardHandler::new()
.with_allow_filter(|_user, addr, _port| addr == "127.0.0.1");
let port = h
.bind("u", "127.0.0.1", 0, ForwardContext::for_test_no_opens())
.expect("loopback bind allowed");
assert!(h
.bind("u", "0.0.0.0", 0, ForwardContext::for_test_no_opens())
.is_err());
assert_eq!(h.binding_count(), 1);
h.unbind("u", "127.0.0.1", port).expect("unbind");
}
#[test]
fn allow_filter_sees_user() {
let h = DefaultTcpipForwardHandler::new()
.with_allow_filter(|user, _addr, _port| user == "alice");
assert!(h
.bind("bob", "127.0.0.1", 0, ForwardContext::for_test_no_opens())
.is_err());
let port = h
.bind("alice", "127.0.0.1", 0, ForwardContext::for_test_no_opens())
.expect("alice bind allowed");
h.unbind("alice", "127.0.0.1", port).expect("unbind");
}
}