pub mod errors;
use self::errors::MemError;
use std::alloc::{Alloc, Global, Layout};
use std::mem;
use std::num::NonZeroUsize;
use std::ptr;
use std::ptr::NonNull;
pub type MemResult<T> = Result<T, MemError>;
pub unsafe fn alloc(size: NonZeroUsize) -> MemResult<NonNull<u8>> {
let layout: Layout = Layout::from_size_align(size.get(), mem::align_of::<u8>())?;
Global.alloc(layout).map_err(Into::into)
}
pub unsafe fn dealloc(ptr: NonNull<u8>, size: NonZeroUsize) -> MemResult<()> {
let layout = Layout::from_size_align(size.get(), mem::align_of::<u8>())?;
Global.dealloc(ptr, layout);
Ok(())
}
pub const STR_LEN_BYTES: usize = 4;
pub unsafe fn write_str_to_mem(str: &str) -> MemResult<NonNull<u8>> {
let str_len = str.len();
let total_len = STR_LEN_BYTES
.checked_add(str_len)
.ok_or_else(|| MemError::new("usize overflow occurred"))?;
let len_as_bytes: [u8; STR_LEN_BYTES] = mem::transmute((str_len as u32).to_le());
let result_ptr = alloc(NonZeroUsize::new_unchecked(total_len))?;
ptr::copy_nonoverlapping(len_as_bytes.as_ptr(), result_ptr.as_ptr(), STR_LEN_BYTES);
ptr::copy_nonoverlapping(
str.as_ptr(),
result_ptr.as_ptr().add(STR_LEN_BYTES),
str_len,
);
Ok(result_ptr)
}
pub unsafe fn deref_str(ptr: *mut u8, len: usize) -> String {
String::from_raw_parts(ptr, len, len)
}
pub unsafe fn read_str_from_fat_ptr(ptr: NonNull<u8>) -> MemResult<String> {
let str_len = read_len(ptr.as_ptr()) as usize;
let total_len = STR_LEN_BYTES
.checked_add(str_len)
.ok_or_else(|| MemError::new("usize overflow occurred"))?;
let mut str = deref_str(ptr.as_ptr(), total_len);
{
str.drain(0..STR_LEN_BYTES);
}
Ok(str)
}
unsafe fn read_len(ptr: *mut u8) -> u32 {
let mut str_len_as_bytes: [u8; STR_LEN_BYTES] = [0; STR_LEN_BYTES];
ptr::copy_nonoverlapping(ptr, str_len_as_bytes.as_mut_ptr(), STR_LEN_BYTES);
mem::transmute(str_len_as_bytes)
}
#[cfg(test)]
mod test {
use super::*;
use std::num::NonZeroUsize;
#[test]
fn alloc_dealloc_test() {
unsafe {
let size = NonZeroUsize::new_unchecked(123);
let ptr = alloc(size).unwrap();
assert_eq!(dealloc(ptr, size).unwrap(), ());
}
}
#[test]
fn write_and_read_str_test() {
unsafe {
let src_str = "some string Ω";
let ptr = write_str_to_mem(src_str).unwrap();
let result_str = read_str_from_fat_ptr(ptr).unwrap();
assert_eq!(src_str, result_str);
}
}
fn create_big_str(len: usize) -> String {
unsafe { String::from_utf8_unchecked(vec!['Q' as u8; len]) }
}
#[test]
fn lot_of_write_and_read_str_test() {
unsafe {
let mb_str = create_big_str(1024 * 1024);
for _ in 1..10_000 {
let ptr = write_str_to_mem(&mb_str).unwrap();
let result_str = read_str_from_fat_ptr(ptr).unwrap();
assert_eq!(mb_str, result_str);
}
}
}
}