use super::*;
use crate::ComInterface;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicPtr, Ordering};
#[doc(hidden)]
pub struct FactoryCache<C, I> {
shared: AtomicPtr<std::ffi::c_void>,
_c: PhantomData<C>,
_i: PhantomData<I>,
}
impl<C, I> FactoryCache<C, I> {
pub const fn new() -> Self {
Self { shared: AtomicPtr::new(std::ptr::null_mut()), _c: PhantomData, _i: PhantomData }
}
}
impl<C: crate::RuntimeName, I: crate::ComInterface> FactoryCache<C, I> {
pub fn call<R, F: FnOnce(&I) -> crate::Result<R>>(&self, callback: F) -> crate::Result<R> {
loop {
let ptr = self.shared.load(Ordering::Relaxed);
if !ptr.is_null() {
return callback(unsafe { std::mem::transmute(&ptr) });
}
let factory = factory::<C, I>()?;
if factory.cast::<IAgileObject>().is_ok() {
if self.shared.compare_exchange_weak(std::ptr::null_mut(), factory.as_raw(), Ordering::Relaxed, Ordering::Relaxed).is_ok() {
std::mem::forget(factory);
}
} else {
return callback(&factory);
}
}
}
}
unsafe impl<C, I> std::marker::Sync for FactoryCache<C, I> {}
pub fn factory<C: crate::RuntimeName, I: crate::ComInterface>() -> crate::Result<I> {
let mut factory: Option<I> = None;
let name = crate::HSTRING::from(C::NAME);
let code = if let Some(function) = unsafe { delay_load::<RoGetActivationFactory>(crate::s!("combase.dll"), crate::s!("RoGetActivationFactory")) } {
unsafe {
let mut code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _);
if code == CO_E_NOTINITIALIZED {
if let Some(mta) = delay_load::<CoIncrementMTAUsage>(crate::s!("ole32.dll"), crate::s!("CoIncrementMTAUsage")) {
let mut cookie = std::ptr::null_mut();
let _ = mta(&mut cookie);
}
code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _);
}
code
}
} else {
CLASS_E_CLASSNOTAVAILABLE
};
if code.is_ok() {
return code.and_some(factory);
}
let original: crate::Error = code.into();
if let Some(i) = search_path(C::NAME, |library| unsafe { get_activation_factory(library, &name) }) {
i.cast()
} else {
Err(original)
}
}
fn search_path<F, R>(mut path: &str, mut callback: F) -> Option<R>
where
F: FnMut(crate::PCSTR) -> crate::Result<R>,
{
let suffix = b".dll\0";
let mut library = vec![0; path.len() + suffix.len()];
while let Some(pos) = path.rfind('.') {
path = &path[..pos];
library.truncate(path.len() + suffix.len());
library[..path.len()].copy_from_slice(path.as_bytes());
library[path.len()..].copy_from_slice(suffix);
if let Ok(r) = callback(crate::PCSTR::from_raw(library.as_ptr())) {
return Some(r);
}
}
None
}
unsafe fn get_activation_factory(library: crate::PCSTR, name: &crate::HSTRING) -> crate::Result<IGenericFactory> {
let function = delay_load::<DllGetActivationFactory>(library, crate::s!("DllGetActivationFactory")).ok_or_else(crate::Error::from_win32)?;
let mut abi = std::ptr::null_mut();
function(std::mem::transmute_copy(name), &mut abi).from_abi(abi)
}
type CoIncrementMTAUsage = extern "system" fn(cookie: *mut *mut std::ffi::c_void) -> crate::HRESULT;
type RoGetActivationFactory = extern "system" fn(hstring: *mut std::ffi::c_void, interface: &crate::GUID, result: *mut *mut std::ffi::c_void) -> crate::HRESULT;
type DllGetActivationFactory = extern "system" fn(name: *mut std::ffi::c_void, factory: *mut *mut std::ffi::c_void) -> crate::HRESULT;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dll_search() {
let path = "A.B.TypeName";
let mut results = Vec::new();
let end_result = search_path(path, |library| {
results.push(unsafe { library.to_string().unwrap() });
if unsafe { library.as_bytes() } == &b"A.dll"[..] {
Ok(42)
} else {
Err(crate::Error::OK)
}
});
assert!(matches!(end_result, Some(42)));
assert_eq!(results, vec!["A.B.dll", "A.dll"]);
let mut results = Vec::new();
let end_result = search_path(path, |library| {
results.push(unsafe { library.to_string().unwrap() });
crate::Result::<()>::Err(crate::Error::OK)
});
assert!(end_result.is_none());
assert_eq!(results, vec!["A.B.dll", "A.dll"]);
}
}