nucomcore 0.1.1

Nucom, cross-platform Component Object Model implementation
Documentation
use crate::idl::iunknown::IUnknownUnsafeExt;
use crate::IUnknown;
use std::cell::Cell;
use std::marker::PhantomData;
use std::ptr::NonNull;

#[cfg(doc)]
use crate::idl::iunknown::IUnknownVtbl;

mod private {
    pub trait MtKind {}

    pub trait SafetyKind {}
}

/// A marker trait for types that determine the threading style of a COM
/// reference.
///
/// See [`Apartment`] and [`Multi`].
pub trait MtKind: private::MtKind {}

impl<T> MtKind for T where T: private::MtKind {}

/// Apartment threaded COM object, i.e. generally not thread safe.
pub struct Apartment {
    _v: *const (),
}

impl private::MtKind for Apartment {}

// There's no such thing as Send + !Sync: Sync can be emulated by AddRef-ing the
// COM object and then sending it to the other thread.
/// Multi-threaded COM object, which can be freely shared between threads.
pub struct Multi {
    _v: (),
}

impl private::MtKind for Multi {}

/// A marker trait for types that determine the safety of a COM reference.
///
/// See [`Safe`] and [`Unsafe`].
pub trait SafetyKind: private::SafetyKind {}

impl<T> SafetyKind for T where T: private::SafetyKind {}

/// The COM object behind this pointer is safe.
pub struct Safe {
    _v: (),
}

impl private::SafetyKind for Safe {}

/// The COM object behind this pointer is unsafe.
pub struct Unsafe {
    _v: (),
}

impl private::SafetyKind for Unsafe {}

/// Declares this type as a COM interface. This is implemented by nuidl on the
/// base interface type, e.g. [`IUnknown`].
pub trait ComInterface {
    type Vtbl;
    type Mt: MtKind;
}

/// Asserts that casting a `*mut Self` pointer to `*mut T` results in a valid
/// COM interface pointer. This is automatically implemented by nuidl, you
/// shouldn't implement this yourself.
///
/// This interface is implemented transitively, i.e. if `T: ComHierarchy<U>` and
/// `U: ComHierarchy<V>`, then `T: ComHierarchy<V>`.
///
/// # Safety
///
/// This is used to safely implement [`ComRef::upcast`]. As a result, casting a
/// `*mut Self` pointer to `*mut T` must result in a valid COM interface
/// pointer.
/// Otherwise, a safe reference `ComRef<Self>` may be casted to a safe reference
/// `ComRef<T>` which then results in wrong method calls.
pub unsafe trait ComHierarchy<T>: ComInterface {}

/// The nucom analog to [`core::ops::Deref`], returning [`ComRef<T>`] instead of
/// `&T`.
pub unsafe trait ComDeref {
    type Target: ComInterface;
    type Mt: MtKind;
    type Safety: SafetyKind;

    /// Dereferences `self` to return the underlying COM reference.
    fn com_deref(&self) -> ComRef<Self::Target, Self::Mt, Self::Safety>;
}

unsafe impl<T: ComInterface> ComDeref for NonNull<T> {
    type Target = T;
    type Mt = Apartment;
    type Safety = Unsafe;

    fn com_deref(&self) -> ComRef<Self::Target, Self::Mt, Self::Safety> {
        ComRef::new(*self)
    }
}

/// Raw reference for COM objects. This does not manage the COM object (i.e.
/// call AddRef/Release).
#[repr(transparent)]
pub struct ComRef<T, M: MtKind = Apartment, S: SafetyKind = Unsafe> {
    ptr: NonNull<T>,
    _marker: PhantomData<(Cell<T>, M, S)>,
}

impl<T: ComInterface> ComRef<T, Apartment, Unsafe> {
    /// Constructs a `ComRef` using apartment threading semantics and exposing
    /// the unsafe interface.
    #[inline]
    pub const fn new(ptr: NonNull<T>) -> Self {
        unsafe { ComRef::new_with_characteristics(ptr) }
    }

    /// Constructs a `ComRef` using the interface's default threading semantics
    /// and exposing the unsafe interface.
    ///
    /// # Safety
    ///
    /// If [`T::Mt`](ComInterface::Mt) is [`Multi`], the underlying COM object must support being
    /// shared between threads.
    #[inline]
    pub const fn new_with_default_threading(ptr: NonNull<T>) -> ComRef<T, T::Mt, Unsafe> {
        unsafe { ComRef::new_with_characteristics(ptr) }
    }
}

impl<T: ComInterface, M: MtKind, S: SafetyKind> ComRef<T, M, S> {
    #[inline]
    const unsafe fn new_with_characteristics(ptr: NonNull<T>) -> Self {
        ComRef {
            ptr,
            _marker: PhantomData,
        }
    }

    /// Returns this `ComRef` as having apartment threading semantics, i.e.
    /// it cannot be shared between threads.
    #[inline]
    pub const fn into_apartment_threaded(self) -> ComRef<T, Apartment, S> {
        unsafe { ComRef::new_with_characteristics(self.ptr) }
    }

    /// Returns this `ComRef` as having multithreaded semantics, i.e. it can
    /// be freely shared between threads. The underlying COM object must
    /// explicitly support this.
    #[inline]
    pub const unsafe fn into_multi_threaded(self) -> ComRef<T, Multi, S> {
        unsafe { ComRef::new_with_characteristics(self.ptr) }
    }

    /// Returns this `ComRef` with the unsafe function interface.
    #[inline]
    pub const fn into_unsafe(self) -> ComRef<T, M, Unsafe> {
        unsafe { ComRef::new_with_characteristics(self.ptr) }
    }

    /// Returns this `ComRef with the safe function interface.
    ///
    /// # Safety
    ///
    /// You must ensure that all of the underlying COM object's methods are safe
    /// to call with any argument.
    #[inline]
    pub const unsafe fn into_safe(self) -> ComRef<T, M, Safe> {
        ComRef::new_with_characteristics(self.ptr)
    }

    /// Upcasts the COM reference to a base interface `U`.
    #[inline]
    pub const fn upcast<U>(&self) -> ComRef<U, M, S>
    where
        T: ComHierarchy<U>,
        U: ComInterface,
    {
        // SAFETY: T is a COM interface which has U as one of its base
        // interfaces. As such, a pointer to T is also a valid pointer to U.
        unsafe { ComRef::new_with_characteristics(self.ptr.cast::<U>()) }
    }

    /// Returns the underlying non-null pointer to T.
    #[inline]
    pub const fn into_raw(self) -> NonNull<T> {
        self.ptr
    }

    /// Returns the underlying pointer to T.
    #[inline]
    pub const fn as_ptr(self) -> *mut T {
        self.ptr.as_ptr()
    }
}

unsafe impl<T: ComInterface, M: MtKind, S: SafetyKind> ComDeref for ComRef<T, M, S> {
    type Target = T;
    type Mt = M;
    type Safety = S;

    #[inline]
    fn com_deref(&self) -> Self {
        *self
    }
}

unsafe impl<T, M: MtKind, S: SafetyKind> Send for ComRef<T, M, S> where M: Send {}

unsafe impl<T, M: MtKind, S: SafetyKind> Sync for ComRef<T, M, S> where M: Sync {}

impl<T, M: MtKind, S: SafetyKind> Clone for ComRef<T, M, S> {
    #[inline]
    fn clone(&self) -> Self {
        Self {
            ptr: self.ptr,
            _marker: self._marker,
        }
    }
}

impl<T, M: MtKind, S: SafetyKind> Copy for ComRef<T, M, S> {}

/// A smart COM pointer which manages the underlying COM object's reference
/// count by calling [`IUnknown::AddRef`](IUnknownVtbl#structfield.AddRef) and
/// [`IUnknown::Release`](IUnknownVtbl#structfield.Release) where appropriate.
#[repr(transparent)]
pub struct ComPtr<T: ComHierarchy<IUnknown>, M: MtKind = Apartment, S: SafetyKind = Unsafe> {
    ptr: ComRef<T, M, S>,
}

impl<T: ComHierarchy<IUnknown>, M: MtKind, S: SafetyKind> ComPtr<T, M, S> {
    /// Constructs a new COM pointer without calling `AddRef`.
    ///
    /// # Safety
    ///
    /// Calling this method creates a ComPtr which calls Release on drop. That
    /// deferred Release call must be safe.
    pub unsafe fn new(ptr: ComRef<T, M, S>) -> Self {
        Self { ptr }
    }

    /// Constructs a new COM pointer, calling `AddRef` on the object.
    ///
    /// # Safety
    ///
    /// This calls AddRef on the underlying COM object regardless of the
    /// reference's safety marker. Additionally, the returned ComPtr calls
    /// Release on drop. That deferred Release call must be safe.
    pub unsafe fn adopt(ptr: ComRef<T, M, S>) -> Self {
        ptr.into_unsafe().AddRef();
        Self { ptr }
    }

    /// Returns the inner reference and destroys this smart pointer without
    /// calling Release.
    pub fn into_inner(self) -> ComRef<T, M, S> {
        let p = self.ptr;
        std::mem::forget(self);
        p
    }

    /// Returns a reference to the underlying object.
    pub fn as_ref(&self) -> ComRef<T, M, S> {
        self.ptr
    }

    /// Returns an unsafe reference to the underlying object.
    pub fn as_unsafe(&self) -> ComRef<T, M, Unsafe> {
        self.ptr.into_unsafe()
    }

    /// Upcasts the COM reference to a base interface `U`.
    #[inline]
    pub fn upcast<U>(&self) -> ComRef<U, M, S>
    where
        T: ComHierarchy<U>,
        U: ComInterface,
    {
        self.ptr.upcast()
    }
}

unsafe impl<T: ComHierarchy<IUnknown>, M: MtKind, S: SafetyKind> ComDeref for ComPtr<T, M, S> {
    type Target = T;
    type Mt = M;
    type Safety = S;

    #[inline]
    fn com_deref(&self) -> ComRef<T, M, S> {
        self.ptr
    }
}

impl<T: ComHierarchy<IUnknown>, M: MtKind, S: SafetyKind> Clone for ComPtr<T, M, S> {
    fn clone(&self) -> Self {
        // SAFETY: AddRef must always be safe to call on a valid COM object.
        unsafe {
            self.as_unsafe().AddRef();
        }

        Self { ptr: self.ptr }
    }
}

impl<T: ComHierarchy<IUnknown>, M: MtKind, S: SafetyKind> Drop for ComPtr<T, M, S> {
    fn drop(&mut self) {
        // SAFETY: ComPtr construction asserts that the inner object is safe to
        // release.
        unsafe {
            self.as_unsafe().Release();
        }
    }
}