extern crate alloc;
use core::ffi::c_void;
use core::sync::atomic::{AtomicU32, Ordering};
use windows::Win32::Foundation::HMODULE;
use windows::Win32::System::LibraryLoader::{
GetModuleFileNameW, GetModuleHandleExW, GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
};
use windows_core::{ComObject, IUnknown, Interface, GUID, HRESULT, PCWSTR};
use crate::clsid::Clsid;
use crate::error::HResult;
use crate::raw::class_factory::{ApoClassFactory, ApoVTable};
use crate::raw::register;
pub unsafe fn dll_get_class_object_dispatch(
rclsid: *const GUID,
riid: *const GUID,
ppv: *mut *mut c_void,
registry: &[&'static ApoVTable],
) -> HRESULT {
if ppv.is_null() {
return HResult::E_POINTER.into();
}
unsafe {
*ppv = core::ptr::null_mut();
}
if rclsid.is_null() || riid.is_null() {
return HResult::E_POINTER.into();
}
let requested = Clsid::from(unsafe { *rclsid });
let Some(vtable) = registry.iter().find(|v| v.clsid == requested) else {
return HResult::CLASS_E_CLASSNOTAVAILABLE.into();
};
let factory = ApoClassFactory::new(vtable);
let com = ComObject::new(factory);
let unknown: IUnknown = com.into_interface();
unsafe { unknown.query(riid, ppv) }
}
static OUTSTANDING: AtomicU32 = AtomicU32::new(0);
pub fn outstanding_inc() {
OUTSTANDING.fetch_add(1, Ordering::Relaxed);
}
pub fn outstanding_dec() {
OUTSTANDING.fetch_sub(1, Ordering::Relaxed);
}
#[must_use]
pub fn outstanding_count() -> u32 {
OUTSTANDING.load(Ordering::Relaxed)
}
pub fn dll_can_unload_now_dispatch() -> HRESULT {
if outstanding_count() == 0 {
HRESULT(0) } else {
HRESULT(1) }
}
pub fn dll_register_server_dispatch(registry: &[&'static ApoVTable]) -> HRESULT {
let dll_path = match own_module_path() {
Ok(p) => p,
Err(e) => return e.code(),
};
for vtable in registry {
if let Err(e) = register::write_registry(vtable, &dll_path) {
return e.code();
}
}
HRESULT(0)
}
pub fn dll_unregister_server_dispatch(registry: &[&'static ApoVTable]) -> HRESULT {
for vtable in registry {
if let Err(e) = register::clear_registry(&vtable.clsid) {
return e.code();
}
}
HRESULT(0)
}
pub fn own_module_path() -> windows_core::Result<alloc::vec::Vec<u16>> {
let mut hmodule = HMODULE::default();
let address = own_module_path as *const c_void;
unsafe {
GetModuleHandleExW(
GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
PCWSTR(address.cast::<u16>()),
&mut hmodule,
)
}?;
let mut buf = alloc::vec![0u16; 1024];
let written = unsafe { GetModuleFileNameW(Some(hmodule), &mut buf) };
if written == 0 {
return Err(windows_core::Error::from_thread());
}
buf.truncate(written as usize + 1);
Ok(buf)
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::sync::Arc;
use super::*;
use crate::apo::{ApoCategory, ProcessInput, ProcessingObject};
use crate::buffer::BufferFlags;
use crate::instance::{AnyApoInstance, ApoInstance};
use crate::realtime::RealtimeContext;
struct Dummy;
impl ProcessingObject for Dummy {
const CLSID: Clsid = Clsid::from_u128(0xA0A0A0A0_0000_0000_0000_0000A0A0A0A0);
const NAME: &'static str = "dummy";
const COPYRIGHT: &'static str = "test";
const CATEGORY: ApoCategory = ApoCategory::Sfx;
fn new() -> Self {
Self
}
fn process(
&mut self,
_rt: &RealtimeContext,
input: ProcessInput<'_>,
output: &mut [f32],
) -> BufferFlags {
output.copy_from_slice(input.samples());
input.flags()
}
}
fn dummy_create() -> Arc<dyn AnyApoInstance> {
Arc::new(ApoInstance::<Dummy>::new())
}
static DUMMY_VT: ApoVTable = ApoVTable {
clsid: Dummy::CLSID,
name: Dummy::NAME,
copyright: Dummy::COPYRIGHT,
category: Dummy::CATEGORY,
create: dummy_create,
};
fn dispatch(
clsid: Clsid,
riid: GUID,
registry: &[&'static ApoVTable],
ppv_null: bool,
) -> (HRESULT, *mut c_void) {
let mut out: *mut c_void = 0xDEAD_BEEF as *mut c_void;
let ppv_ptr = if ppv_null {
core::ptr::null_mut()
} else {
&mut out as *mut *mut c_void
};
let g: GUID = clsid.into();
let hr = unsafe { dll_get_class_object_dispatch(&g, &riid, ppv_ptr, registry) };
(hr, out)
}
#[test]
fn dispatch_null_ppv_returns_e_pointer() {
let (hr, out) = dispatch(Dummy::CLSID, IUnknown::IID, &[], true);
assert_eq!(hr, HResult::E_POINTER.into());
assert_eq!(out, 0xDEAD_BEEF as *mut c_void);
}
#[test]
fn dispatch_unknown_clsid_returns_class_e_classnotavailable() {
let unknown_clsid = Clsid::from_u128(0xBADBAD00_0000_0000_0000_0000BADBAD00);
let (hr, out) = dispatch(unknown_clsid, IUnknown::IID, &[], false);
assert_eq!(hr, HResult::CLASS_E_CLASSNOTAVAILABLE.into());
assert!(out.is_null());
}
#[test]
fn dispatch_matching_clsid_returns_class_factory() {
use windows::Win32::System::Com::IClassFactory;
let registry: &[&ApoVTable] = &[&DUMMY_VT];
let (hr, out) = dispatch(Dummy::CLSID, IClassFactory::IID, registry, false);
assert!(hr.is_ok(), "expected S_OK from query, got {hr:?}");
assert!(!out.is_null());
unsafe {
let _factory = IClassFactory::from_raw(out);
}
}
#[test]
fn dispatch_matching_clsid_with_unsupported_riid_returns_no_interface() {
let unsupported = GUID::from_u128(0xCAFE0001_0000_0000_0000_000000000001);
let registry: &[&ApoVTable] = &[&DUMMY_VT];
let (hr, out) = dispatch(Dummy::CLSID, unsupported, registry, false);
assert_eq!(hr, HResult::E_NOINTERFACE.into());
assert!(out.is_null());
}
}