waybackend 0.10.1

A simple, low-level wayland client implementation
Documentation
use ::alloc::*;
use core::{num::NonZeroU32, ptr::NonNull};

use crate::types::ObjectId;

/// Object Manager for creating, removing, and maintaining Wayland Objects
///
/// `T` should be a simple enum, like so:
/// ```
/// #[derive(Clone, Copy, PartialEq)]
/// enum WaylandProtocol {
///     Display,
///     Registry,
///     LayerShell,
///     //...
/// }
/// ```
///
/// **It is important not to associate any data with this enum type, as it will greatly increase
/// our memory requirements for storing it.** To drive this point home we demand the users to
/// implement `Copy`, implying this type must be simply enough to be trivially copiable.
///
/// ### Why do it like this?
///
/// This is the lowest-overhead solution I could come up with. If the application developer knows
/// exactly with which Wayland objects it will interact, we have to need to create a generic object
/// manager that can deal with every possible use case. We can just make one that will deal with
/// just the objects that the developer actually cares about.
pub struct ObjectManager<T: Copy + PartialEq> {
    /// stores the object types. The position in this vector + 1 is the object id
    /// for example, if objects[1] == LayerSurface, then the object of id 2 is of
    /// the type "LayerSurface". This is because 0 is null.
    objects: NonNull<Option<T>>,
    objects_cap: u32,
    /// the next id we ought to generate
    next: u32,

    /// stores the object types created by the server. The position in this vector + 0xFF000000 is
    /// the object id. This is because the objects allocated by the server start at id 0xFF000000
    server_objects: NonNull<Option<T>>,
    server_objects_cap: u32,
    /// the next id we ought to generate from the server
    server_next: u32,
}

impl<T: Copy + PartialEq> ObjectManager<T> {
    pub(crate) fn new(display: T) -> Self {
        let objects = allocate::<Option<T>, 1>(Some(display));
        let server_objects = allocate::<Option<T>, 1>(None);
        Self {
            objects,
            objects_cap: 1,
            next: 1,
            server_objects,
            server_objects_cap: 1,
            server_next: 0,
        }
    }

    pub fn create_from_server(&mut self, object: T) {
        if self.server_next == self.server_objects_cap {
            grow(&mut self.server_objects, &mut self.server_objects_cap, None);
        }

        unsafe {
            self.server_objects
                .add(self.server_next as usize)
                .write(Some(object))
        }

        // update next to the next available id
        self.server_next += 1;
        while self.server_next < self.server_objects_cap {
            if unsafe {
                self.server_objects
                    .add(self.server_next as usize)
                    .read()
                    .is_none()
            } {
                break;
            }
            self.server_next += 1;
        }
    }

    /// Returns the first objectId of the specified variant
    #[must_use]
    pub fn get_first(&self, variant: T) -> Option<ObjectId> {
        for i in 0..self.objects_cap as usize {
            if unsafe { self.objects.add(i).read() }.is_some_and(|obj| obj == variant) {
                return Some(ObjectId::new(NonZeroU32::new(i as u32 + 1).unwrap()));
            }
        }
        None
    }

    /// Gets all the objects that match the specified variant
    ///
    /// You can use this to, for example, find all binded outputs after initialization.
    pub fn get_all(&self, variant: T) -> impl Iterator<Item = ObjectId> {
        let objects = unsafe {
            core::slice::from_raw_parts(self.objects.as_ptr(), self.objects_cap as usize)
        };
        objects.iter().enumerate().filter_map(move |(i, obj)| {
            if obj.is_some_and(|obj| obj == variant) {
                Some(ObjectId::new(NonZeroU32::new(i as u32 + 1).unwrap()))
            } else {
                None
            }
        })
    }

    /// get the type of the wayland object from its id
    ///
    /// Returns
    ///   * 'Some(T)' if the object still exists
    ///   * 'None' if the object was already deleted, or wasn't registered
    #[must_use]
    pub fn get(&self, object_id: ObjectId) -> Option<T> {
        if object_id.created_by_server() {
            let pos = object_id.get().get() as usize - 0xFF000000;
            if pos < self.server_objects_cap as usize {
                return unsafe { self.server_objects.add(pos).read() };
            }
        } else {
            let pos = object_id.get().get() as usize - 1;
            if pos < self.objects_cap as usize {
                return unsafe { self.objects.add(pos).read() };
            }
        }
        None
    }

    /// creates a new Id to use in requests
    #[must_use]
    pub fn create(&mut self, object: T) -> ObjectId {
        if self.next == self.objects_cap {
            grow(&mut self.objects, &mut self.objects_cap, None);
        }

        let i = self.next as usize;
        unsafe { self.objects.add(self.next as usize).write(Some(object)) }

        // update next to the next available id
        self.next += 1;
        while self.next < self.objects_cap {
            if unsafe { self.objects.add(self.next as usize).read().is_none() } {
                break;
            }
            self.next += 1;
        }

        // SAFETY: we are adding one to make sure
        ObjectId::new(unsafe { NonZeroU32::new_unchecked(i as u32 + 1) })
    }

    /// removes the wayland object, if it still exists
    pub fn remove(&mut self, object_id: u32) -> Option<T> {
        if let Ok(object_id) = ObjectId::try_new(object_id) {
            if object_id.created_by_server() {
                let pos = object_id.get().get() - 0xFF000000;
                if pos < self.server_objects_cap {
                    let ret = unsafe { self.server_objects.add(pos as usize).read() };
                    unsafe { self.server_objects.add(pos as usize).write(None) };
                    if pos < self.server_next {
                        self.server_next = pos;
                    }
                    return ret;
                }
            } else {
                let pos = object_id.get().get() - 1;
                if pos < self.objects_cap {
                    let ret = unsafe { self.objects.add(pos as usize).read() };
                    unsafe { self.objects.add(pos as usize).write(None) };
                    if pos < self.next {
                        self.next = pos;
                    }
                    return ret;
                }
            }
        }
        None
    }
}

fn allocate<T: Copy, const INITIAL_SIZE: usize>(fill: T) -> NonNull<T> {
    const { assert!(INITIAL_SIZE > 0) }

    let layout = alloc::Layout::array::<T>(INITIAL_SIZE).unwrap();
    let ptr = unsafe { alloc::alloc(layout) }.cast::<T>();
    match NonNull::new(ptr) {
        Some(ptr) => {
            for i in 0..INITIAL_SIZE {
                unsafe { ptr.add(i).write(fill) }
            }
            ptr
        }
        None => alloc::handle_alloc_error(layout),
    }
}

#[cold]
fn grow<T: Copy>(ptr: &mut NonNull<T>, cap: &mut u32, fill: T) {
    let old_cap = *cap;
    let layout = alloc::Layout::array::<T>(*cap as usize).unwrap();

    let new_cap = *cap * 2;
    let new_layout = alloc::Layout::array::<T>(new_cap as usize).unwrap();

    let old_ptr = ptr.as_ptr() as *mut u8;
    let new_ptr = unsafe { alloc::realloc(old_ptr, layout, new_layout.size()) }.cast::<T>();

    match NonNull::new(new_ptr) {
        Some(new_ptr) => {
            for i in old_cap..new_cap {
                unsafe { new_ptr.add(i as usize).write(fill) }
            }
            *cap = new_cap;
            *ptr = new_ptr
        }
        None => alloc::handle_alloc_error(new_layout),
    }
}

fn deallocate<T>(ptr: NonNull<T>, cap: u32) {
    let layout = alloc::Layout::array::<T>(cap as usize).unwrap();
    let ptr = ptr.as_ptr() as *mut u8;
    unsafe { alloc::dealloc(ptr, layout) }
}

impl<T: Copy + PartialEq> Drop for ObjectManager<T> {
    fn drop(&mut self) {
        deallocate(self.objects, self.objects_cap);
        deallocate(self.server_objects, self.server_objects_cap);
    }
}

#[cfg(test)]
mod tests {
    extern crate std;
    use std::vec::Vec;

    use super::*;

    #[derive(Clone, Copy, Debug, PartialEq)]
    enum DummyProtocol {
        Display,
        Region,
        Surface,
    }

    fn obj_from_u32(u: u32) -> ObjectId {
        ObjectId::new(NonZeroU32::new(u).unwrap())
    }

    fn server_objects<T: Copy + PartialEq>(objman: &ObjectManager<T>) -> &[Option<T>] {
        unsafe {
            core::slice::from_raw_parts(
                objman.server_objects.as_ptr(),
                objman.server_objects_cap as usize,
            )
        }
    }

    #[test]
    fn creating_object_ids() {
        let mut manager = ObjectManager::new(DummyProtocol::Display);
        let id1 = manager.create(DummyProtocol::Region);
        assert_eq!(id1, obj_from_u32(2));
        let id2 = manager.create(DummyProtocol::Region);
        assert_eq!(id2, obj_from_u32(3));
        let id3 = manager.create(DummyProtocol::Region);
        assert_eq!(id3, obj_from_u32(4));

        manager.remove(id2.get().get());
        let id4 = manager.create(DummyProtocol::Region);
        assert_eq!(id4, id2);

        manager.remove(id1.get().get());
        let id5 = manager.create(DummyProtocol::Region);
        assert_eq!(id5, id1);

        manager.remove(id2.get().get());
        manager.remove(id1.get().get());
        let id6 = manager.create(DummyProtocol::Region);
        assert_eq!(id6, id1);

        let id7 = manager.create(DummyProtocol::Region);
        assert_eq!(id7, id2);
    }

    #[test]
    fn get_all() {
        let mut manager = ObjectManager::new(DummyProtocol::Display);
        let id1 = manager.create(DummyProtocol::Region);
        let id2 = manager.create(DummyProtocol::Surface);
        let id3 = manager.create(DummyProtocol::Region);
        let id4 = manager.create(DummyProtocol::Region);
        let id5 = manager.create(DummyProtocol::Surface);
        let id6 = manager.create(DummyProtocol::Surface);
        let id7 = manager.create(DummyProtocol::Region);
        let id8 = manager.create(DummyProtocol::Region);
        let id9 = manager.create(DummyProtocol::Region);
        let id10 = manager.create(DummyProtocol::Surface);

        let regions: Vec<ObjectId> = manager.get_all(DummyProtocol::Region).collect();
        let surfaces: Vec<ObjectId> = manager.get_all(DummyProtocol::Surface).collect();

        assert_eq!(&regions, &[id1, id3, id4, id7, id8, id9]);
        assert_eq!(&surfaces, &[id2, id5, id6, id10]);
    }

    #[test]
    fn creating_object_ids_from_server() {
        let mut manager = ObjectManager::new(DummyProtocol::Display);
        manager.create_from_server(DummyProtocol::Region);
        manager.create_from_server(DummyProtocol::Region);
        manager.create_from_server(DummyProtocol::Region);

        assert_eq!(
            &server_objects(&manager)[..3],
            &[Some(DummyProtocol::Region); 3]
        );
        assert_eq!(manager.server_next, 3);

        manager.remove(0xFF000002);
        assert_eq!(manager.server_next, 2);

        manager.create_from_server(DummyProtocol::Region);
        assert_eq!(
            &server_objects(&manager)[..3],
            &[Some(DummyProtocol::Region); 3]
        );
        assert_eq!(manager.server_next, 3);

        manager.remove(0xFF000001);
        assert_eq!(manager.server_next, 1);
        manager.create_from_server(DummyProtocol::Region);
        assert_eq!(
            &server_objects(&manager)[..3],
            &[Some(DummyProtocol::Region); 3]
        );
        assert_eq!(manager.server_next, 3);

        manager.remove(0xFF000001);
        manager.remove(0xFF000002);
        assert_eq!(manager.server_next, 1);
        manager.create_from_server(DummyProtocol::Region);
        assert_eq!(
            &server_objects(&manager)[..2],
            &[Some(DummyProtocol::Region); 2]
        );
        assert_eq!(manager.server_next, 2);

        manager.create_from_server(DummyProtocol::Region);
        assert_eq!(
            &server_objects(&manager)[..3],
            &[Some(DummyProtocol::Region); 3]
        );
        assert_eq!(manager.server_next, 3);
    }
}