extern crate winapi;
use self::winapi::shared::minwindef::{WORD, DWORD, HMODULE, FARPROC};
use self::winapi::shared::ntdef::WCHAR;
use self::winapi::shared::winerror;
use self::winapi::um::{errhandlingapi, libloaderapi};
use util::{ensure_compatible_types, cstr_cow_from_bytes};
use std::ffi::{OsStr, OsString};
use std::{fmt, io, marker, mem, ptr};
use std::os::windows::ffi::{OsStrExt, OsStringExt};
use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering};
pub struct Library(HMODULE);
unsafe impl Send for Library {}
unsafe impl Sync for Library {}
impl Library {
#[inline]
pub fn new<P: AsRef<OsStr>>(filename: P) -> ::Result<Library> {
let wide_filename: Vec<u16> = filename.as_ref().encode_wide().chain(Some(0)).collect();
let _guard = ErrorModeGuard::new();
let ret = with_get_last_error(|| {
let handle = unsafe { libloaderapi::LoadLibraryW(wide_filename.as_ptr()) };
if handle.is_null() {
None
} else {
Some(Library(handle))
}
}).map_err(|e| e.unwrap_or_else(||
panic!("LoadLibraryW failed but GetLastError did not report the error")
));
drop(wide_filename); ret
}
pub unsafe fn get<T>(&self, symbol: &[u8]) -> ::Result<Symbol<T>> {
ensure_compatible_types::<T, FARPROC>();
let symbol = try!(cstr_cow_from_bytes(symbol));
with_get_last_error(|| {
let symbol = libloaderapi::GetProcAddress(self.0, symbol.as_ptr());
if symbol.is_null() {
None
} else {
Some(Symbol {
pointer: symbol,
pd: marker::PhantomData
})
}
}).map_err(|e| e.unwrap_or_else(||
panic!("GetProcAddress failed but GetLastError did not report the error")
))
}
pub unsafe fn get_ordinal<T>(&self, ordinal: WORD) -> ::Result<Symbol<T>> {
ensure_compatible_types::<T, FARPROC>();
with_get_last_error(|| {
let ordinal = ordinal as usize as *mut _;
let symbol = libloaderapi::GetProcAddress(self.0, ordinal);
if symbol.is_null() {
None
} else {
Some(Symbol {
pointer: symbol,
pd: marker::PhantomData
})
}
}).map_err(|e| e.unwrap_or_else(||
panic!("GetProcAddress failed but GetLastError did not report the error")
))
}
pub fn into_raw(self) -> HMODULE {
let handle = self.0;
mem::forget(self);
handle
}
pub unsafe fn from_raw(handle: HMODULE) -> Library {
Library(handle)
}
}
impl Drop for Library {
fn drop(&mut self) {
with_get_last_error(|| {
if unsafe { libloaderapi::FreeLibrary(self.0) == 0 } {
None
} else {
Some(())
}
}).unwrap()
}
}
impl fmt::Debug for Library {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
unsafe {
let mut buf: [WCHAR; 1024] = mem::uninitialized();
let len = libloaderapi::GetModuleFileNameW(self.0,
(&mut buf[..]).as_mut_ptr(), 1024) as usize;
if len == 0 {
f.write_str(&format!("Library@{:p}", self.0))
} else {
let string: OsString = OsString::from_wide(&buf[..len]);
f.write_str(&format!("Library@{:p} from {:?}", self.0, string))
}
}
}
}
pub struct Symbol<T> {
pointer: FARPROC,
pd: marker::PhantomData<T>
}
impl<T> Symbol<T> {
pub fn into_raw(self) -> FARPROC {
let pointer = self.pointer;
mem::forget(self);
pointer
}
}
impl<T> Symbol<Option<T>> {
pub fn lift_option(self) -> Option<Symbol<T>> {
if self.pointer.is_null() {
None
} else {
Some(Symbol {
pointer: self.pointer,
pd: marker::PhantomData,
})
}
}
}
unsafe impl<T: Send> Send for Symbol<T> {}
unsafe impl<T: Sync> Sync for Symbol<T> {}
impl<T> Clone for Symbol<T> {
fn clone(&self) -> Symbol<T> {
Symbol { ..*self }
}
}
impl<T> ::std::ops::Deref for Symbol<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe {
mem::transmute(&self.pointer)
}
}
}
impl<T> fmt::Debug for Symbol<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(&format!("Symbol@{:p}", self.pointer))
}
}
static USE_ERRORMODE: AtomicBool = ATOMIC_BOOL_INIT;
struct ErrorModeGuard(DWORD);
impl ErrorModeGuard {
fn new() -> Option<ErrorModeGuard> {
const SEM_FAILCE: DWORD = 1;
unsafe {
if !USE_ERRORMODE.load(Ordering::Acquire) {
let mut previous_mode = 0;
let success = errhandlingapi::SetThreadErrorMode(SEM_FAILCE, &mut previous_mode) != 0;
if !success && errhandlingapi::GetLastError() == winerror::ERROR_CALL_NOT_IMPLEMENTED {
USE_ERRORMODE.store(true, Ordering::Release);
} else if !success {
return None;
} else if previous_mode == SEM_FAILCE {
return None;
} else {
return Some(ErrorModeGuard(previous_mode));
}
}
match errhandlingapi::SetErrorMode(SEM_FAILCE) {
SEM_FAILCE => {
None
}
a => Some(ErrorModeGuard(a))
}
}
}
}
impl Drop for ErrorModeGuard {
fn drop(&mut self) {
unsafe {
if !USE_ERRORMODE.load(Ordering::Relaxed) {
errhandlingapi::SetThreadErrorMode(self.0, ptr::null_mut());
} else {
errhandlingapi::SetErrorMode(self.0);
}
}
}
}
fn with_get_last_error<T, F>(closure: F) -> Result<T, Option<io::Error>>
where F: FnOnce() -> Option<T> {
closure().ok_or_else(|| {
let error = unsafe { errhandlingapi::GetLastError() };
if error == 0 {
None
} else {
Some(io::Error::from_raw_os_error(error as i32))
}
})
}
#[test]
fn works_getlasterror() {
let lib = Library::new("kernel32.dll").unwrap();
let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
lib.get(b"GetLastError").unwrap()
};
unsafe {
errhandlingapi::SetLastError(42);
assert_eq!(errhandlingapi::GetLastError(), gle())
}
}
#[test]
fn works_getlasterror0() {
let lib = Library::new("kernel32.dll").unwrap();
let gle: Symbol<unsafe extern "system" fn() -> DWORD> = unsafe {
lib.get(b"GetLastError\0").unwrap()
};
unsafe {
errhandlingapi::SetLastError(42);
assert_eq!(errhandlingapi::GetLastError(), gle())
}
}