memsafe 0.1.8

A cross-platform Rust library for securely wrapping data in memory
Documentation
// src/lib.rs
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::ptr;
use std::io;

#[cfg(unix)]
use libc::{self, c_void};
#[cfg(windows)]
use winapi::um::{
    memoryapi::{VirtualAlloc, VirtualFree, VirtualLock, VirtualProtect, VirtualUnlock},
    winnt::{MEM_COMMIT, MEM_RESERVE, PAGE_NOACCESS, PAGE_READONLY, PAGE_READWRITE, MEM_RELEASE}, // Added MEM_RELEASE
};

#[derive(Debug)]
pub struct MemoryError(io::Error);

impl From<io::Error> for MemoryError {
    fn from(err: io::Error) -> Self {
        MemoryError(err)
    }
}

impl MemoryError {
    pub fn inner(&self) -> &io::Error {
        &self.0
    }
}

pub struct MemSafe<T> {
    ptr: *mut T,
    len: usize,
    is_writable: UnsafeCell<bool>,
}

impl<T> MemSafe<T> {
    pub fn new(value: T) -> Result<Self, MemoryError> {
        let size = std::mem::size_of::<T>();

        #[cfg(unix)]
        {
            let ptr = unsafe {
                libc::mmap(
                    ptr::null_mut(),
                    size,
                    libc::PROT_READ | libc::PROT_WRITE,
                    libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
                    -1,
                    0,
                )
            };
            if ptr == libc::MAP_FAILED {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            let mem_safe = MemSafe {
                ptr: ptr as *mut T,
                len: size,
                is_writable: UnsafeCell::new(true),
            };
            unsafe { ptr::write(mem_safe.ptr, value); }
            mem_safe.lock_memory()?;
            mem_safe.set_memory_advice()?;
            mem_safe.make_noaccess()?;
            Ok(mem_safe)
        }

        #[cfg(windows)]
        {
            let ptr = unsafe {
                VirtualAlloc(
                    ptr::null_mut(),
                    size,
                    MEM_COMMIT | MEM_RESERVE,
                    PAGE_READWRITE,
                )
            };
            if ptr.is_null() {  // Fixed: Use is_null() instead of MAP_FAILED
                return Err(MemoryError(io::Error::last_os_error()));
            }
            let mem_safe = MemSafe {
                ptr: ptr as *mut T,
                len: size,
                is_writable: UnsafeCell::new(true),
            };
            unsafe { ptr::write(mem_safe.ptr, value); }
            mem_safe.lock_memory()?;
            mem_safe.make_noaccess()?;
            Ok(mem_safe)
        }
    }

    fn make_noaccess(&self) -> Result<(), MemoryError> {
        #[cfg(unix)]
        unsafe {
            if libc::mprotect(self.ptr as *mut c_void, self.len, libc::PROT_NONE) != 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = false;
            Ok(())
        }

        #[cfg(windows)]
        unsafe {
            let mut old_protect = 0;
            if VirtualProtect(self.ptr as *mut _, self.len, PAGE_NOACCESS, &mut old_protect) == 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = false;
            Ok(())
        }
    }

    fn make_writable(&self) -> Result<(), MemoryError> {
        #[cfg(unix)]
        unsafe {
            if libc::mprotect(self.ptr as *mut c_void, self.len, libc::PROT_READ | libc::PROT_WRITE) != 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = true;
            Ok(())
        }

        #[cfg(windows)]
        unsafe {
            let mut old_protect = 0;
            if VirtualProtect(self.ptr as *mut _, self.len, PAGE_READWRITE, &mut old_protect) == 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = true;
            Ok(())
        }
    }

    fn make_readonly(&self) -> Result<(), MemoryError> {
        #[cfg(unix)]
        unsafe {
            if libc::mprotect(self.ptr as *mut c_void, self.len, libc::PROT_READ) != 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = false;
            Ok(())
        }

        #[cfg(windows)]
        unsafe {
            let mut old_protect = 0;
            if VirtualProtect(self.ptr as *mut _, self.len, PAGE_READONLY, &mut old_protect) == 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            *self.is_writable.get() = false;
            Ok(())
        }
    }

    fn lock_memory(&self) -> Result<(), MemoryError> {
        #[cfg(unix)]
        unsafe {
            if libc::mlock(self.ptr as *const c_void, self.len) != 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            Ok(())
        }

        #[cfg(windows)]
        unsafe {
            if VirtualLock(self.ptr as *mut _, self.len) == 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
            Ok(())
        }
    }

    #[cfg(target_os = "linux")]
    fn set_memory_advice(&self) -> Result<(), MemoryError> {
        unsafe {
            if libc::madvise(self.ptr as *mut c_void, self.len, libc::MADV_DONTDUMP) != 0 {
                return Err(MemoryError(io::Error::last_os_error()));
            }
        }
        Ok(())
    }

    #[cfg(not(target_os = "linux"))]
    fn set_memory_advice(&self) -> Result<(), MemoryError> {
        Ok(())
    }
}

impl<T> Deref for MemSafe<T> {
    type Target = T;
    fn deref(&self) -> &Self::Target {
        unsafe {
            if !*self.is_writable.get() {
                self.make_readonly().expect("Failed to make readable");
            }
            &*self.ptr
        }
    }
}

impl<T> DerefMut for MemSafe<T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        self.make_writable().expect("Failed to make writable");
        unsafe { &mut *self.ptr }
    }
}

impl<T> Drop for MemSafe<T> {
    fn drop(&mut self) {
        #[cfg(unix)]
        unsafe {
            self.make_writable().ok();
            ptr::drop_in_place(self.ptr);
            ptr::write_bytes(self.ptr as *mut u8, 0, self.len);
            libc::munlock(self.ptr as *const c_void, self.len);
            libc::munmap(self.ptr as *mut c_void, self.len);
        }

        #[cfg(windows)]
        unsafe {
            self.make_writable().ok();
            ptr::drop_in_place(self.ptr);
            ptr::write_bytes(self.ptr as *mut u8, 0, self.len);
            VirtualUnlock(self.ptr as *mut _, self.len);
            VirtualFree(self.ptr as *mut _, 0, MEM_RELEASE);  // Now in scope
        }
    }
}