#include "mimalloc.h"
#include "mimalloc-internal.h"
#if !defined(_WIN32)
#error "this file should only be included on Windows"
#endif
#include <windows.h>
#include <psapi.h>
static int __cdecl mi_setmaxstdio(int newmax);
#define UNUSED(x) (void)(x)
static void* mi__expand(void* p, size_t newsize) {
void* res = mi_expand(p, newsize);
if (res == NULL) errno = ENOMEM;
return res;
}
static void mi_free_term(void* p) {
UNUSED(p);
}
static void* mi_realloc_term(void* p, size_t newsize) {
UNUSED(p); UNUSED(newsize);
return NULL;
}
static void* mi__recalloc_term(void* p, size_t newcount, size_t newsize) {
UNUSED(p); UNUSED(newcount); UNUSED(newsize);
return NULL;
}
static void* mi__expand_term(void* p, size_t newsize) {
UNUSED(p); UNUSED(newsize);
return NULL;
}
static size_t mi__msize_term(void* p) {
UNUSED(p);
return 0;
}
static void* mi__malloc_dbg(size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return _malloc_base(size);
}
static void* mi__calloc_dbg(size_t count, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return _calloc_base(count, size);
}
static void* mi__realloc_dbg(void* p, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return _realloc_base(p, size);
}
static void mi__free_dbg(void* p, int block_type) {
UNUSED(block_type);
_free_base(p);
}
static void* mi__recalloc_dbg(void* p, size_t count, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return mi_recalloc(p, count, size);
}
static void* mi__expand_dbg(void* p, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return mi__expand(p, size);
}
static size_t mi__msize_dbg(void* p, int block_type) {
UNUSED(block_type);
return mi_usable_size(p);
}
static void* mi__recalloc_dbg_term(void* p, size_t count, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return mi__recalloc_term(p, count, size);
}
static void* mi__expand_dbg_term(void* p, size_t size, int block_type, const char* fname, int line) {
UNUSED(block_type); UNUSED(fname); UNUSED(line);
return mi__expand_term(p, size);
}
static size_t mi__msize_dbg_term(void* p, int block_type) {
UNUSED(block_type);
return mi__msize_term(p);
}
typedef void (cbfun_t)(void);
typedef int (atexit_fun_t)(cbfun_t* fn);
typedef uintptr_t encoded_t;
typedef struct exit_list_s {
encoded_t functions; size_t count;
size_t capacity;
} exit_list_t;
#define MI_EXIT_INC (64)
static exit_list_t atexit_list = { 0, 0, 0 };
static exit_list_t at_quick_exit_list = { 0, 0, 0 };
static CRITICAL_SECTION atexit_lock;
static encoded_t canary;
static inline void *decode(encoded_t x) {
return (void*)(x^canary);
}
static inline encoded_t encode(void* p) {
return ((uintptr_t)p ^ canary);
}
static void init_canary()
{
canary = _mi_random_init(0);
atexit_list.functions = at_quick_exit_list.functions = encode(NULL);
}
static void mi_initialize_atexit(void) {
InitializeCriticalSection(&atexit_lock);
init_canary();
}
static int mi_register_atexit(exit_list_t* list, cbfun_t* fn) {
if (fn == NULL) return EINVAL;
EnterCriticalSection(&atexit_lock);
encoded_t* functions = (encoded_t*)decode(list->functions);
if (list->count >= list->capacity) { encoded_t* newf = (encoded_t*)mi_recalloc(functions, list->capacity + MI_EXIT_INC, sizeof(cbfun_t*));
if (newf != NULL) {
list->capacity += MI_EXIT_INC;
list->functions = encode(newf);
functions = newf;
}
}
int result;
if (list->count < list->capacity && functions != NULL) {
functions[list->count] = encode(fn);
list->count++;
result = 0; }
else {
result = ENOMEM;
}
LeaveCriticalSection(&atexit_lock);
return result;
}
static int mi__crt_atexit(cbfun_t* fn) {
return mi_register_atexit(&atexit_list,fn);
}
static int mi__crt_at_quick_exit(cbfun_t* fn) {
return mi_register_atexit(&at_quick_exit_list,fn);
}
static void mi_execute_exit_list(exit_list_t* list) {
EnterCriticalSection(&atexit_lock);
exit_list_t clist = *list;
memset(list,0,sizeof(*list));
LeaveCriticalSection(&atexit_lock);
encoded_t* functions = (encoded_t*)decode(clist.functions);
if (functions != NULL) {
for (size_t i = clist.count; i > 0; i--) { cbfun_t* fn = (cbfun_t*)decode(functions[i-1]);
if (fn==NULL) break; fn();
}
mi_free(functions);
}
}
#if defined(_M_IX86) || defined(_M_X64)
#define MI_JUMP_SIZE 14
typedef struct mi_jump_s {
uint8_t opcodes[MI_JUMP_SIZE];
} mi_jump_t;
void mi_jump_restore(void* current, const mi_jump_t* saved) {
memcpy(current, &saved->opcodes, MI_JUMP_SIZE);
}
void mi_jump_write(void* current, void* target, mi_jump_t* save) {
if (save != NULL) {
memcpy(&save->opcodes, current, MI_JUMP_SIZE);
}
uint8_t* opcodes = ((mi_jump_t*)current)->opcodes;
ptrdiff_t diff = (uint8_t*)target - (uint8_t*)current;
uint32_t ofs32 = (uint32_t)diff;
#ifdef _M_X64
uint64_t ofs64 = (uint64_t)diff;
if (ofs64 != (uint64_t)ofs32) {
opcodes[0] = 0xFF;
opcodes[1] = 0x25;
*((uint32_t*)&opcodes[2]) = 0;
*((uint64_t*)&opcodes[6]) = (uint64_t)target;
}
else
#endif
{
opcodes[0] = 0xE9;
*((uint32_t*)&opcodes[1]) = ofs32 - 5 ;
}
}
#elif defined(_M_ARM64)
#define MI_JUMP_SIZE 16
typedef struct mi_jump_s {
uint8_t opcodes[MI_JUMP_SIZE];
} mi_jump_t;
void mi_jump_restore(void* current, const mi_jump_t* saved) {
memcpy(current, &saved->opcodes, MI_JUMP_SIZE);
}
void mi_jump_write(void* current, void* target, mi_jump_t* save) {
if (save != NULL) {
memcpy(&save->opcodes, current, MI_JUMP_SIZE);
}
uint8_t* opcodes = ((mi_jump_t*)current)->opcodes;
uint64_t diff = (uint8_t*)target - (uint8_t*)current;
static const uint8_t jump_opcodes[8] = { 0x50, 0x00, 0x00, 0x58, 0x00, 0x02, 0x3F, 0xD6 };
memcpy(&opcodes[0], jump_opcodes, sizeof(jump_opcodes));
*((uint64_t*)&opcodes[8]) = diff;
}
#else
#error "define jump instructions for this platform"
#endif
typedef enum patch_apply_e {
PATCH_NONE,
PATCH_TARGET,
PATCH_TARGET_TERM
} patch_apply_t;
typedef struct mi_patch_s {
const char* name; void* original; void* target; void* target_term; patch_apply_t applied; mi_jump_t save; } mi_patch_t;
#define MI_PATCH_NAME3(name,target,term) { name, NULL, &target, &term, false }
#define MI_PATCH_NAME2(name,target) { name, NULL, &target, NULL, false }
#define MI_PATCH3(name,target,term) MI_PATCH_NAME3(#name, target, term)
#define MI_PATCH2(name,target) MI_PATCH_NAME2(#name, target)
#define MI_PATCH1(name) MI_PATCH2(name,mi_##name)
static mi_patch_t patches[] = {
MI_PATCH2(_crt_atexit, mi__crt_atexit),
MI_PATCH2(_crt_at_quick_exit, mi__crt_at_quick_exit),
MI_PATCH2(_setmaxstdio, mi_setmaxstdio),
MI_PATCH2(_malloc_base, mi_malloc),
MI_PATCH2(_calloc_base, mi_calloc),
MI_PATCH3(_realloc_base, mi_realloc,mi_realloc_term),
MI_PATCH3(_free_base, mi_free,mi_free_term),
MI_PATCH3(_expand, mi__expand,mi__expand_term),
MI_PATCH3(_recalloc, mi_recalloc,mi__recalloc_term),
MI_PATCH3(_msize, mi_usable_size,mi__msize_term),
MI_PATCH_NAME3("_recalloc_base", mi_recalloc,mi__recalloc_term),
MI_PATCH_NAME3("_msize_base", mi_usable_size,mi__msize_term),
MI_PATCH2(_strdup, mi_strdup),
MI_PATCH2(_strndup, mi_strndup),
MI_PATCH2(_malloc_dbg, mi__malloc_dbg),
MI_PATCH2(_realloc_dbg, mi__realloc_dbg),
MI_PATCH2(_calloc_dbg, mi__calloc_dbg),
MI_PATCH2(_free_dbg, mi__free_dbg),
MI_PATCH3(_expand_dbg, mi__expand_dbg, mi__expand_dbg_term),
MI_PATCH3(_recalloc_dbg, mi__recalloc_dbg, mi__recalloc_dbg_term),
MI_PATCH3(_msize_dbg, mi__msize_dbg, mi__msize_dbg_term),
#ifdef _WIN64
MI_PATCH_NAME2("??2@YAPEAX_K@Z", mi_malloc),
MI_PATCH_NAME2("??_U@YAPEAX_K@Z", mi_malloc),
MI_PATCH_NAME3("??3@YAXPEAX@Z", mi_free, mi_free_term),
MI_PATCH_NAME3("??_V@YAXPEAX@Z", mi_free, mi_free_term),
MI_PATCH_NAME2("??2@YAPEAX_KAEBUnothrow_t@std@@@Z", mi_malloc),
MI_PATCH_NAME2("??_U@YAPEAX_KAEBUnothrow_t@std@@@Z", mi_malloc),
MI_PATCH_NAME3("??3@YAXPEAXAEBUnothrow_t@std@@@Z", mi_free, mi_free_term),
MI_PATCH_NAME3("??_V@YAXPEAXAEBUnothrow_t@std@@@Z", mi_free, mi_free_term),
#else
MI_PATCH_NAME2("??2@YAPAXI@Z", mi_malloc),
MI_PATCH_NAME2("??_U@YAPAXI@Z", mi_malloc),
MI_PATCH_NAME3("??3@YAXPAX@Z", mi_free, mi_free_term),
MI_PATCH_NAME3("??_V@YAXPAX@Z", mi_free, mi_free_term),
MI_PATCH_NAME2("??2@YAPAXIABUnothrow_t@std@@@Z", mi_malloc),
MI_PATCH_NAME2("??_U@YAPAXIABUnothrow_t@std@@@Z", mi_malloc),
MI_PATCH_NAME3("??3@YAXPAXABUnothrow_t@std@@@Z", mi_free, mi_free_term),
MI_PATCH_NAME3("??_V@YAXPAXABUnothrow_t@std@@@Z", mi_free, mi_free_term),
#endif
{ NULL, NULL, NULL, false }
};
static bool mi_patch_apply(mi_patch_t* patch, patch_apply_t apply)
{
if (patch->original == NULL) return true; if (apply == PATCH_TARGET_TERM && patch->target_term == NULL) apply = PATCH_TARGET; if (patch->applied == apply) return false;
DWORD protect = PAGE_READWRITE;
if (!VirtualProtect(patch->original, MI_JUMP_SIZE, PAGE_EXECUTE_READWRITE, &protect)) return false;
if (apply == PATCH_NONE) {
mi_jump_restore(patch->original, &patch->save);
}
else {
void* target = (apply == PATCH_TARGET ? patch->target : patch->target_term);
mi_assert_internal(target!=NULL);
if (target != NULL) mi_jump_write(patch->original, target, &patch->save);
}
patch->applied = apply;
VirtualProtect(patch->original, MI_JUMP_SIZE, protect, &protect);
return true;
}
static bool _mi_patches_apply(patch_apply_t apply, patch_apply_t* previous) {
static patch_apply_t current = PATCH_NONE;
if (previous != NULL) *previous = current;
if (current == apply) return true;
current = apply;
bool ok = true;
for (size_t i = 0; patches[i].name != NULL; i++) {
if (!mi_patch_apply(&patches[i], apply)) ok = false;
}
return ok;
}
mi_decl_export void mi_patches_disable(void) {
_mi_patches_apply(PATCH_NONE, NULL);
}
mi_decl_export bool mi_patches_enable(void) {
return _mi_patches_apply( PATCH_TARGET, NULL );
}
mi_decl_export bool mi_patches_enable_term(void) {
return _mi_patches_apply(PATCH_TARGET_TERM, NULL);
}
static int __cdecl mi_setmaxstdio(int newmax) {
patch_apply_t previous;
_mi_patches_apply(PATCH_NONE, &previous); int result = _setmaxstdio(newmax); _mi_patches_apply(previous,NULL); return result;
}
static void mi_module_resolve(HMODULE mod) {
for (size_t i = 0; patches[i].name != NULL; i++) {
mi_patch_t* patch = &patches[i];
if (!patch->applied && patch->original==NULL) {
void* addr = GetProcAddress(mod, patch->name);
if (addr != NULL) {
patch->original = addr;
}
}
}
}
#define MIMALLOC_NAME "mimalloc-override"
#define UCRTBASE_NAME "ucrtbase"
static atexit_fun_t* crt_atexit = NULL;
static atexit_fun_t* crt_at_quick_exit = NULL;
static bool mi_patches_resolve(void) {
HANDLE process = GetCurrentProcess(); DWORD needed = 0;
HMODULE modules[400]; EnumProcessModules(process, modules, sizeof(modules), &needed);
if (needed == 0) return false;
size_t count = needed / sizeof(HMODULE);
size_t ucrtbase_index = 0;
size_t mimalloc_index = 0;
for (size_t i = 0; i < count; i++) {
HMODULE mod = modules[i];
char filename[MAX_PATH] = { 0 };
DWORD slen = GetModuleFileName(mod, filename, MAX_PATH);
if (slen > 0 && slen < MAX_PATH) {
filename[slen] = 0;
const char* lastsep = strrchr(filename, '\\');
const char* basename = (lastsep==NULL ? filename : lastsep+1);
if (i==0 || _strnicmp(basename, "ucrt", 4) == 0 || _strnicmp(basename, "msvcr", 5) == 0) {
if (_stricmp(basename, MIMALLOC_NAME) == 0) mimalloc_index = i;
if (_stricmp(basename, UCRTBASE_NAME) == 0) ucrtbase_index = i;
mi_module_resolve(mod);
if (crt_atexit==NULL) crt_atexit = (atexit_fun_t*)GetProcAddress(mod, "_crt_atexit");
if (crt_at_quick_exit == NULL) crt_at_quick_exit = (atexit_fun_t*)GetProcAddress(mod, "_crt_at_quick_exit");
}
}
}
#if (MI_DEBUG)
size_t diff = (mimalloc_index > ucrtbase_index ? mimalloc_index - ucrtbase_index : ucrtbase_index - mimalloc_index);
if ((mimalloc_index > 0 || ucrtbase_index > 0) && (diff != 1)) {
_mi_warning_message("warning: the \"mimalloc-override\" DLL seems not to load right before or after the C runtime (\"ucrtbase\").\n"
" Try to fix this by changing the linking order.");
}
#endif
return true;
}
extern BOOL WINAPI _DllMainCRTStartup(HINSTANCE inst, DWORD reason, LPVOID reserved);
static DWORD mi_fls_unwind_entry;
static void NTAPI mi_fls_unwind(PVOID value) {
if (value != NULL) mi_patches_enable(); return;
}
static void mi_patches_atexit(void) {
mi_execute_exit_list(&atexit_list);
mi_patches_enable_term(); }
static void mi_patches_at_quick_exit(void) {
mi_execute_exit_list(&at_quick_exit_list);
mi_patches_enable_term(); }
__declspec(dllexport) BOOL WINAPI DllEntry(HINSTANCE inst, DWORD reason, LPVOID reserved) {
if (reason == DLL_PROCESS_ATTACH) {
__security_init_cookie();
}
else if (reason == DLL_PROCESS_DETACH) {
mi_patches_enable_term();
}
BOOL ok = _DllMainCRTStartup(inst, reason, reserved);
if (reason == DLL_PROCESS_ATTACH && ok) {
ok = mi_patches_resolve();
if (ok) {
mi_fls_unwind_entry = FlsAlloc(&mi_fls_unwind);
if (mi_fls_unwind_entry != FLS_OUT_OF_INDEXES) {
FlsSetValue(mi_fls_unwind_entry, (void*)1);
}
mi_initialize_atexit();
if (crt_atexit != NULL) (*crt_atexit)(&mi_patches_atexit);
if (crt_at_quick_exit != NULL) (*crt_at_quick_exit)(&mi_patches_at_quick_exit);
mi_patches_enable();
mi_stats_reset();
}
}
return ok;
}