Documentation
#![no_std]

extern crate libc;
extern crate libmimalloc_sys as ffi;

use core::ffi::c_void;
use core::alloc::{GlobalAlloc, Layout};

use libmimalloc_sys::*;


// Copied from https://github.com/rust-lang/rust/blob/master/src/libstd/sys_common/alloc.rs
#[cfg(all(any(
    target_arch = "x86",
    target_arch = "arm",
    target_arch = "mips",
    target_arch = "powerpc",
    target_arch = "powerpc64",
    target_arch = "asmjs",
    target_arch = "wasm32"
)))]
const MIN_ALIGN: usize = 8;

#[cfg(all(any(
    target_arch = "x86_64",
    target_arch = "aarch64",
    target_arch = "mips64",
    target_arch = "s390x",
    target_arch = "sparc64"
)))]
const MIN_ALIGN: usize = 16;

/// Drop-in mimalloc global allocator.
///
/// ## Usage
/// ```rust,ignore
/// use mimalloc::MiMalloc;
///
/// #[global_allocator]
/// static GLOBAL: MiMalloc = MiMalloc;
/// ```
pub struct MiMalloc;

#[cfg(target_os = "linux")]
const ONE_G: usize = 1073741824; // 1GB

#[cfg(target_os = "linux")]
#[inline]
fn hugepage_align(layout: Layout) -> usize {
    let mut len_bytes = layout.size();
    // NOTE: 确保尺寸大小向 1G 对齐。
    let rem = len_bytes % ONE_G;
    if rem != 0 {
        len_bytes += ONE_G - rem;
    }
    debug_assert_eq!(len_bytes % ONE_G, 0);

    len_bytes
}

#[cfg(target_os = "linux")]
#[inline]
unsafe fn hugepage_alloc_zeroed(layout: Layout) -> *mut u8 {
    // Example Code:
    // https://github.com/torvalds/linux/blob/master/tools/testing/selftests/vm/map_hugetlb.c
    let len_bytes = hugepage_align(layout);
    
    // NOTE: flag `MAP_ANONYMOUS` 指示返回的内存区域会使用 0 来填充。
    const FLAGS: libc::c_int         = libc::MAP_PRIVATE | libc::MAP_ANONYMOUS | libc::MAP_HUGETLB;
    const PROTECT_FLAGS: libc::c_int = libc::PROT_READ | libc::PROT_WRITE; // protect_flags
    
    let addr = core::ptr::null_mut();
    let ptr: *mut libc::c_void = libc::mmap(addr, len_bytes, PROTECT_FLAGS, FLAGS, -1, 0);
    if ptr == libc::MAP_FAILED {
        // NOTE: 巨页分配失败,使用下面的传统的分配逻辑。
        //       如果巨页数量足够多的话,没必要记录错误信息,
        //       以防止后续巨页分配尝试。
        return core::ptr::null_mut();
    } else {
        return ptr as *mut u8;
    }
}

#[cfg(target_os = "linux")]
#[inline]
unsafe fn hugepage_dealloc(ptr: *mut u8, layout: Layout) -> libc::c_int {
    let len_bytes = hugepage_align(layout);
    // NOTE: `0` is success, other is failed.
    libc::munmap(ptr as *mut libc::c_void, len_bytes)
}

unsafe impl GlobalAlloc for MiMalloc {
    #[inline]
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        #[cfg(target_os = "linux")]
        {
            if layout.size() >= ONE_G {
                let ptr = hugepage_alloc_zeroed(layout);
                if !ptr.is_null() {
                    return ptr;
                }
            }
        }

        if layout.align() <= MIN_ALIGN && layout.align() <= layout.size() {
            mi_malloc(layout.size()) as *mut u8
        } else {
            #[cfg(target_os = "macos")]
            if layout.align() > (1 << 31) {
                return core::ptr::null_mut();
            }

            mi_malloc_aligned(layout.size(), layout.align()) as *mut u8
        }
    }

    #[inline]
    unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
        #[cfg(target_os = "linux")]
        {
            if layout.size() >= ONE_G {
                let ptr = hugepage_alloc_zeroed(layout);
                if !ptr.is_null() {
                    return ptr;
                }
            }
        }

        if layout.align() <= MIN_ALIGN && layout.align() <= layout.size() {
            mi_zalloc(layout.size()) as *mut u8
        } else {
            #[cfg(target_os = "macos")]
            if layout.align() > (1 << 31) {
                return core::ptr::null_mut();
            }

            mi_zalloc_aligned(layout.size(), layout.align()) as *mut u8
        }
    }

    #[inline]
    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
        #[cfg(target_os = "linux")]
        {
            if layout.size() >= ONE_G {
                let ret = hugepage_dealloc(ptr, layout);
                if ret == 0 {
                    return ();
                }
            }
        }

        mi_free(ptr as *mut c_void);
    }

    #[inline]
    unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
        if layout.align() <= MIN_ALIGN && layout.align() <= layout.size() {
            mi_realloc(ptr as *mut c_void, new_size) as *mut u8
        } else {
            mi_realloc_aligned(ptr as *mut c_void, new_size, layout.align()) as *mut u8
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn it_frees_allocated_memory() {
        unsafe {
            let layout = Layout::from_size_align(8, 8).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc(layout);
            alloc.dealloc(ptr, layout);
        }
    }

    #[test]
    fn it_frees_allocated_big_memory() {
        unsafe {
            let layout = Layout::from_size_align(1 << 20, 32).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc(layout);
            alloc.dealloc(ptr, layout);
        }
    }

    #[test]
    fn it_frees_zero_allocated_memory() {
        unsafe {
            let layout = Layout::from_size_align(8, 8).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc_zeroed(layout);
            alloc.dealloc(ptr, layout);
        }
    }

    #[test]
    fn it_frees_zero_allocated_big_memory() {
        unsafe {
            let layout = Layout::from_size_align(1 << 20, 32).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc_zeroed(layout);
            alloc.dealloc(ptr, layout);
        }
    }

    #[test]
    fn it_frees_reallocated_memory() {
        unsafe {
            let layout = Layout::from_size_align(8, 8).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc(layout);
            let ptr = alloc.realloc(ptr, layout, 16);
            alloc.dealloc(ptr, layout);
        }
    }

    #[test]
    fn it_frees_reallocated_big_memory() {
        unsafe {
            let layout = Layout::from_size_align(1 << 20, 32).unwrap();
            let alloc = MiMalloc;

            let ptr = alloc.alloc(layout);
            let ptr = alloc.realloc(ptr, layout, 2 << 20);
            alloc.dealloc(ptr, layout);
        }
    }
}