use crate::*;
use std::convert::TryInto;
use bindings::{
Windows::Win32::Foundation::BSTR,
Windows::Win32::System::OleAutomation::{GetErrorInfo, SetErrorInfo},
Windows::Win32::System::WinRT::{ILanguageExceptionErrorInfo2, IRestrictedErrorInfo},
};
#[derive(Clone, PartialEq)]
pub struct Error {
code: HRESULT,
info: Option<IRestrictedErrorInfo>,
}
impl Error {
pub fn new(code: HRESULT, message: &str) -> Self {
let message: HSTRING = message.into();
unsafe {
let _ = RoOriginateError(code, message.abi() as _);
}
let info = unsafe { GetErrorInfo(0).and_then(|e| e.cast()).ok() };
Self { code, info }
}
#[doc(hidden)]
pub const fn fast_error(code: HRESULT) -> Self {
Self { code, info: None }
}
pub const fn code(&self) -> HRESULT {
self.code
}
pub const fn info(&self) -> &Option<IRestrictedErrorInfo> {
&self.info
}
pub fn message(&self) -> String {
if let Some(info) = &self.info {
let mut fallback = BSTR::default();
let mut message = BSTR::default();
let mut unused = BSTR::default();
let mut code = HRESULT(0);
unsafe {
let _ = info.GetErrorDetails(&mut fallback, &mut code, &mut message, &mut unused);
}
let message = if !message.is_empty() {
message
} else {
fallback
};
let message: String = message.try_into().unwrap_or_default();
if self.code == code {
return message.trim_end().to_owned();
}
}
self.code.message()
}
fn win32_code(&self) -> Option<u32> {
let hresult = self.code.0;
if ((hresult >> 16) & 0x7FF) == 7 {
Some(hresult & 0xFFFF)
} else {
None
}
}
}
impl std::convert::From<Error> for HRESULT {
fn from(error: Error) -> Self {
let code = error.code;
let info = error.info.and_then(|info| info.cast().ok());
unsafe {
let _ = SetErrorInfo(0, info);
}
code
}
}
impl std::convert::From<HRESULT> for Error {
fn from(code: HRESULT) -> Self {
let info: Option<IRestrictedErrorInfo> =
unsafe { GetErrorInfo(0).and_then(|e| e.cast()).ok() };
if let Some(info) = info {
if let Ok(capture) = info.cast::<ILanguageExceptionErrorInfo2>() {
unsafe {
let _ = capture.CapturePropagationContext(None);
}
}
return Self {
code,
info: Some(info),
};
}
if let Ok(info) = unsafe { GetErrorInfo(0) } {
let message = unsafe { info.GetDescription().unwrap_or_default() };
let message: String = message.try_into().unwrap_or_default();
Self::new(code, &message)
} else {
Self::new(code, "")
}
}
}
impl std::fmt::Debug for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut debug = fmt.debug_struct("Error");
debug
.field("code", &format_args!("{:#010X}", self.code.0))
.field("message", &self.message());
if let Some(win32) = self.win32_code() {
debug.field("win32_code", &format_args!("{}", win32));
}
debug.finish()
}
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_str(&self.message())
}
}
impl std::error::Error for Error {}
demand_load! {
"combase.dll" {
fn RoOriginateError(code: HRESULT, message: RawPtr) -> i32;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message() {
let code = Error::fast_error(HRESULT::from_win32(0));
assert_eq!(code.message(), "The operation completed successfully.");
let code = Error::fast_error(HRESULT::from_win32(997));
assert_eq!(code.message(), "Overlapped I/O operation is in progress.");
}
#[test]
fn win32_error_conversion() {
let code = Error::fast_error(HRESULT::from_win32(18));
let win32_error = code.win32_code();
assert_eq!(win32_error, Some(18))
}
}