use crate::D3DResult;
use std::{
fmt,
hash::{Hash, Hasher},
ops::Deref,
ptr,
};
use winapi::{ctypes::c_void, um::unknwnbase::IUnknown, Interface};
#[repr(transparent)]
pub struct ComPtr<T: Interface>(*mut T);
impl<T: Interface> ComPtr<T> {
pub fn null() -> Self {
ComPtr(ptr::null_mut())
}
pub unsafe fn from_raw(raw: *mut T) -> Self {
if !raw.is_null() {
(*(raw as *mut IUnknown)).AddRef();
}
ComPtr(raw)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub fn as_ptr(&self) -> *const T {
self.0
}
pub fn as_mut_ptr(&self) -> *mut T {
self.0
}
pub fn mut_void(&mut self) -> *mut *mut c_void {
&mut self.0 as *mut *mut _ as *mut *mut _
}
pub fn mut_self(&mut self) -> *mut *mut T {
&mut self.0 as *mut *mut _
}
}
impl<T: Interface> ComPtr<T> {
pub unsafe fn as_unknown(&self) -> &IUnknown {
debug_assert!(!self.is_null());
&*(self.0 as *mut IUnknown)
}
pub unsafe fn cast<U>(&self) -> D3DResult<ComPtr<U>>
where
U: Interface,
{
debug_assert!(!self.is_null());
let mut obj = ComPtr::<U>::null();
let hr = self
.as_unknown()
.QueryInterface(&U::uuidof(), obj.mut_void());
(obj, hr)
}
}
impl<T: Interface> Clone for ComPtr<T> {
fn clone(&self) -> Self {
debug_assert!(!self.is_null());
unsafe {
self.as_unknown().AddRef();
}
ComPtr(self.0)
}
}
impl<T: Interface> Drop for ComPtr<T> {
fn drop(&mut self) {
if !self.0.is_null() {
unsafe {
self.as_unknown().Release();
}
}
}
}
impl<T: Interface> Deref for ComPtr<T> {
type Target = T;
fn deref(&self) -> &T {
debug_assert!(!self.is_null());
unsafe { &*self.0 }
}
}
impl<T: Interface> fmt::Debug for ComPtr<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ComPtr( ptr: {:?} )", self.0)
}
}
impl<T: Interface> PartialEq<*mut T> for ComPtr<T> {
fn eq(&self, other: &*mut T) -> bool {
self.0 == *other
}
}
impl<T: Interface> PartialEq for ComPtr<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T: Interface> Hash for ComPtr<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
#[macro_export]
macro_rules! weak_com_inheritance_chain {
(
$(#[$meta:meta])*
$vis:vis enum $name:ident {
$first_variant:ident($first_type:ty), $first_from_name:ident, $first_as_name:ident, $first_unwrap_name:ident $(;)?
$($variant:ident($type:ty), $from_name:ident, $as_name:ident, $unwrap_name:ident);* $(;)?
}
) => {
$(#[$meta])*
$vis enum $name {
$first_variant($crate::ComPtr<$first_type>),
$(
$variant($crate::ComPtr<$type>)
),+
}
impl $name {
$crate::weak_com_inheritance_chain! {
@recursion_logic,
$vis,
;
$first_variant($first_type), $first_from_name, $first_as_name, $first_unwrap_name;
$($variant($type), $from_name, $as_name, $unwrap_name);*
}
}
impl std::ops::Deref for $name {
type Target = $crate::ComPtr<$first_type>;
fn deref(&self) -> &Self::Target {
self.$first_unwrap_name()
}
}
};
(
@recursion_logic,
$vis:vis,
$(,)? $($prev_variant:ident),* $(,)?;
$this_variant:ident($this_type:ty), $this_from_name:ident, $this_as_name:ident, $this_unwrap_name:ident $(;)?
$($next_variant:ident($next_type:ty), $next_from_name:ident, $next_as_name:ident, $next_unwrap_name:ident);*
) => {
$crate::weak_com_inheritance_chain! {
@render_members,
$vis,
$this_from_name, $this_as_name, $this_unwrap_name;
$($prev_variant),*;
$this_variant($this_type);
$($next_variant),*;
}
$crate::weak_com_inheritance_chain! {
@recursion_logic,
$vis,
$($prev_variant),* , $this_variant;
$($next_variant($next_type), $next_from_name, $next_as_name, $next_unwrap_name);*
}
};
(
@recursion_logic,
$vis:vis,
$($prev_variant:ident),*;
) => {};
(
@render_members,
$vis:vis,
$from_name:ident, $as_name:ident, $unwrap_name:ident;
$($prev_variant:ident),*;
$variant:ident($type:ty);
$($next_variant:ident),*;
) => {
$vis unsafe fn $from_name(value: $crate::ComPtr<$type>) -> Self {
Self::$variant(value)
}
$vis fn $as_name(&self) -> Option<&$crate::ComPtr<$type>> {
match *self {
$(
Self::$prev_variant(_) => None,
)*
Self::$variant(ref v) => Some(v),
$(
Self::$next_variant(ref v) => {
Some(unsafe { std::mem::transmute(v) })
}
)*
}
}
#[track_caller]
$vis fn $unwrap_name(&self) -> &$crate::ComPtr<$type> {
match *self {
$(
Self::$prev_variant(_) => panic!(concat!("Tried to unwrap a ", stringify!($prev_variant), " as a ", stringify!($variant))),
)*
Self::$variant(ref v) => &*v,
$(
Self::$next_variant(ref v) => {
unsafe { std::mem::transmute(v) }
}
)*
}
}
};
}