use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
#[cfg(feature = "server")]
use std::net::{SocketAddr, TcpListener};
#[cfg(feature = "server")]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "server")]
use std::thread::JoinHandle;
#[cfg(feature = "server")]
use std::time::Duration;
#[cfg(feature = "server")]
use crate::error::{Error, Result};
#[cfg(feature = "server")]
use crate::server::{X11ForwardContext, X11ForwardHandle, X11ForwardHandler};
use crate::stream::{ChannelEgress, ChannelStream};
#[cfg(feature = "server")]
const ACCEPT_POLL_INTERVAL: Duration = Duration::from_millis(100);
const X_BASE_PORT: u16 = 6000;
#[cfg(feature = "server")]
const DEFAULT_MIN_DISPLAY: u16 = 10;
#[cfg(feature = "server")]
const DEFAULT_MAX_DISPLAY: u16 = 999;
#[cfg(feature = "server")]
struct X11Binding {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
}
#[cfg(feature = "server")]
impl Drop for X11Binding {
fn drop(&mut self) {
self.stop.store(true, Ordering::SeqCst);
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
#[cfg(feature = "server")]
pub struct DefaultX11ForwardHandler {
min_display: u16,
max_display: u16,
}
#[cfg(feature = "server")]
impl Default for DefaultX11ForwardHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "server")]
impl DefaultX11ForwardHandler {
pub fn new() -> Self {
Self {
min_display: DEFAULT_MIN_DISPLAY,
max_display: DEFAULT_MAX_DISPLAY,
}
}
pub fn with_display_range(min_display: u16, max_display: u16) -> Self {
Self {
min_display,
max_display,
}
}
fn bind_first_free(&self) -> Result<(TcpListener, u16)> {
for n in self.min_display..=self.max_display {
let port = X_BASE_PORT.saturating_add(n);
let addr: SocketAddr = ([127u8, 0, 0, 1], port).into();
if let Ok(listener) = TcpListener::bind(addr) {
return Ok((listener, n));
}
}
Err(Error::Io(std::io::Error::new(
ErrorKind::AddrInUse,
"x11-forward: no free display number in configured range",
)))
}
}
#[cfg(feature = "server")]
impl X11ForwardHandler for DefaultX11ForwardHandler {
fn setup(
&self,
_user: &str,
_single_connection: bool,
_auth_protocol: &str,
_auth_cookie: &str,
screen: u32,
ctx: X11ForwardContext,
) -> Result<X11ForwardHandle> {
let (listener, display_number) = self.bind_first_free()?;
listener.set_nonblocking(true)?;
let stop = Arc::new(AtomicBool::new(false));
let stop_thread = Arc::clone(&stop);
let handle = thread::spawn(move || {
while !stop_thread.load(Ordering::SeqCst) {
match listener.accept() {
Ok((conn, peer)) => {
let orig_host = match peer.ip() {
std::net::IpAddr::V4(v4) => v4.to_string(),
std::net::IpAddr::V6(v6) => v6.to_string(),
};
let orig_port = peer.port() as u32;
match ctx.open_x11(orig_host, orig_port) {
Ok(channel_stream) => {
spawn_tcp_splice(conn, channel_stream);
}
Err(_) => {
let _ = conn.shutdown(Shutdown::Both);
}
}
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {
thread::sleep(ACCEPT_POLL_INTERVAL);
}
Err(_) => break,
}
}
});
let display_env = format!("localhost:{display_number}.{screen}");
let binding = X11Binding {
stop,
handle: Some(handle),
};
Ok(X11ForwardHandle {
display_env,
display_number,
stopper: Box::new(binding),
})
}
}
fn spawn_tcp_splice(tcp: TcpStream, stream: ChannelStream) {
let (chan_rx, chan_tx) = stream.into_raw();
let Ok(tcp_in) = tcp.try_clone() else {
let _ = chan_tx.send(ChannelEgress::Eof);
let _ = chan_tx.send(ChannelEgress::Close);
return;
};
let tcp_out = tcp;
let chan_tx_a = chan_tx.clone();
let mut tcp_in_a = tcp_in;
let a = thread::spawn(move || {
let mut buf = [0u8; 32 * 1024];
loop {
match tcp_in_a.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if chan_tx_a
.send(ChannelEgress::Data(buf[..n].to_vec()))
.is_err()
{
break;
}
}
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
let _ = chan_tx_a.send(ChannelEgress::Eof);
});
let mut tcp_out_b = tcp_out;
let b = thread::spawn(move || {
while let Ok(Some(chunk)) = chan_rx.recv() {
if tcp_out_b.write_all(&chunk).is_err() {
break;
}
}
let _ = tcp_out_b.shutdown(Shutdown::Read);
});
thread::spawn(move || {
let _ = a.join();
let _ = b.join();
let _ = chan_tx.send(ChannelEgress::Close);
});
}
pub fn splice_to_unix_display_callback(
path: PathBuf,
) -> Option<Arc<dyn Fn(ChannelStream) + Send + Sync + 'static>> {
if !path.exists() {
return None;
}
Some(Arc::new(
move |stream: ChannelStream| match std::os::unix::net::UnixStream::connect(&path) {
Ok(uds) => spawn_unix_splice(uds, stream),
Err(_) => {
let (_rx, tx) = stream.into_raw();
let _ = tx.send(ChannelEgress::Eof);
let _ = tx.send(ChannelEgress::Close);
}
},
))
}
pub fn splice_to_tcp_display_callback(
host: String,
port: u16,
) -> Arc<dyn Fn(ChannelStream) + Send + Sync + 'static> {
Arc::new(
move |stream: ChannelStream| match TcpStream::connect((host.as_str(), port)) {
Ok(tcp) => spawn_tcp_splice(tcp, stream),
Err(_) => {
let (_rx, tx) = stream.into_raw();
let _ = tx.send(ChannelEgress::Eof);
let _ = tx.send(ChannelEgress::Close);
}
},
)
}
pub fn splice_to_local_display_callback(
) -> Option<Arc<dyn Fn(ChannelStream) + Send + Sync + 'static>> {
let raw = std::env::var("DISPLAY").ok()?;
if raw.is_empty() {
return None;
}
let (host_part, display_part) = raw.rsplit_once(':')?;
let n_str = display_part.split('.').next()?;
let n: u16 = n_str.parse().ok()?;
if host_part.is_empty() || host_part == "unix" {
let path = PathBuf::from(format!("/tmp/.X11-unix/X{n}"));
return splice_to_unix_display_callback(path);
}
let port = X_BASE_PORT.saturating_add(n);
Some(splice_to_tcp_display_callback(host_part.to_string(), port))
}
fn spawn_unix_splice(uds: std::os::unix::net::UnixStream, stream: ChannelStream) {
use std::os::unix::net::UnixStream;
let (chan_rx, chan_tx) = stream.into_raw();
let Ok(uds_in) = uds.try_clone() else {
let _ = chan_tx.send(ChannelEgress::Eof);
let _ = chan_tx.send(ChannelEgress::Close);
return;
};
let uds_out = uds;
let chan_tx_a = chan_tx.clone();
let mut uds_in_a: UnixStream = uds_in;
let a = thread::spawn(move || {
let mut buf = [0u8; 32 * 1024];
loop {
match uds_in_a.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if chan_tx_a
.send(ChannelEgress::Data(buf[..n].to_vec()))
.is_err()
{
break;
}
}
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
let _ = chan_tx_a.send(ChannelEgress::Eof);
});
let mut uds_out_b: UnixStream = uds_out;
let b = thread::spawn(move || {
while let Ok(Some(chunk)) = chan_rx.recv() {
if uds_out_b.write_all(&chunk).is_err() {
break;
}
}
let _ = uds_out_b.shutdown(Shutdown::Read);
});
thread::spawn(move || {
let _ = a.join();
let _ = b.join();
let _ = chan_tx.send(ChannelEgress::Close);
});
}
#[cfg(all(test, feature = "server"))]
mod tests {
use super::*;
#[test]
fn setup_binds_a_display_port() {
let h = DefaultX11ForwardHandler::with_display_range(900, 920);
let ctx = X11ForwardContext::for_test_no_opens();
let handle = h
.setup("u", false, "MIT-MAGIC-COOKIE-1", "deadbeef", 0, ctx)
.expect("setup");
let n = handle.display_number;
assert!((900..=920).contains(&n), "n out of range: {n}");
assert_eq!(handle.display_env, format!("localhost:{n}.0"));
let addr: SocketAddr = ([127u8, 0, 0, 1], 6000 + n).into();
assert!(
TcpListener::bind(addr).is_err(),
"port should be in use while the handle is alive"
);
drop(handle);
for _ in 0..50 {
if TcpListener::bind(addr).is_ok() {
break;
}
thread::sleep(Duration::from_millis(50));
}
assert!(
TcpListener::bind(addr).is_ok(),
"port should be free after the handle is dropped"
);
}
#[test]
fn accepted_connection_is_closed_when_open_fails() {
let h = DefaultX11ForwardHandler::with_display_range(800, 820);
let ctx = X11ForwardContext::for_test_no_opens();
let handle = h
.setup("u", false, "MIT-MAGIC-COOKIE-1", "deadbeef", 0, ctx)
.expect("setup");
let addr: SocketAddr = ([127u8, 0, 0, 1], 6000 + handle.display_number).into();
let mut peer = TcpStream::connect_timeout(&addr, Duration::from_secs(2)).expect("connect");
peer.set_read_timeout(Some(Duration::from_secs(2)))
.expect("read timeout");
let mut buf = [0u8; 1];
let _ = peer.read(&mut buf);
}
#[test]
fn tcp_display_callback_constructs() {
let _cb = splice_to_tcp_display_callback("127.0.0.1".to_string(), 65000);
}
}