com-scrape-types 0.1.1

Support code for bindings generated with com-scrape
Documentation
use std::ops::Deref;
use std::ptr::addr_of;
use std::sync::Arc;

use super::{ComPtr, ComRef, Guid, Interface};

macro_rules! offset_of {
    ($struct:ty, $field:tt) => {{
        use ::std::ffi::c_void;
        use ::std::mem::MaybeUninit;
        use ::std::ptr::addr_of;

        let dummy = MaybeUninit::<$struct>::uninit();
        let base = dummy.as_ptr();
        let field = addr_of!((*base).$field);

        (field as *const c_void).offset_from(base as *const c_void)
    }};
}

/// Helper functionality used in generated virtual tables for Rust types.
///
/// The purpose of this trait is to allow the [`Construct`] implementations generated by
/// `com-scrape` to avoid hard-coding any particular reference counting logic or class layout, and
/// to instead allow this logic to be plugged in at the point where the COM object is constructed.
///
/// Currently, the only type implementing this trait is [`ComWrapper`], but it would be possible
/// for additional wrapper types to implement it in the future.
pub trait Wrapper<C: Class + ?Sized> {
    /// Given a pointer to an object's header, returns a pointer to the object itself.
    unsafe fn data_from_header(ptr: *mut Header<C>) -> *mut C;

    /// Given a pointer to an object, returns a pointer to the object's header.
    unsafe fn header_from_data(ptr: *mut C) -> *mut Header<C>;

    /// Increments the reference count of an object and returns the resulting count.
    unsafe fn add_ref(ptr: *mut C) -> usize;

    /// Decrements the reference count of an object and returns the resulting count.
    unsafe fn release(ptr: *mut C) -> usize;
}

/// Generates the virtual table and base class object for a given class and interface.
///
/// The `W` parameter is the wrapper type which provides helper functionality for reference
/// counting and class layout. The `OFFSET` parameter is the offset, within the header, at which
/// the constructed base class object will be located.
///
/// This trait is used in the implementations of [`MakeHeader`] for tuples of interfaces. Interface
/// types generated by `com-scrape` will implement this trait.
///
/// # Safety
///
/// When `I::OBJ` is reinterpreted as `*const I::Vtbl` (see the
/// [safety documentation](Interface#safety) for [`Interface`]), it must be a valid pointer to an
/// instance of `I::Vtbl`.
pub unsafe trait Construct<C: Class, W: Wrapper<C>, const OFFSET: isize>: Interface {
    /// The generated base class object.
    const OBJ: Self;
}

/// A list of COM interfaces implemented by a Rust type.
///
/// Provides a header type containing base class objects for each interface in the list, as well as
/// a method for querying whether an interface is a member of the list and, if so, at what offset
/// the corresponding base class object is located in the header.
///
/// This trait is implemented for tuples of interface types.
///
/// # Safety
///
/// If `L::query(I::IID)` returns `Some(offset)` for an [`Interface`] `I`, then whenever
/// `ptr: *mut L::Header` points to a valid instance of `L::Header`,
/// `(ptr as *mut u8).offset(offset) as *mut I` must point to a valid instance of `I`.
pub unsafe trait InterfaceList {
    /// Header type containing a base class object for each of the interfaces in the list.
    type Header;

    /// If there is an interface in the list whose GUID equals `iid`, or which transitively derives
    /// from an interface whose GUID equals `iid`, `query` returns the offset of the corresponding
    /// base class object within `Self::Header`.
    fn query(iid: &Guid) -> Option<isize>;
}

/// Generates the object header for a given class and list of interfaces.
///
/// This trait is implemented for tuples of interface types, and it is used by [`ComWrapper`] to
/// construct the object header for a given Rust value.
///
/// # Safety
///
/// If `L::query(I::IID)` returns `Some(offset)` for an [`Interface`] `I` (see [`InterfaceList`]),
/// the `I` value located at offset `offset` within `L::Header` must (when reinterpreted as
/// `*const I::Vtbl`) be a valid pointer to an instance of `I::Vtbl`.
pub unsafe trait MakeHeader<C, W>: InterfaceList
where
    C: Class,
    W: Wrapper<C>,
{
    const HEADER: Self::Header;
}

/// A Rust type that defines a COM class.
///
/// Must be implemented for a type to be used with [`ComWrapper`].
pub trait Class {
    /// The list of interfaces implemented by this Rust type.
    ///
    /// Should be given as a tuple, e.g.:
    ///
    /// ```ignore
    /// struct MyClass;
    ///
    /// impl Class for MyClass {
    ///     type Interfaces = (ISomeInterface, IAnotherInterface);
    /// }
    /// ```
    type Interfaces: InterfaceList;
}

/// Convenience alias for getting the object header of a [`Class`].
pub type Header<C> = <<C as Class>::Interfaces as InterfaceList>::Header;

macro_rules! interface_list {
    ($header:ident, $($interface:ident $index:tt),*) => {
        #[repr(C)]
        pub struct $header<$($interface),*>($($interface),*);

        unsafe impl<$($interface: Interface),*> InterfaceList for ($($interface,)*) {
            type Header = $header<$($interface),*>;

            fn query(iid: &Guid) -> Option<isize> {
                $(
                    if $interface::inherits(iid) {
                        return Some($index * std::mem::size_of::<*mut ()>() as isize);
                    }
                )*

                None
            }
        }

        unsafe impl<C, W $(, $interface)*> MakeHeader<C, W> for ($($interface,)*)
        where
            C: Class,
            W: Wrapper<C>,
            $($interface: Construct<C, W, { $index * std::mem::size_of::<*mut ()>() as isize }>,)*
        {
            const HEADER: Self::Header = $header($($interface::OBJ),*);
        }
    }
}

interface_list!(Header1, I0 0);
interface_list!(Header2, I0 0, I1 1);
interface_list!(Header3, I0 0, I1 1, I2 2);
interface_list!(Header4, I0 0, I1 1, I2 2, I3 3);
interface_list!(Header5, I0 0, I1 1, I2 2, I3 3, I4 4);
interface_list!(Header6, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5);
interface_list!(Header7, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6);
interface_list!(Header8, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7);
interface_list!(Header9, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8);
interface_list!(Header10, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9);
interface_list!(Header11, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10);
interface_list!(Header12, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11);
interface_list!(Header13, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12);
interface_list!(Header14, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12, I13 13);
interface_list!(Header15, I0 0, I1 1, I2 2, I3 3, I4 4, I5 5, I6 6, I7 7, I8 8, I9 9, I10 10, I11 11, I12 12, I13 13, I14 14);

#[repr(C)]
struct ComWrapperInner<C: Class> {
    header: Header<C>,
    data: C,
}

/// A wrapper for constructing a reference-counted COM object from a Rust value.
///
/// `ComWrapper` represents an owning reference to the COM object, i.e. it will decrement the
/// object's reference count when it goes out of scope.
pub struct ComWrapper<C: Class> {
    inner: Arc<ComWrapperInner<C>>,
}

impl<C: Class> Clone for ComWrapper<C> {
    fn clone(&self) -> ComWrapper<C> {
        ComWrapper {
            inner: self.inner.clone(),
        }
    }
}

unsafe impl<C: Class> Send for ComWrapper<C> where C: Send + Sync {}
unsafe impl<C: Class> Sync for ComWrapper<C> where C: Send + Sync {}

impl<C: Class> Deref for ComWrapper<C> {
    type Target = C;

    #[inline]
    fn deref(&self) -> &Self::Target {
        &self.inner.data
    }
}

impl<C: Class> Wrapper<C> for ComWrapper<C> {
    #[inline]
    unsafe fn data_from_header(ptr: *mut Header<C>) -> *mut C {
        (ptr as *mut u8)
            .offset(-offset_of!(ComWrapperInner<C>, header))
            .offset(offset_of!(ComWrapperInner<C>, data)) as *mut C
    }

    #[inline]
    unsafe fn header_from_data(ptr: *mut C) -> *mut Header<C> {
        (ptr as *mut u8)
            .offset(-offset_of!(ComWrapperInner<C>, data))
            .offset(offset_of!(ComWrapperInner<C>, header)) as *mut Header<C>
    }

    #[inline]
    unsafe fn add_ref(ptr: *mut C) -> usize {
        let wrapper_ptr = (ptr as *mut u8).offset(-offset_of!(ComWrapperInner<C>, data))
            as *mut ComWrapperInner<C>;

        let arc = Arc::from_raw(wrapper_ptr);
        let result = Arc::strong_count(&arc) + 1;
        let _ = Arc::into_raw(arc);

        Arc::increment_strong_count(wrapper_ptr);

        result
    }

    #[inline]
    unsafe fn release(ptr: *mut C) -> usize {
        let wrapper_ptr = (ptr as *mut u8).offset(-offset_of!(ComWrapperInner<C>, data))
            as *mut ComWrapperInner<C>;

        let arc = Arc::from_raw(wrapper_ptr);
        let result = Arc::strong_count(&arc) - 1;
        let _ = Arc::into_raw(arc);

        Arc::decrement_strong_count(wrapper_ptr);

        result
    }
}

impl<C: Class> ComWrapper<C> {
    /// Allocates memory for an object and its header and places `data` into it.
    #[inline]
    pub fn new(data: C) -> ComWrapper<C>
    where
        C: 'static,
        C::Interfaces: MakeHeader<C, Self>,
    {
        ComWrapper {
            inner: Arc::new(ComWrapperInner {
                header: C::Interfaces::HEADER,
                data,
            }),
        }
    }

    /// If `I` is in `C`'s interface list, returns a [`ComRef<I>`] pointing to the object.
    ///
    /// Does not perform any reference counting operations.
    #[inline]
    pub fn as_com_ref<'a, I: Interface>(&'a self) -> Option<ComRef<'a, I>> {
        if let Some(offset) = C::Interfaces::query(&I::IID) {
            unsafe {
                let wrapper_ptr = Arc::as_ptr(&self.inner) as *mut ComWrapperInner<C>;
                let interface_ptr = (wrapper_ptr as *mut u8)
                    .offset(offset_of!(ComWrapperInner<C>, header))
                    .offset(offset) as *mut I;
                Some(ComRef::from_raw_unchecked(interface_ptr))
            }
        } else {
            None
        }
    }

    /// If `I` is in `C`'s interface list, returns a [`ComPtr<I>`] pointing to the object.
    ///
    /// If a [`ComPtr`] is returned, the object's reference count will be incremented.
    #[inline]
    pub fn to_com_ptr<I: Interface>(&self) -> Option<ComPtr<I>> {
        if let Some(offset) = C::Interfaces::query(&I::IID) {
            unsafe {
                let wrapper_ptr = Arc::into_raw(self.inner.clone()) as *mut ComWrapperInner<C>;
                let interface_ptr = (wrapper_ptr as *mut u8)
                    .offset(offset_of!(ComWrapperInner<C>, header))
                    .offset(offset) as *mut I;
                Some(ComPtr::from_raw_unchecked(interface_ptr))
            }
        } else {
            None
        }
    }
}