use {
clap::Parser,
rustix::{
fd::{AsFd, FromRawFd, OwnedFd, RawFd},
fs::{flock, OpenOptionsExt},
},
sd_notify::NotifyState,
std::{
collections::{HashMap, HashSet},
fmt::Display,
fs::{remove_file, File},
io::Cursor,
os::unix::net::{UnixListener, UnixStream},
path::PathBuf,
sync::{Arc, Mutex, OnceLock},
thread::spawn,
},
};
use uds::UnixStreamExt;
use wlproxy::proto::{self, read_arg_string};
use wlproxy::ObjType;
fn default_upstream() -> PathBuf {
let runtime_dir =
std::env::var("XDG_RUNTIME_DIR").expect("XDG_RUNTIME_DIR must be set for Wayland");
let socket_name = std::env::var("WAYLAND_DISPLAY").unwrap_or_else(|_| "wayland-0".to_string());
PathBuf::from(runtime_dir).join(socket_name)
}
#[derive(Parser, Clone)]
#[command(name = "wlproxy")]
struct Args {
#[arg(short = 'u', long = "upstream")]
upstream: Option<PathBuf>,
#[arg(short = 'a', long = "app-id")]
app_id: Option<String>,
#[arg(short = 'A', long = "prefix-app-id")]
prefix_app_id: bool,
#[arg(short = 't', long = "title")]
title: Option<String>,
#[arg(short = 'T', long = "prefix-title")]
prefix_title: bool,
#[arg(short = 'b', long = "block", value_delimiter = ',')]
block: Vec<String>,
#[arg(short = 'q', long = "quiet")]
quiet: bool,
#[arg(long = "debug")]
debug: bool,
downstream: PathBuf,
}
fn known_protocols() -> &'static [&'static str] {
static LIST: OnceLock<Vec<&'static str>> = OnceLock::new();
LIST.get_or_init(|| {
let mut list: Vec<&'static str> = include_str!("../known_protocols.txt")
.lines()
.filter(|l| {
let t = l.trim();
!t.is_empty() && !t.starts_with('#')
})
.collect();
list.sort();
list.dedup();
list
})
}
trait Errorize<T> {
fn context(self, text: &str) -> Result<T, String>;
}
impl<T, E: Display> Errorize<T> for Result<T, E> {
fn context(self, text: &str) -> Result<T, String> {
match self {
Ok(x) => Ok(x),
Err(e) => Err(format!("{}: {}", text, e)),
}
}
}
fn is_blocked_interface(name: &str, block: &[String]) -> bool {
block.iter().any(|b| b == name)
}
fn validate_interfaces(block: &[String], quiet: bool) {
if quiet || block.is_empty() {
return;
}
for name in block {
if !known_protocols().contains(&name.as_str()) {
eprintln!(
"Warning: unknown Wayland interface \"{}\" in --block list",
name
);
}
}
}
fn drop_fds(fds: &mut Vec<RawFd>) {
for fd in fds.drain(..) {
drop(unsafe { OwnedFd::from_raw_fd(fd) });
}
}
struct AncillaryReader<'a> {
reader: &'a UnixStream,
fds: &'a mut Vec<RawFd>,
}
impl<'a> std::io::Read for AncillaryReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut fd_buf = [0i32; 8];
let (n, nfds) = self.reader.recv_fds(buf, &mut fd_buf)?;
self.fds.extend(&fd_buf[..nfds]);
Ok(n)
}
}
struct AncillaryWriter<'a> {
writer: &'a UnixStream,
fds: Vec<RawFd>,
}
impl<'a> AncillaryWriter<'a> {
fn new(writer: &'a UnixStream, fds: &[RawFd]) -> Self {
Self {
writer,
fds: fds.to_vec(),
}
}
}
impl<'a> std::io::Write for AncillaryWriter<'a> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let fds = std::mem::take(&mut self.fds);
if fds.is_empty() {
self.writer.write(buf)
} else {
self.writer.send_fds(buf, &fds)
}
}
fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}
fn handle_client_to_server(
downstream: &UnixStream,
upstream: &UnixStream,
objects: &Arc<Mutex<HashMap<u32, ObjType>>>,
xdgwmbase_type_id: &Arc<Mutex<Option<(u32, u32)>>>,
blocked_objects: &Arc<Mutex<HashSet<u32>>>,
args: &Args,
) -> Result<(), String> {
let mut ancillary_accum = vec![];
while let Some(mut packet) = proto::read_packet(&mut AncillaryReader {
reader: downstream,
fds: &mut ancillary_accum,
})
.context("Error reading message")?
{
let should_block = {
let mut objects = objects.lock().unwrap();
let mut blocked = blocked_objects.lock().unwrap();
if blocked.contains(&packet.id) {
true
} else {
let o = objects.get(&packet.id).cloned();
if args.debug {
eprintln!(
"Received packet from downstream for tracked object {:?} with {} ancillary FDs: {:?}",
o,
ancillary_accum.len(),
packet
);
}
if let Some(o) = o {
match o {
ObjType::Display => {
if packet.opcode == 1 {
let obj_id = proto::read_arg_uint(&mut Cursor::new(&packet.body))
.context("Error reading registry id")?;
objects.insert(obj_id, ObjType::Registry);
}
false
}
ObjType::Registry => {
if packet.opcode == 0 {
let mut cursor = Cursor::new(&packet.body);
let obj_type_id = proto::read_arg_uint(&mut cursor)
.context("Error/eof reading bind object type id")?;
let interface_name = proto::read_arg_string(&mut cursor)
.context("Error reading bind message type string")?;
let version = proto::read_arg_uint(&mut cursor)
.context("Error reading bind message version")?;
let obj_id = proto::read_arg_uint(&mut cursor)
.context("Error/eof reading bind object id")?;
if let Some(ref name) = interface_name {
if is_blocked_interface(name, &args.block) {
if args.debug {
eprintln!("Blocked bind for interface: {}", name);
}
blocked.insert(obj_id);
true
} else if let Some((want_type_id, _version)) =
*xdgwmbase_type_id.lock().unwrap()
{
if obj_type_id == want_type_id {
objects.insert(
obj_id,
ObjType::XdgWmBase { ver: version },
);
}
false
} else {
false
}
} else {
false
}
} else {
false
}
}
ObjType::XdgWmBase { ver } => match ver {
0..=6 => {
if packet.opcode == 2 {
let obj_id =
proto::read_arg_uint(&mut Cursor::new(&packet.body))
.context(
"Error reading xdg wm base create surface id",
)?;
objects.insert(obj_id, ObjType::XdgSurface { ver });
}
false
}
_ => {
return Err(
format!("Unsupported xdg_wm_base object version {ver}",),
)
}
},
ObjType::XdgSurface { ver } => match ver {
0..=6 => {
if packet.opcode == 1 {
let obj_id =
proto::read_arg_uint(&mut Cursor::new(&packet.body))
.context(
"Error reading xdg surface create toplevel id",
)?;
objects.insert(obj_id, ObjType::XdgToplevel { ver });
}
false
}
_ => {
return Err(
format!("Unsupported xdg_surface object version {ver}",),
)
}
},
ObjType::XdgToplevel { ver } => match ver {
0..=6 => {
match packet.opcode {
2 => {
if let Some(title) = &args.title {
let read_title =
read_arg_string(&mut packet.body.as_slice())
.context("Error reading app id message body")?;
packet.body.clear();
let new_title = if args.prefix_title {
format!(
"{}{}",
title,
read_title.unwrap_or_default()
)
} else {
title.clone()
};
proto::write_arg_string(&mut packet.body, &new_title)
.unwrap();
if args.debug {
eprintln!(
"Modified title; new message: {:?}",
packet
);
}
}
}
3 => {
if let Some(app_id) = &args.app_id {
let read_app_id =
read_arg_string(&mut packet.body.as_slice())
.context("Error reading app id message body")?;
packet.body.clear();
let new_app_id = if args.prefix_app_id {
format!(
"{}{}",
app_id,
read_app_id.unwrap_or_default()
)
} else {
app_id.clone()
};
proto::write_arg_string(&mut packet.body, &new_app_id)
.unwrap();
if args.debug {
eprintln!(
"Modified app id; new message: {:?}",
packet
);
}
}
}
_ => (),
};
false
}
_ => {
return Err(format!(
"Unsupported xdg_toplevel object version {ver}",
))
}
},
}
} else {
false
}
}
};
if should_block {
drop_fds(&mut ancillary_accum);
continue;
}
proto::write_packet(
&mut AncillaryWriter::new(upstream, &ancillary_accum),
&packet,
)
.context("Error writing message")?;
drop_fds(&mut ancillary_accum);
}
Ok(())
}
fn handle_server_to_client(
upstream: &UnixStream,
downstream: &UnixStream,
objects: &Arc<Mutex<HashMap<u32, ObjType>>>,
xdgwmbase_type_id: &Arc<Mutex<Option<(u32, u32)>>>,
args: &Args,
) -> Result<(), String> {
let mut ancillary_accum = vec![];
let mut cache_reg_id = None;
while let Some(packet) = proto::read_packet(&mut AncillaryReader {
reader: upstream,
fds: &mut ancillary_accum,
})
.context("Error reading message")?
{
if args.debug {
eprintln!(
"Received packet from upstream with {} ancillary FDs: {:?}",
ancillary_accum.len(),
packet
);
}
if (packet.id, packet.opcode) == (1, 1) {
let mut cursor = Cursor::new(&packet.body);
let obj_id =
proto::read_arg_uint(&mut cursor).context("Error reading display delete obj id")?;
objects.lock().unwrap().remove(&obj_id);
if cache_reg_id == Some(obj_id) {
cache_reg_id = None;
}
}
if let Some(reg_id) = match &cache_reg_id {
Some(r) => Some(*r),
None => {
if let Some(ObjType::Registry) = objects.lock().unwrap().get(&packet.id) {
cache_reg_id = Some(packet.id);
Some(packet.id)
} else {
None
}
}
} {
if reg_id == packet.id && packet.opcode == 0 {
let mut cursor = Cursor::new(&packet.body);
let type_id = proto::read_arg_uint(&mut cursor)
.context("Error reading global message type id")?;
let type_str = proto::read_arg_string(&mut cursor)
.context("Error reading global message type string")?;
let version = proto::read_arg_uint(&mut cursor)
.context("Error reading global message version")?;
if type_str.as_deref() == Some("xdg_wm_base") {
*xdgwmbase_type_id.lock().unwrap() = Some((type_id, version));
}
if let Some(ref name) = type_str {
if !args.block.is_empty() && args.block.iter().any(|b| b == name) {
if args.debug {
eprintln!("Blocked global: {}", name);
}
drop_fds(&mut ancillary_accum);
continue;
}
}
}
}
proto::write_packet(
&mut AncillaryWriter::new(downstream, &ancillary_accum),
&packet,
)
.context("Error writing message")?;
drop_fds(&mut ancillary_accum);
}
Ok(())
}
fn main() -> Result<(), String> {
let args = Args::parse();
validate_interfaces(&args.block, args.quiet);
let lock_path = args.downstream.with_extension("lock");
let filelock = File::options()
.mode(0o660)
.write(true)
.create(true)
.custom_flags(libc::O_CLOEXEC)
.open(&lock_path)
.context("Error opening lock file")?;
flock(
filelock.as_fd(),
rustix::fs::FlockOperation::NonBlockingLockExclusive,
).context("Error getting exclusive lock for downstream listener, is another compositor already listening?")?;
let _defer = defer::defer(|| {
_ = remove_file(&lock_path);
});
_ = remove_file(&args.downstream);
let downstream =
UnixListener::bind(&args.downstream).context("Error creating downstream listener")?;
let _defer1 = defer::defer(|| {
_ = remove_file(&args.downstream);
});
if let Ok(true) = sd_notify::booted() {
if args.debug {
eprintln!("Init detected as being systemd. Notifying of readiness.");
}
if let Err(e) = sd_notify::notify(&[NotifyState::Ready]) {
eprintln!("Warning, failed to notify systemd with error: {}", e);
}
}
loop {
let (downstream, _) = downstream
.accept()
.context("Error accepting downstream connection")?;
let upstream_path = args.upstream.clone().unwrap_or_else(default_upstream);
let upstream =
UnixStream::connect(&upstream_path).context("Error creating upstream connection")?;
let objects = Arc::new(Mutex::new(HashMap::new()));
objects.lock().unwrap().insert(1, ObjType::Display);
let xdgwmbase_type_id = Arc::new(Mutex::new(None));
let blocked_objects: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
spawn({
let downstream = downstream.try_clone().unwrap();
let upstream = upstream.try_clone().unwrap();
let objects = objects.clone();
let xdgwmbase_type_id = xdgwmbase_type_id.clone();
let blocked_objects = blocked_objects.clone();
let args = args.clone();
move || {
let _defer = defer::defer({
let downstream = downstream.try_clone().unwrap();
let upstream = upstream.try_clone().unwrap();
move || {
_ = downstream.shutdown(std::net::Shutdown::Both);
_ = upstream.shutdown(std::net::Shutdown::Both);
}
});
if let Err(e) = handle_client_to_server(
&downstream,
&upstream,
&objects,
&xdgwmbase_type_id,
&blocked_objects,
&args,
) {
eprintln!("Warning, client->server thread exiting with error: {}", e);
}
}
});
spawn({
let downstream = downstream.try_clone().unwrap();
let upstream = upstream.try_clone().unwrap();
let objects = objects.clone();
let xdgwmbase_type_id = xdgwmbase_type_id.clone();
let args = args.clone();
move || {
let _defer = defer::defer({
let downstream = downstream.try_clone().unwrap();
let upstream = upstream.try_clone().unwrap();
move || {
_ = downstream.shutdown(std::net::Shutdown::Both);
_ = upstream.shutdown(std::net::Shutdown::Both);
}
});
if let Err(e) = handle_server_to_client(
&upstream,
&downstream,
&objects,
&xdgwmbase_type_id,
&args,
) {
eprintln!("Warning, server->client thread exiting with error: {}", e);
}
}
});
}
}