use std::cell::Cell;
use std::fmt;
use std::marker::PhantomData;
use std::rc::Rc;
use windows::Win32::System::Com::{
COINIT, COINIT_APARTMENTTHREADED, COINIT_MULTITHREADED, CoInitializeEx, CoUninitialize,
};
use windows::core::HRESULT;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum Type {
Sta,
Mta,
}
impl Type {
#[must_use]
pub fn as_raw(&self) -> COINIT {
match self {
Type::Sta => COINIT_APARTMENTTHREADED,
Type::Mta => COINIT_MULTITHREADED,
}
}
#[must_use]
pub fn as_code(&self) -> i32 {
self.as_raw().0
}
}
const RPC_E_CHANGED_MODE: HRESULT = HRESULT(0x8001_0106_u32.cast_signed());
thread_local! {
static COM_INIT: Cell<bool> = const { Cell::new(false) };
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Initialization {
Success,
AlreadyInitialized,
}
pub fn initialize_as(kind: Type) -> Result<Initialization, Error> {
COM_INIT.with(|state| {
if state.get() {
return Ok(Initialization::AlreadyInitialized);
}
let coinit: COINIT = match kind {
Type::Sta => COINIT_APARTMENTTHREADED,
Type::Mta => COINIT_MULTITHREADED,
};
let hresult = unsafe { CoInitializeEx(None, coinit) };
match hresult {
HRESULT(0) => {
state.set(true);
Ok(Initialization::Success)
}
HRESULT(1) | RPC_E_CHANGED_MODE => {
state.set(true);
Ok(Initialization::AlreadyInitialized)
}
_ => Err(windows::core::Error::from_hresult(hresult).context(None, "CoInitializeEx")),
}
})
}
pub(crate) fn ensure_initialized() -> crate::Result<()> {
match initialize_as(Type::Sta) {
Ok(_) => Ok(()),
Err(err) => Err(crate::Error::Com(err)),
}
}
pub fn uninitialize() {
COM_INIT.with(|state| {
if state.get() {
unsafe { CoUninitialize() };
}
});
}
#[derive(Debug)]
pub struct Session {
uninit_on_drop: bool,
_no_send: PhantomData<Rc<()>>,
}
impl Session {
pub fn new(kind: Type) -> Result<Self, Error> {
let init = initialize_as(kind)?;
Ok(Self {
uninit_on_drop: init == Initialization::Success,
_no_send: PhantomData,
})
}
}
impl Drop for Session {
fn drop(&mut self) {
if self.uninit_on_drop {
uninitialize();
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MethodInfo<'a> {
pub interface: Option<&'a str>,
pub method: &'a str,
}
impl fmt::Display for MethodInfo<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.interface {
Some(interface) => write!(f, "{}::{}", interface, self.method),
None => write!(f, "{}", self.method),
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("COM error from {method}: {source}")]
pub struct Error {
pub method: MethodInfo<'static>,
#[source]
pub source: windows::core::Error,
}
pub(crate) trait ComErrorExt {
fn context(self, interface: Option<&'static str>, method: &'static str) -> Error;
}
impl ComErrorExt for windows::core::Error {
fn context(self, interface: Option<&'static str>, method: &'static str) -> Error {
Error {
method: MethodInfo { interface, method },
source: self,
}
}
}
pub(crate) trait ComResultExt<T> {
fn context(self, interface: Option<&'static str>, method: &'static str) -> crate::Result<T>;
}
impl<T> ComResultExt<T> for windows::core::Result<T> {
fn context(self, interface: Option<&'static str>, method: &'static str) -> crate::Result<T> {
self.map_err(|source| crate::Error::Com(source.context(interface, method)))
}
}