use std::{
fmt,
ops::RangeInclusive,
os::unix::io::OwnedFd,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};
use wayland_backend::{
client::{Backend, InvalidId, ObjectData, ObjectId, WaylandError},
protocol::Message,
};
use crate::{
protocol::{wl_display, wl_registry},
Connection, Dispatch, EventQueue, Proxy, QueueHandle,
};
pub fn registry_queue_init<State>(
conn: &Connection,
) -> Result<(GlobalList, EventQueue<State>), GlobalError>
where
State: Dispatch<wl_registry::WlRegistry, GlobalListContents> + 'static,
{
let event_queue = conn.new_event_queue();
let display = conn.display();
let data = Arc::new(RegistryState {
globals: GlobalListContents { contents: Default::default() },
handle: event_queue.handle(),
initial_roundtrip_done: AtomicBool::new(false),
});
let registry = display.send_constructor(wl_display::Request::GetRegistry {}, data.clone())?;
conn.roundtrip()?;
data.initial_roundtrip_done.store(true, Ordering::Relaxed);
Ok((GlobalList { registry }, event_queue))
}
#[derive(Debug)]
pub struct GlobalList {
registry: wl_registry::WlRegistry,
}
impl GlobalList {
pub fn contents(&self) -> &GlobalListContents {
self.registry.data::<GlobalListContents>().unwrap()
}
pub fn bind<I, State, U>(
&self,
qh: &QueueHandle<State>,
version: RangeInclusive<u32>,
udata: U,
) -> Result<I, BindError>
where
I: Proxy + 'static,
State: Dispatch<I, U> + 'static,
U: Send + Sync + 'static,
{
let version_start = *version.start();
let version_end = *version.end();
let interface = I::interface();
if *version.end() > interface.version {
panic!("Maximum version ({}) of {} was higher than the proxy's maximum version ({}); outdated wayland XML files?",
version.end(), interface.name, interface.version);
}
let globals = &self.registry.data::<GlobalListContents>().unwrap().contents;
let guard = globals.lock().unwrap();
let (name, version) = guard
.iter()
.filter_map(|Global { name, interface: interface_name, version }| {
if interface.name == &interface_name[..] {
Some((*name, *version))
} else {
None
}
})
.next()
.ok_or(BindError::NotPresent)?;
if version < version_start {
return Err(BindError::UnsupportedVersion);
}
let version = version.min(version_end);
Ok(self.registry.bind(name, version, qh, udata))
}
pub fn registry(&self) -> &wl_registry::WlRegistry {
&self.registry
}
}
#[derive(Debug)]
pub enum GlobalError {
Backend(WaylandError),
InvalidId(InvalidId),
}
impl std::error::Error for GlobalError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
GlobalError::Backend(source) => Some(source),
GlobalError::InvalidId(source) => std::error::Error::source(source),
}
}
}
impl std::fmt::Display for GlobalError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GlobalError::Backend(source) => {
write!(f, "Backend error: {source}")
}
GlobalError::InvalidId(source) => write!(f, "{source}"),
}
}
}
impl From<WaylandError> for GlobalError {
fn from(source: WaylandError) -> Self {
GlobalError::Backend(source)
}
}
impl From<InvalidId> for GlobalError {
fn from(source: InvalidId) -> Self {
GlobalError::InvalidId(source)
}
}
#[derive(Debug)]
pub enum BindError {
UnsupportedVersion,
NotPresent,
}
impl std::error::Error for BindError {}
impl fmt::Display for BindError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BindError::UnsupportedVersion {} => {
write!(f, "the requested version of the global is not supported")
}
BindError::NotPresent {} => {
write!(f, "the requested global was not found in the registry")
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Global {
pub name: u32,
pub interface: String,
pub version: u32,
}
#[derive(Debug)]
pub struct GlobalListContents {
contents: Mutex<Vec<Global>>,
}
impl GlobalListContents {
pub fn with_list<T, F: FnOnce(&[Global]) -> T>(&self, f: F) -> T {
let guard = self.contents.lock().unwrap();
f(&guard)
}
pub fn clone_list(&self) -> Vec<Global> {
self.contents.lock().unwrap().clone()
}
}
struct RegistryState<State> {
globals: GlobalListContents,
handle: QueueHandle<State>,
initial_roundtrip_done: AtomicBool,
}
impl<State: 'static> ObjectData for RegistryState<State>
where
State: Dispatch<wl_registry::WlRegistry, GlobalListContents>,
{
fn event(
self: Arc<Self>,
backend: &Backend,
msg: Message<ObjectId, OwnedFd>,
) -> Option<Arc<dyn ObjectData>> {
let conn = Connection::from_backend(backend.clone());
#[derive(Debug, Clone)]
enum Void {}
let msg: Message<ObjectId, Void> = msg.map_fd(|_| unreachable!());
let to_forward = if self.initial_roundtrip_done.load(Ordering::Relaxed) {
Some(msg.clone().map_fd(|v| match v {}))
} else {
None
};
let msg = msg.map_fd(|v| match v {});
if let Ok((_, event)) = wl_registry::WlRegistry::parse_event(&conn, msg) {
match event {
wl_registry::Event::Global { name, interface, version } => {
let mut guard = self.globals.contents.lock().unwrap();
guard.push(Global { name, interface, version });
}
wl_registry::Event::GlobalRemove { name: remove } => {
let mut guard = self.globals.contents.lock().unwrap();
guard.retain(|Global { name, .. }| name != &remove);
}
}
};
if let Some(msg) = to_forward {
self.handle
.inner
.lock()
.unwrap()
.enqueue_event::<wl_registry::WlRegistry, GlobalListContents>(msg, self.clone())
}
None
}
fn destroyed(&self, _id: ObjectId) {
}
fn data_as_any(&self) -> &dyn std::any::Any {
&self.globals
}
}