use {
super::python_resources::PythonResourcesState,
lazy_static::lazy_static,
memory_module_sys::{
MemoryFreeLibrary, MemoryGetProcAddress, MemoryLoadLibraryEx, HCUSTOMMODULE,
},
std::collections::HashMap,
std::ffi::{c_void, CStr},
std::sync::atomic::{AtomicUsize, Ordering},
std::sync::Mutex,
winapi::shared::basetsd::SIZE_T,
winapi::shared::minwindef::{BOOL, DWORD, FARPROC, HINSTANCE__, LPVOID},
winapi::shared::ntdef::LPCSTR,
winapi::um::libloaderapi::{FreeLibrary, GetProcAddress, LoadLibraryA},
winapi::um::memoryapi::{VirtualAlloc, VirtualFree},
};
struct MemoryModule {
ptr: *const c_void,
ref_count: AtomicUsize,
}
struct MemoryModules {
modules: HashMap<String, MemoryModule>,
module_ptrs: Vec<*const c_void>,
}
unsafe impl Send for MemoryModules {}
lazy_static! {
static ref MEMORY_MODULES: Mutex<MemoryModules> = {
Mutex::new(MemoryModules {
modules: HashMap::new(),
module_ptrs: Vec::new(),
})
};
}
pub(crate) unsafe fn load_library_memory(
resources_state: &PythonResourcesState<u8>,
data: &[u8],
) -> *const c_void {
MemoryLoadLibraryEx(
data.as_ptr() as *const c_void,
data.len(),
default_alloc,
default_free,
custom_load_library,
custom_get_proc_address,
custom_free_library,
resources_state as *const PythonResourcesState<u8> as *mut c_void,
)
}
pub(crate) unsafe fn free_library_memory(module: *const c_void) {
MemoryFreeLibrary(module);
}
pub(crate) unsafe fn get_proc_address_memory(module: *const c_void, name: &CStr) -> FARPROC {
MemoryGetProcAddress(module, name.as_ptr())
}
#[no_mangle]
extern "C" fn default_alloc(
address: LPVOID,
size: SIZE_T,
allocation_type: DWORD,
protect: DWORD,
_user_data: *mut c_void,
) -> LPVOID {
unsafe { VirtualAlloc(address, size, allocation_type, protect) }
}
#[no_mangle]
extern "C" fn default_free(
address: LPVOID,
size: SIZE_T,
free_type: DWORD,
_user_data: *mut c_void,
) -> BOOL {
unsafe { VirtualFree(address, size, free_type) }
}
#[no_mangle]
extern "C" fn custom_load_library(filename: LPCSTR, user_data: *mut c_void) -> HCUSTOMMODULE {
assert!(!user_data.is_null());
let name = unsafe { CStr::from_ptr(filename) }.to_string_lossy();
{
let memory_state = MEMORY_MODULES.lock().unwrap();
if let Some(module) = memory_state.modules.get(name.as_ref()) {
module.ref_count.fetch_add(1, Ordering::Acquire);
return module.ptr;
}
}
let resources_state = unsafe {
(user_data as *const PythonResourcesState<u8>)
.as_ref()
.unwrap()
};
if let Some(entry) = resources_state.resources.get(name.as_ref()) {
if let Some(library_data) = &entry.in_memory_shared_library {
let res = unsafe { load_library_memory(resources_state, library_data) };
if !res.is_null() {
let mut memory_state = MEMORY_MODULES.lock().unwrap();
memory_state.modules.insert(
name.to_string(),
MemoryModule {
ptr: res,
ref_count: AtomicUsize::new(1),
},
);
memory_state.module_ptrs.push(res);
}
return res;
}
}
let result = unsafe { LoadLibraryA(filename) };
if result.is_null() {
std::ptr::null() as HCUSTOMMODULE
} else {
result as HCUSTOMMODULE
}
}
#[no_mangle]
extern "C" fn custom_get_proc_address(
module: HCUSTOMMODULE,
name: LPCSTR,
_user_data: *mut c_void,
) -> FARPROC {
{
let memory_state = MEMORY_MODULES.lock().unwrap();
if memory_state.module_ptrs.contains(&module) {
return unsafe { MemoryGetProcAddress(module, name) };
}
}
unsafe { GetProcAddress(module as *mut HINSTANCE__, name) }
}
#[no_mangle]
extern "C" fn custom_free_library(module: HCUSTOMMODULE, _user_data: *mut c_void) {
let mut memory_state = MEMORY_MODULES.lock().unwrap();
if let Some(index) = memory_state
.module_ptrs
.iter()
.position(|ptr| ptr == &module)
{
memory_state.module_ptrs.remove(index);
let mut free_module = None;
for (name, module_state) in &memory_state.modules {
if module_state.ptr == module {
if module_state.ref_count.fetch_sub(1, Ordering::Acquire) == 1 {
free_module = Some(name.to_string());
break;
}
}
}
if let Some(free_module) = free_module {
memory_state.modules.remove(&free_module);
std::mem::drop(memory_state);
return unsafe { MemoryFreeLibrary(module) };
}
}
unsafe {
FreeLibrary(module as *mut HINSTANCE__);
}
}