waybackend 0.10.1

A simple, low-level wayland client implementation
Documentation
//! This module implements a [German String](https://cedardb.com/blog/german_strings/) specialized
//! for the Wayland protocol. See [GermanString] for more details.
//!
//! Note we currently emphasize memory efficiency, and do not focus particularly on string
//! comparison performance[^perf]. Our particular implementation therefore does **NOT** correspond
//! to what is described in the above link, though it follows the same principles.
//! 
//! [^perf]: Although, generally speaking not having to follow an extra pointer to do string
//! comparisons and using less memory overall should still lead to performance wins.

use ::alloc::alloc::{Layout, alloc, dealloc, handle_alloc_error};
use core::mem;
use core::ops::Deref;
use core::ptr::NonNull;

const PTR_SIZE: usize = mem::size_of::<*const u8>();
const PAD_SIZE: usize = PTR_SIZE.saturating_sub(mem::size_of::<u16>());

#[repr(C)]
union ArrOrPtr {
    ptr: NonNull<u8>,
    arr: [u8; PTR_SIZE],
}

/// This is a GermanString optimized for Wayland Strings. In particular, Wayland strings cannot be
/// longer or equal to 4096 byte (the maximum message size). This means we can encode their length
/// in just 12 bits. And so, in a 64bit machine, we can store the a total of 14 bytes in the stack
/// before we need to spill it to the heap.
///
/// This particular implementation is not currently focused on speeding up string comparisons.
/// This may change in the future, but for now the main use case is improving memory efficiency.
#[repr(C)]
pub struct GermanString {
    len: u16,
    pad: [u8; PAD_SIZE],
    ptr: ArrOrPtr,
}

impl GermanString {
    /// Returns `None` if `s.len() >= 4096`
    ///
    /// You can safely call [unwrap](Option::unwrap) for any `&str` you received in a wayland event
    /// handler implementation.
    #[inline]
    pub fn new(s: &str) -> Option<Self> {
        if s.len() >= 4096 {
            None
        } else {
            Some(unsafe { Self::new_unsafe(s) })
        }
    }

    /// # Safety
    ///
    /// Only call this if you are sure `s.len() < 4096`. In [waybackend](crate), this is true for
    /// any `&str` you received in a wayland event handler.
    #[inline]
    pub unsafe fn new_unsafe(s: &str) -> Self {
        if s.len() <= PTR_SIZE + PAD_SIZE {
            let mut bytes = [0u8; mem::size_of::<Self>()];
            unsafe {
                let dst = bytes.as_mut_ptr();
                dst.cast::<u16>().write(s.len() as u16);
                core::ptr::copy_nonoverlapping(s.as_ptr(), dst.add(2), s.len());
                mem::transmute::<[u8; mem::size_of::<Self>()], Self>(bytes)
            }
        } else {
            let layout = Layout::array::<u8>(s.len()).unwrap();
            let ptr = match NonNull::new(unsafe { alloc(layout) }) {
                Some(ptr) => ptr,
                None => handle_alloc_error(layout),
            };

            unsafe { core::ptr::copy_nonoverlapping(s.as_ptr(), ptr.as_ptr(), s.len()) };

            Self {
                len: s.len() as u16,
                pad: [0; PAD_SIZE],
                ptr: ArrOrPtr { ptr },
            }
        }
    }

    /// Returns the prefix len and padding for fast comparisons
    #[inline]
    fn prefix(&self) -> [u8; PTR_SIZE] {
        unsafe {
            mem::transmute::<&Self, &[u8; mem::size_of::<Self>()]>(self)
                .get_unchecked(..PTR_SIZE)
                .try_into()
                .unwrap_unchecked()
        }
    }

    /// Assumes `self.ptr` is currently a pointer
    #[inline]
    unsafe fn ptr_slice(&self) -> &[u8] {
        let len = self.len as usize;
        unsafe { core::slice::from_raw_parts(self.ptr.ptr.as_ptr(), len) }
    }
}

impl Drop for GermanString {
    #[inline]
    fn drop(&mut self) {
        if self.len > (PAD_SIZE + PTR_SIZE) as u16 {
            let layout = Layout::array::<u8>(self.len as usize).unwrap();
            unsafe { dealloc(self.ptr.ptr.as_ptr(), layout) };
        }
    }
}

unsafe impl Send for GermanString {}
unsafe impl Sync for GermanString {}

impl Deref for GermanString {
    type Target = str;

    #[inline]
    fn deref(&self) -> &Self::Target {
        if self.len <= (PAD_SIZE + PTR_SIZE) as u16 {
            unsafe {
                let bytes = mem::transmute::<&Self, &[u8; mem::size_of::<Self>()]>(self);
                core::str::from_utf8_unchecked(bytes.get_unchecked(2..2 + self.len as usize))
            }
        } else {
            unsafe {
                let slice = core::slice::from_raw_parts(self.ptr.ptr.as_ptr(), self.len as usize);
                core::str::from_utf8_unchecked(slice)
            }
        }
    }
}

impl AsRef<str> for GermanString {
    #[inline]
    fn as_ref(&self) -> &str {
        Deref::deref(self)
    }
}

impl core::fmt::Debug for GermanString {
    #[inline]
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.debug_struct("GermanString")
            .field("len", &self.len)
            .field("pad", &self.pad)
            .field(
                "ptr",
                if self.len <= (PAD_SIZE + PTR_SIZE) as u16 {
                    unsafe { &self.ptr.arr }
                } else {
                    unsafe { &self.ptr.ptr }
                },
            )
            .finish()
    }
}

impl core::fmt::Display for GermanString {
    #[inline]
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.write_str(Deref::deref(self))
    }
}

impl PartialEq<GermanString> for GermanString {
    #[inline]
    fn eq(&self, other: &GermanString) -> bool {
        if self.prefix() != other.prefix() {
            false
        } else if self.len <= (PAD_SIZE + PTR_SIZE) as u16 {
            unsafe { self.ptr.arr == other.ptr.arr }
        } else {
            unsafe {
                let a = self.ptr_slice();
                let b = other.ptr_slice();
                PartialEq::eq(a, b)
            }
        }
    }
}

impl Eq for GermanString {}

impl Ord for GermanString {
    #[inline]
    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
        self.prefix().cmp(&other.prefix()).then_with(|| {
            // Note: if we got here, the prefixes are equal
            if self.len <= PTR_SIZE as u16 {
                core::cmp::Ordering::Equal
            } else if self.len <= (PAD_SIZE + PTR_SIZE) as u16 {
                unsafe { self.ptr.arr.cmp(&other.ptr.arr) }
            } else {
                unsafe {
                    let a = self.ptr_slice();
                    let b = other.ptr_slice();
                    Ord::cmp(a, b)
                }
            }
        })
    }
}

impl PartialOrd for GermanString {
    #[inline]
    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Default for GermanString {
    #[inline]
    fn default() -> Self {
        Self {
            len: Default::default(),
            pad: Default::default(),
            ptr: ArrOrPtr {
                arr: Default::default(),
            },
        }
    }
}

impl core::str::FromStr for GermanString {
    type Err = ();

    #[inline]
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Self::new(s).ok_or(())
    }
}

impl TryFrom<&str> for GermanString {
    type Error = ();

    #[inline]
    fn try_from(s: &str) -> Result<Self, Self::Error> {
        Self::new(s).ok_or(())
    }
}

impl TryFrom<&mut str> for GermanString {
    type Error = ();

    #[inline]
    fn try_from(s: &mut str) -> Result<Self, Self::Error> {
        Self::new(s).ok_or(())
    }
}

impl core::borrow::Borrow<str> for GermanString {
    #[inline]
    fn borrow(&self) -> &str {
        Deref::deref(self)
    }
}

impl core::hash::Hash for GermanString {
    #[inline]
    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
        Deref::deref(self).hash(state)
    }
}

impl PartialEq<str> for GermanString {
    #[inline]
    fn eq(&self, other: &str) -> bool {
        Deref::deref(self).eq(other)
    }
}

impl<'a> PartialEq<&'a str> for GermanString {
    #[inline]
    fn eq(&self, other: &&'a str) -> bool {
        Deref::deref(self).eq(*other)
    }
}

impl PartialEq<GermanString> for str {
    #[inline]
    fn eq(&self, other: &GermanString) -> bool {
        self.eq(Deref::deref(other))
    }
}

impl PartialEq<GermanString> for &str {
    #[inline]
    fn eq(&self, other: &GermanString) -> bool {
        self.eq(&Deref::deref(other))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn str_is_too_large() {
        assert!(GermanString::new(&"a".repeat(4096)).is_none());
    }

    #[test]
    fn inline_str_comparisons() {
        let s1 = GermanString::new("hello world").unwrap();
        assert_eq!(s1, "hello world");
        assert_eq!(s1, GermanString::new("hello world").unwrap());

        assert_ne!(s1, "hello worl");
        assert_ne!(s1, "hello_world");
        assert_ne!(s1, "hello worlt");
        assert_ne!(s1, "iello world");

        assert_ne!(s1, GermanString::new("hello worl").unwrap());
        assert_ne!(s1, GermanString::new("hello_world").unwrap());
        assert_ne!(s1, GermanString::new("hello worlt").unwrap());
        assert_ne!(s1, GermanString::new("iello world").unwrap());

        let s1 = GermanString::new("hello").unwrap();
        assert_eq!(s1, "hello");
        assert_eq!(s1, GermanString::new("hello").unwrap());

        assert_ne!(s1, "hell");
        assert_ne!(s1, "hellu");
        assert_ne!(s1, "iello");
        assert_ne!(s1, GermanString::new("hell").unwrap());
        assert_ne!(s1, GermanString::new("hellu").unwrap());
        assert_ne!(s1, GermanString::new("iello").unwrap());
    }

    #[test]
    fn heap_str_comparisons() {
        let mut s = "hello world".repeat(10);
        let s1 = GermanString::new(&s).unwrap();
        assert_eq!(s1, s.as_str());
        assert_eq!(s1, GermanString::new(&s).unwrap());

        unsafe { *s.as_bytes_mut().last_mut().unwrap() -= 1 };
        assert_ne!(s1, s.as_str());
        assert_ne!(s1, GermanString::new(&s).unwrap());
    }
}