use super::HardError;
use errno::errno;
use libsodium_sys as sodium;
use std::ffi::c_void;
use std::ptr::NonNull;
pub unsafe fn malloc<T>() -> Result<NonNull<T>, HardError> {
let ptr = sodium::sodium_malloc(std::mem::size_of::<T>()) as *mut ();
NonNull::new(ptr)
.map(|p| p.cast())
.ok_or(HardError::AllocationFailed(errno()))
}
pub unsafe fn free<T>(ptr: NonNull<T>) {
sodium::sodium_free(ptr.as_ptr() as *mut c_void)
}
pub unsafe fn memzero<T>(ptr: NonNull<T>) {
sodium::sodium_memzero(ptr.as_ptr() as *mut c_void, std::mem::size_of::<T>())
}
pub unsafe fn memcmp<T>(a: NonNull<T>, b: NonNull<T>) -> bool {
sodium::sodium_memcmp(
a.as_ptr() as *const c_void,
b.as_ptr() as *const c_void,
std::mem::size_of::<T>(),
) == 0
}
pub unsafe fn mprotect_noaccess<T>(ptr: NonNull<T>) -> Result<(), HardError> {
if sodium::sodium_mprotect_noaccess(ptr.as_ptr() as *mut c_void) == 0 {
Ok(())
} else {
Err(HardError::MprotectNoAccessFailed(errno()))
}
}
pub unsafe fn mprotect_readonly<T>(ptr: NonNull<T>) -> Result<(), HardError> {
if sodium::sodium_mprotect_readonly(ptr.as_ptr() as *mut c_void) == 0 {
Ok(())
} else {
Err(HardError::MprotectReadOnlyFailed(errno()))
}
}
pub unsafe fn mprotect_readwrite<T>(ptr: NonNull<T>) -> Result<(), HardError> {
if sodium::sodium_mprotect_readwrite(ptr.as_ptr() as *mut c_void) == 0 {
Ok(())
} else {
Err(HardError::MprotectReadWriteFailed(errno()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init, HardError};
use std::ptr::NonNull;
#[test]
fn malloc_and_free() -> Result<(), HardError> {
unsafe {
init()?;
let ptr_a: NonNull<u8> = malloc()?; let ptr_b: NonNull<[u8; 1 << 3]> = malloc()?; let ptr_c: NonNull<[u8; 1 << 10]> = malloc()?; let ptr_d: NonNull<[u8; 1 << 20]> = malloc()?;
free(ptr_a);
free(ptr_b);
free(ptr_c);
free(ptr_d);
Ok(())
}
}
#[test]
fn memzero_does_zero() -> Result<(), HardError> {
unsafe {
init()?;
let ptr_a: NonNull<u8> = malloc()?; let ptr_b: NonNull<[u8; 1 << 3]> = malloc()?; let ptr_c: NonNull<[u8; 1 << 10]> = malloc()?; let ptr_d: NonNull<[u8; 1 << 20]> = malloc()?;
memzero(ptr_a);
memzero(ptr_b);
memzero(ptr_c);
memzero(ptr_d);
assert_eq!(ptr_a.as_ref(), &0);
assert_eq!(&ptr_b.as_ref()[..], &[0; 1 << 3][..]);
assert_eq!(&ptr_c.as_ref()[..], &[0; 1 << 10][..]);
assert_eq!(&ptr_d.as_ref()[..], &[0; 1 << 20][..]);
free(ptr_a);
free(ptr_b);
free(ptr_c);
free(ptr_d);
Ok(())
}
}
#[test]
fn memcmp_compare_works() -> Result<(), HardError> {
unsafe {
init()?;
let mut ptr_a: NonNull<[u8; 1 << 3]> = malloc()?;
let mut ptr_b: NonNull<[u8; 1 << 3]> = malloc()?;
ptr_a
.as_mut()
.copy_from_slice(&[0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe]);
ptr_b
.as_mut()
.copy_from_slice(&[0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe]);
assert!(memcmp(ptr_a, ptr_b));
ptr_b
.as_mut()
.copy_from_slice(&[0xff, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe]);
assert!(!memcmp(ptr_a, ptr_b));
ptr_b
.as_mut()
.copy_from_slice(&[0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xff]);
assert!(!memcmp(ptr_a, ptr_b));
free(ptr_a);
free(ptr_b);
Ok(())
}
}
#[test]
fn mprotect_works() -> Result<(), HardError> {
unsafe {
init()?;
let mut ptr: NonNull<[u8; 32]> = malloc()?;
ptr.as_mut().copy_from_slice(&[0xfe; 32]);
mprotect_noaccess(ptr)?;
mprotect_readonly(ptr)?;
assert_eq!(&ptr.as_ref()[..], &[0xfe; 32][..]);
mprotect_readwrite(ptr)?;
ptr.as_mut().copy_from_slice(&[0xba; 32]);
assert_eq!(&ptr.as_ref()[..], &[0xba; 32][..]);
free(ptr);
Ok(())
}
}
}