use std::fmt::{Debug, Formatter};
use std::ops::{Deref, DerefMut};
use std::os::raw::c_void;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use winapi::shared::guiddef::{IID, REFIID};
use winapi::shared::ntdef::HRESULT;
use winapi::um::unknwnbase::IUnknown;
pub const fn iid(data1: u32, data2: u16, data3: u16, data4: [u8; 8]) -> Iid {
Iid::new(data1, data2, data3, data4)
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Iid {
pub data1: u32,
pub data2: u16,
pub data3: u16,
pub data4: [u8; 8],
}
impl Iid {
const fn new(data1: u32, data2: u16, data3: u16, data4: [u8; 8]) -> Iid {
Iid {
data1,
data2,
data3,
data4,
}
}
}
impl Into<IID> for Iid {
fn into(self) -> IID {
unsafe { std::mem::transmute(self) }
}
}
impl From<IID> for Iid {
fn from(iid: IID) -> Iid {
unsafe { std::mem::transmute(iid) }
}
}
impl Deref for Iid {
type Target = IID;
fn deref(&self) -> &IID {
unsafe { &*(self as *const Iid as *const IID) }
}
}
impl DerefMut for Iid {
fn deref_mut(&mut self) -> &mut IID {
unsafe { &mut *(self as *mut Iid as *mut IID) }
}
}
pub struct ComPtr<T> {
pointer: NonNull<T>,
}
impl<T> ComPtr<T> {
pub unsafe fn new(pointer: *mut T) -> ComPtr<T> {
let pointer = NonNull::new(pointer).expect("cannot create `ComPtr` from null object.");
ComPtr { pointer }
}
pub unsafe fn from(pointer: *mut T) -> ComPtr<T> {
let pointer = NonNull::new(pointer).expect("cannot create `ComPtr` from null object.");
reference_add(pointer.cast());
ComPtr { pointer }
}
pub fn cast<U>(&self) -> ComPtr<U>
where
T: Deref<Target = U>,
{
unsafe {
reference_add(self.pointer.cast());
ComPtr {
pointer: self.pointer.cast(),
}
}
}
pub fn query<U: winapi::Interface>(&self) -> Result<ComPtr<U>, HRESULT> {
self.query_iid::<U>(&U::uuidof())
}
pub fn query_iid<U>(&self, iid: REFIID) -> Result<ComPtr<U>, HRESULT> {
unsafe {
let mut pointer = std::ptr::null_mut::<U>();
let unknown = self.as_unknown();
let hr = (*unknown).QueryInterface(iid, &mut pointer as *mut *mut _ as *mut *mut c_void);
match hr >= 0 {
true => Ok(ComPtr::new(pointer)),
false => Err(hr),
}
}
}
pub fn as_mut(&mut self) -> &mut T {
unsafe { self.pointer.as_mut() }
}
pub fn as_ptr(&self) -> *const T {
self.pointer.as_ptr()
}
pub fn as_mut_ptr(&self) -> *mut T {
self.pointer.as_ptr()
}
pub fn as_unknown(&self) -> *mut IUnknown {
self.pointer.as_ptr() as *mut IUnknown
}
pub fn into_raw(self) -> *mut T {
let pointer = self.pointer.as_ptr();
std::mem::forget(self);
pointer
}
}
impl<T> Deref for ComPtr<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.pointer.as_ref() }
}
}
impl<T> DerefMut for ComPtr<T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { self.pointer.as_mut() }
}
}
impl<T> Clone for ComPtr<T> {
fn clone(&self) -> ComPtr<T> {
unsafe {
reference_add(self.pointer.cast());
ComPtr { pointer: self.pointer }
}
}
}
impl<T> Drop for ComPtr<T> {
fn drop(&mut self) {
unsafe {
reference_remove(self.pointer.cast());
}
}
}
impl<T> PartialEq<ComPtr<T>> for ComPtr<T> {
fn eq(&self, other: &ComPtr<T>) -> bool {
self.pointer == other.pointer
}
}
impl<T> Debug for ComPtr<T> {
fn fmt(&self, formatter: &mut Formatter) -> Result<(), std::fmt::Error> {
formatter
.debug_struct("ComPtr")
.field("type", unsafe { &std::intrinsics::type_name::<T>() })
.field("pointer", &self.pointer)
.finish()
}
}
pub unsafe fn reference_add(mut object: NonNull<IUnknown>) {
object.as_mut().AddRef();
}
pub unsafe fn reference_remove(mut object: NonNull<IUnknown>) {
object.as_mut().Release();
}
#[derive(Default)]
pub struct ReferenceCount {
value: AtomicUsize,
}
impl ReferenceCount {
pub fn new() -> ReferenceCount {
ReferenceCount {
value: AtomicUsize::new(1),
}
}
pub fn zero() -> ReferenceCount {
ReferenceCount {
value: AtomicUsize::new(1),
}
}
pub fn increment(&self) -> usize {
self.value.fetch_add(1, Ordering::Relaxed) + 1
}
pub fn decrement(&self) -> usize {
self.value.fetch_sub(1, Ordering::Release) - 1
}
}
#[repr(C)]
struct ComObject<T, TVirtualFunctionTable> {
vtable: *const TVirtualFunctionTable,
value: T,
}
pub trait ComLayout<TInterface>
where
Self: Sized,
{
type VirtualFunctionTable: 'static;
fn vtable() -> &'static Self::VirtualFunctionTable;
unsafe fn into_com(self) -> *mut TInterface {
let native = box ComObject {
vtable: Self::vtable(),
value: self,
};
Box::into_raw(native) as *mut TInterface
}
unsafe fn from_com<'a>(pointer: *mut TInterface) -> &'a mut Self {
let object = pointer as *mut ComObject<Self, Self::VirtualFunctionTable>;
&mut (*object).value
}
fn into_com_ptr(self) -> ComPtr<TInterface> {
unsafe {
let pointer = self.into_com();
ComPtr::new(pointer)
}
}
unsafe fn destroy(pointer: *mut TInterface) {
let object = pointer as *mut ComObject<Self, Self::VirtualFunctionTable>;
Box::from_raw(object);
}
}