use crate::raw::HRESULT;
use crate::typelib::*;
use std::convert::TryInto;
use std::ffi::{c_void, CStr, CString};
#[allow(clippy::upper_case_acronyms)]
type HANDLE = *mut c_void;
fn pascal_case<T: AsRef<str>>(input: T) -> String
{
let input = input.as_ref();
let mut output = String::new();
output.reserve(input.len());
let mut capitalize = true;
for c in input.chars() {
if c == '_' {
capitalize = true;
} else if capitalize {
for c_up in c.to_uppercase() {
output.push(c_up)
}
capitalize = false;
} else {
output.push(c);
}
}
output
}
#[link(name = "ole32")]
extern "system" {
pub fn RegCreateKeyExA(
hKey: HANDLE,
sub_key_path: *const i8,
reserved: u32,
class: *mut u8,
options: u32,
samDesired: u32,
lpSecurityAttributes: *mut c_void,
phkResult: *mut HANDLE,
lpdwDisposition: *mut u32,
) -> HRESULT;
pub fn RegCloseKey(hKey: HANDLE) -> HRESULT;
pub fn RegSetValueExA(
hKey: HANDLE,
lpValueName: *const i8,
reserved: u32,
dwType: u32,
lpData: *const u8,
cbData: u32,
) -> HRESULT;
pub fn RegDeleteKeyA(hKey: HANDLE, lpSubKey: *const i8) -> HRESULT;
}
#[link(name = "kernel32")]
extern "system" {
pub fn GetModuleFileNameA(hModule: HANDLE, lpFilename: *mut u8, nSize: u32) -> u32;
}
struct Key(HANDLE);
const CLASSES_ROOT: Key = Key(0x80000000 as HANDLE);
impl Key
{
pub fn open_key(&self, path: &str) -> Result<Key, HRESULT>
{
let mut result: HANDLE = ::std::ptr::null_mut();
let mut disposition: u32 = 0;
let hr = unsafe {
RegCreateKeyExA(
self.0,
CString::new(path)
.map_err(|_| crate::raw::E_INVALIDARG)?
.as_ptr(),
0,
::std::ptr::null_mut(),
0,
2, ::std::ptr::null_mut(),
&mut result,
&mut disposition,
)
};
match hr.is_success() {
true => Ok(Key(result)),
false => Err(hr),
}
}
pub fn delete_key(&self, path: &str) -> Result<(), HRESULT>
{
let hr = unsafe {
RegDeleteKeyA(
self.0,
CString::new(path)
.map_err(|_| crate::raw::E_INVALIDARG)?
.as_ptr(),
)
};
match hr.is_success() {
true => Ok(()),
false => Err(hr),
}
}
pub fn set_string_value(&self, name: &str, value: &str) -> Result<(), HRESULT>
{
let data = CString::new(value).map_err(|_| crate::raw::E_INVALIDARG)?;
let hr = unsafe {
RegSetValueExA(
self.0,
CString::new(name)
.map_err(|_| crate::raw::E_INVALIDARG)?
.as_ptr(),
0,
1, data.as_ptr() as *const _,
value
.len()
.try_into()
.map_err(|_| crate::raw::E_INVALIDARG)?,
)
};
match hr.is_success() {
true => Ok(()),
false => Err(hr),
}
}
}
impl Drop for Key
{
fn drop(&mut self)
{
unsafe {
RegCloseKey(self.0);
}
}
}
pub fn register(dll: HANDLE, lib: TypeLib) -> Result<(), HRESULT>
{
register_or_unregister(dll, lib, true)
}
pub fn unregister(dll: HANDLE, lib: TypeLib) -> Result<(), HRESULT>
{
register_or_unregister(dll, lib, false)
}
fn register_or_unregister(dll: HANDLE, lib: TypeLib, do_register: bool) -> Result<(), HRESULT>
{
let lib_name = pascal_case(&lib.name);
let lib_version = lib.version.replace('.', "_");
let path = get_module_path(dll)?;
register_typelib(&path, &lib, &lib_name, do_register)?;
for cls in lib.types.iter().filter_map(|t| match t {
TypeInfo::Class(cls) => Some(cls),
_ => None,
}) {
register_class(&path, &lib, &lib_name, &lib_version, cls, do_register)?;
}
Ok(())
}
pub fn register_typelib(
path: &str,
lib: &TypeLib,
lib_name: &str,
do_register: bool,
) -> Result<(), HRESULT>
{
let description = format!("{} TypeLib", lib_name);
#[cfg(target_arch = "x86")]
let arch = "win32";
#[cfg(target_arch = "x86_64")]
let arch = "win64";
let data = vec![
(format!("TypeLib\\{}", lib.libid), "", description),
(
format!("TypeLib\\{}\\{}", lib.libid, lib.version),
"",
format!("{} {}", lib_name, lib.version),
),
(
format!("TypeLib\\{}\\{}\\0", lib.libid, lib.version),
"",
String::new(),
),
(
format!("TypeLib\\{}\\{}\\0\\{}", lib.libid, lib.version, arch),
"",
path.to_string(),
),
(
format!("TypeLib\\{}\\{}\\FLAGS", lib.libid, lib.version),
"",
"0".to_string(),
),
];
for d in data.iter().rev() {
if do_register {
let key = CLASSES_ROOT.open_key(&d.0)?;
if !d.2.is_empty() {
key.set_string_value(d.1, &d.2)?;
}
} else if d.1.is_empty() {
CLASSES_ROOT.delete_key(&d.0)?;
}
}
Ok(())
}
pub fn register_class(
path: &str,
lib: &TypeLib,
lib_name: &str,
lib_version: &str,
cls: &CoClass,
do_register: bool,
) -> Result<(), HRESULT>
{
let latest = format!("{}.{}", lib_name, cls.name);
let curver = format!("{}.{}.{}", lib_name, cls.name, lib_version);
let description = format!("{} {} Class", lib_name, cls.name);
let data = vec![
(latest.clone(), "", description.clone()),
(curver.clone(), "", description.clone()),
(format!("{}\\CLSID", curver), "", cls.clsid.to_string()),
(format!("{}\\CurVer", latest), "", curver.clone()),
(format!("CLSID\\{}", cls.clsid), "", description),
(
format!("CLSID\\{}\\InprocServer32", cls.clsid),
"",
path.to_string(),
),
(
format!("CLSID\\{}\\InprocServer32", cls.clsid),
"ThreadingModel",
"Both".to_string(),
),
(format!("CLSID\\{}\\ProgID", cls.clsid), "", curver),
(
format!("CLSID\\{}\\TypeLib", cls.clsid),
"",
lib.libid.to_string(),
),
(
format!("CLSID\\{}\\VersionIndependentProgID", cls.clsid),
"",
latest,
),
];
for d in data.iter().rev() {
if do_register {
let key = CLASSES_ROOT.open_key(&d.0)?;
key.set_string_value(d.1, &d.2)?;
} else if d.1.is_empty() {
CLASSES_ROOT.delete_key(&d.0)?;
}
}
Ok(())
}
fn get_module_path(dll_handle: HANDLE) -> Result<String, HRESULT>
{
Ok(unsafe {
let mut path = Vec::new();
path.reserve(1024);
let path_len = GetModuleFileNameA(
dll_handle,
path.as_mut_ptr(),
path.capacity()
.try_into()
.map_err(|_| intercom::raw::E_INVALIDARG)?,
);
let path_len = path_len as usize;
if path_len == 0 || path_len >= path.capacity() {
return Err(intercom::raw::E_FAIL);
}
path.set_len(path_len + 1);
CStr::from_bytes_with_nul(&path)
.map_err(|_| intercom::raw::E_FAIL)?
.to_owned()
}
.to_str()
.map_err(|_| intercom::raw::E_FAIL)?
.to_string())
}