winio-winui3 0.4.0

WinUI3 bindings (subset) for Rust
use std::{ffi::c_void, ops::Deref, ptr::NonNull};

use windows::Win32::Foundation::{E_NOINTERFACE, E_NOTIMPL, E_POINTER, S_OK};
use windows_core::{
    factory, imp::WeakRefCount, ComObject, ComObjectInner, ComObjectInterface, IInspectable,
    IInspectable_Vtbl, IUnknownImpl, Interface, InterfaceRef, Result, RuntimeName, Type, TypeKind,
    GUID, HRESULT,
};

#[repr(transparent)]
pub struct Compose<T> {
    inner: T,
}

pub type CreateInstanceFn = unsafe extern "system" fn(
    this: *mut c_void,
    outer: *mut c_void,
    base: *mut *mut c_void,
    result: *mut *mut c_void,
) -> HRESULT;

pub trait ChildClass: ComObjectInner {
    type BaseType: RuntimeName + TypeKind + Type<Self::BaseType, Abi = *mut c_void>;
    type FactoryInterface: Interface;

    fn create_interface_fn(
        vtable: &<Self::FactoryInterface as Interface>::Vtable,
    ) -> CreateInstanceFn;

    fn identity_vtable(vtable: &mut Self::Outer) -> &mut &'static IInspectable_Vtbl;

    fn ref_count(vtable: &Self::Outer) -> &WeakRefCount;

    fn into_outer(self) -> Self::Outer;
}

pub trait ChildClassImpl: ComObjectInterface<IInspectable> + IUnknownImpl
where
    Self::Impl: ChildClass<Outer = Self>,
{
    /// Get the base object. Query the *Overrides interface to call the base methods.
    fn base(&self) -> Result<&IInspectable> {
        if let Some(compose) = Compose_Impl::<Self::Impl>::from_ref(self) {
            if let Some(base) = compose.base() {
                Ok(base)
            } else {
                Err(E_POINTER.into())
            }
        } else {
            Err(E_NOTIMPL.into())
        }
    }
}

impl<T: ChildClass> Compose<T>
where
    T::Outer: ComObjectInterface<IInspectable>,
{
    pub fn compose(inner: T) -> Result<T::BaseType> {
        Self::compose_with(inner, &factory::<T::BaseType, T::FactoryInterface>()?)
    }

    pub fn compose_with(inner: T, factory: &T::FactoryInterface) -> Result<T::BaseType> {
        let this = Self { inner };
        unsafe {
            let outer: IInspectable = this.into();
            let outer__ = outer.as_raw();
            let r#impl = outer__ as *mut Compose_Impl<T>;
            let base__ = &mut (*r#impl).base;
            let mut result__ = std::ptr::null_mut();
            T::create_interface_fn(factory.vtable())(
                factory.as_raw(),
                outer__,
                base__ as *mut _ as _,
                &mut result__,
            )
            .and_then(|| Type::from_abi(result__))
        }
    }
}

impl<T: ChildClass> Compose<T> {
    #[inline(always)]
    fn into_outer(self) -> Compose_Impl<T> {
        let mut vtable = self.inner.into_outer();
        *T::identity_vtable(&mut vtable) = &Compose_Impl::<T>::VTABLE_IDENTITY;
        Compose_Impl { vtable, base: None }
    }
}

#[repr(C)]
#[allow(non_camel_case_types)]
pub struct Compose_Impl<T: ComObjectInner> {
    vtable: T::Outer,
    base: Option<IInspectable>,
}

impl<T: ComObjectInner> Compose_Impl<T> {
    /// # Safety
    /// The object should be created by `Compose::compose*`.
    pub unsafe fn from_ref_unchecked(this: &T::Outer) -> &Self {
        &*(this as *const T::Outer as *const Self)
    }

    /// # Safety
    /// The object should be created by `Compose::compose*`.
    pub unsafe fn from_mut_unchecked(this: &mut T::Outer) -> &mut Self {
        &mut *(this as *mut T::Outer as *mut Self)
    }

    pub fn base(&self) -> Option<&IInspectable> {
        self.base.as_ref()
    }
}

impl<T: ComObjectInner> Compose_Impl<T>
where
    T::Outer: ComObjectInterface<IInspectable>,
{
    pub fn from_ref(this: &T::Outer) -> Option<&Self> {
        if is_composed(this) {
            Some(unsafe { Self::from_ref_unchecked(this) })
        } else {
            None
        }
    }

    pub fn from_mut(this: &mut T::Outer) -> Option<&mut Self> {
        if is_composed(this) {
            Some(unsafe { Self::from_mut_unchecked(this) })
        } else {
            None
        }
    }
}

impl<T: ComObjectInner> Deref for Compose_Impl<T>
where
    T::Outer: Deref,
{
    type Target = <T::Outer as Deref>::Target;

    #[inline(always)]
    fn deref(&self) -> &Self::Target {
        self.vtable.deref()
    }
}

impl<T: ChildClass> Compose_Impl<T> {
    const VTABLE_IDENTITY: IInspectable_Vtbl =
        IInspectable_Vtbl::new::<Compose_Impl<T>, T::BaseType, 0>();
}

impl<T: ChildClass> IUnknownImpl for Compose_Impl<T> {
    type Impl = Compose<T>;

    #[inline(always)]
    fn get_impl(&self) -> &Self::Impl {
        unsafe { &*(self.vtable.get_impl() as *const T as *const Compose<T>) }
    }

    #[inline(always)]
    fn get_impl_mut(&mut self) -> &mut Self::Impl {
        unsafe { &mut *(self.vtable.get_impl_mut() as *mut T as *mut Compose<T>) }
    }

    #[inline(always)]
    fn into_inner(self) -> Self::Impl {
        Compose {
            inner: self.vtable.into_inner(),
        }
    }

    #[inline(always)]
    unsafe fn QueryInterface(&self, iid: *const GUID, interface: *mut *mut c_void) -> HRESULT {
        let res = self.vtable.QueryInterface(iid, interface);
        if res == E_NOINTERFACE {
            if *iid == IS_COMPOSED_IID {
                interface.write(std::ptr::dangling_mut());
                return S_OK;
            }
            if let Some(base) = &self.base {
                return base.query(iid, interface);
            }
        }
        res
    }

    #[inline(always)]
    fn AddRef(&self) -> u32 {
        T::ref_count(&self.vtable).add_ref()
    }

    #[inline(always)]
    unsafe fn Release(self_: *mut Self) -> u32 {
        let remaining = T::ref_count(&(*self_).vtable).release();
        if remaining == 0 {
            _ = Box::from_raw(self_);
        }
        remaining
    }

    #[inline(always)]
    fn is_reference_count_one(&self) -> bool {
        self.vtable.is_reference_count_one()
    }

    #[inline(always)]
    unsafe fn GetTrustLevel(&self, value: *mut i32) -> HRESULT {
        self.vtable.GetTrustLevel(value)
    }

    #[inline(always)]
    fn to_object(&self) -> ComObject<Self::Impl> {
        self.AddRef();
        unsafe { ComObject::from_raw(NonNull::from(self)) }
    }
}

impl<T: ChildClass> ComObjectInner for Compose<T> {
    type Outer = Compose_Impl<T>;

    fn into_object(self) -> ComObject<Self> {
        let boxed = Box::new(self.into_outer());
        unsafe {
            let ptr = Box::into_raw(boxed);
            ComObject::from_raw(NonNull::new_unchecked(ptr))
        }
    }
}

impl<T: ChildClass> From<Compose<T>> for IInspectable
where
    T::Outer: ComObjectInterface<IInspectable>,
{
    #[inline(always)]
    fn from(value: Compose<T>) -> Self {
        let com_object = ComObject::new(value);
        com_object.into_interface()
    }
}

impl<T: ChildClass> ComObjectInterface<IInspectable> for Compose_Impl<T>
where
    T::Outer: ComObjectInterface<IInspectable>,
{
    #[inline(always)]
    fn as_interface_ref(&self) -> InterfaceRef<'_, IInspectable> {
        self.vtable.as_interface_ref()
    }
}

const IS_COMPOSED_IID: GUID = GUID::from_u128(0xb2ea198c_d3e0_4999_b821_e99271d67cce);

/// Just like [`windows_core::DYNAMIC_CAST_IID`], [`IS_COMPOSED_IID`] is not a standard IID.
/// The implemented `QueryInterface` doesn't increase the reference count but returns a non-null pointer.
fn is_composed<I: ComObjectInterface<IInspectable>>(interface: &I) -> bool {
    let mut is_composed = std::ptr::null_mut();
    let interface = interface.as_interface_ref();
    let res = unsafe { interface.query(&IS_COMPOSED_IID, &mut is_composed) };
    res.is_ok() && (!is_composed.is_null())
}