flex_alloc_secure/
alloc.rs

1//! Support for virtual memory management, including memory protections.
2
3use core::alloc::Layout;
4use core::mem::transmute;
5use core::ptr::{self, NonNull};
6use core::sync::atomic::{AtomicUsize, Ordering};
7use core::{fmt, slice};
8
9#[cfg(all(windows, not(miri)))]
10use core::mem::MaybeUninit;
11
12use flex_alloc::alloc::{AllocError, Allocator, AllocatorDefault, AllocatorZeroizes};
13use flex_alloc::StorageError;
14use zeroize::Zeroize;
15
16#[cfg(all(unix, not(miri)))]
17use libc::{free, mlock, mprotect, munlock, posix_memalign};
18
19#[cfg(all(windows, not(miri)))]
20use windows_sys::Win32::System::{Memory, SystemInformation};
21
22/// Indicator value to help detect uninitialized protected data.
23pub const UNINIT_ALLOC_BYTE: u8 = 0xdb;
24
25/// An error which may result from a memory operation such as locking.
26#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
27pub struct MemoryError;
28
29impl fmt::Display for MemoryError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.write_str("Memory error")
32    }
33}
34
35impl std::error::Error for MemoryError {}
36
37/// Enumeration of options for setting the memory protection mode.
38#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
39pub enum ProtectionMode {
40    /// No read or write access
41    NoAccess,
42    /// Read-only access
43    ReadOnly,
44    /// Read-write access
45    #[default]
46    ReadWrite,
47}
48
49impl ProtectionMode {
50    #[cfg(all(unix, not(miri)))]
51    pub(crate) const fn as_native(self) -> i32 {
52        match self {
53            Self::NoAccess => libc::PROT_NONE,
54            Self::ReadOnly => libc::PROT_READ,
55            Self::ReadWrite => libc::PROT_READ | libc::PROT_WRITE,
56        }
57    }
58
59    #[cfg(all(windows, not(miri)))]
60    pub(crate) const fn as_native(self) -> u32 {
61        match self {
62            Self::NoAccess => windows_sys::Win32::System::Memory::PAGE_NOACCESS,
63            Self::ReadOnly => windows_sys::Win32::System::Memory::PAGE_READONLY,
64            Self::ReadWrite => windows_sys::Win32::System::Memory::PAGE_READWRITE,
65        }
66    }
67}
68
69/// Fetch the system-specific page size.
70pub fn default_page_size() -> usize {
71    static CACHE: AtomicUsize = AtomicUsize::new(0);
72
73    let mut size = CACHE.load(Ordering::Relaxed);
74
75    if size == 0 {
76        #[cfg(miri)]
77        {
78            size = 4096;
79        }
80        #[cfg(all(target_os = "macos", not(miri)))]
81        {
82            size = unsafe { libc::vm_page_size };
83        }
84        #[cfg(all(unix, not(target_os = "macos"), not(miri)))]
85        {
86            size = unsafe { libc::sysconf(libc::_SC_PAGE_SIZE) } as usize;
87        }
88        #[cfg(all(windows, not(miri)))]
89        {
90            let mut sysinfo = MaybeUninit::<SystemInformation::SYSTEM_INFO>::uninit();
91            unsafe { SystemInformation::GetSystemInfo(sysinfo.as_mut_ptr()) };
92            size = unsafe { sysinfo.assume_init_ref() }.dwPageSize as usize;
93        }
94
95        debug_assert_ne!(size, 0);
96        // inputs to posix_memalign must be a multiple of the pointer size
97        debug_assert_eq!(size % size_of::<*const ()>(), 0);
98
99        CACHE.store(size, Ordering::Relaxed);
100    }
101
102    size
103}
104
105/// Allocate a page-aligned buffer. The alignment will be rounded up to a multiple of
106/// the platform pointer size if necessary.
107pub fn alloc_pages(len: usize) -> Result<NonNull<[u8]>, AllocError> {
108    let page_size = default_page_size();
109    let alloc_len = page_rounded_length(len, page_size);
110
111    #[cfg(miri)]
112    {
113        let addr =
114            unsafe { std::alloc::alloc(Layout::from_size_align_unchecked(alloc_len, page_size)) };
115        let range = ptr::slice_from_raw_parts_mut(addr, alloc_len);
116        NonNull::new(range).ok_or_else(|| AllocError)
117    }
118
119    #[cfg(all(unix, not(miri)))]
120    {
121        let mut addr = ptr::null_mut();
122        let ret = unsafe { posix_memalign(&mut addr, page_size, alloc_len) };
123        if ret == 0 {
124            let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
125            Ok(NonNull::new(range).expect("null allocation result"))
126        } else {
127            Err(AllocError)
128        }
129    }
130
131    #[cfg(all(windows, not(miri)))]
132    {
133        let addr = unsafe {
134            Memory::VirtualAlloc(
135                ptr::null_mut(),
136                alloc_len,
137                Memory::MEM_COMMIT | Memory::MEM_RESERVE,
138                Memory::PAGE_READWRITE,
139            )
140        };
141        let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
142        NonNull::new(range).ok_or_else(|| AllocError)
143    }
144}
145
146/// Release a buffer allocated by `alloc_aligned`.
147pub fn dealloc_pages(addr: *mut u8, len: usize) {
148    #[cfg(miri)]
149    {
150        let page_size = default_page_size();
151        let alloc_len = page_rounded_length(len, page_size);
152        unsafe {
153            std::alloc::dealloc(
154                addr,
155                Layout::from_size_align_unchecked(alloc_len, page_size),
156            )
157        };
158        return;
159    }
160
161    #[cfg(all(unix, not(miri)))]
162    {
163        let _ = len;
164        unsafe { free(addr.cast()) };
165    }
166
167    #[cfg(all(windows, not(miri)))]
168    {
169        let _ = len;
170        unsafe { Memory::VirtualFree(addr.cast(), 0, Memory::MEM_RELEASE) };
171    }
172}
173
174/// Prevent swapping for the given memory range.
175/// On supported platforms, avoid including the memory in core dumps.
176pub fn lock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
177    #[cfg(miri)]
178    {
179        _ = (addr, len);
180        Ok(())
181    }
182    #[cfg(all(unix, not(miri)))]
183    {
184        #[cfg(target_os = "linux")]
185        madvise(addr, len, libc::MADV_DONTDUMP)?;
186        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
187        madvise(addr, len, libc::MADV_NOCORE)?;
188
189        let res = unsafe { mlock(addr.cast(), len) };
190        if res == 0 {
191            Ok(())
192        } else {
193            Err(MemoryError)
194        }
195    }
196    #[cfg(all(windows, not(miri)))]
197    {
198        let res = unsafe { Memory::VirtualLock(addr.cast(), len) };
199        if res != 0 {
200            Ok(())
201        } else {
202            Err(MemoryError)
203        }
204    }
205}
206
207/// Resume normal swapping behavior for the given memory range.
208pub fn unlock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
209    #[cfg(miri)]
210    {
211        _ = (addr, len);
212        Ok(())
213    }
214    #[cfg(all(unix, not(miri)))]
215    {
216        #[cfg(target_os = "linux")]
217        madvise(addr, len, libc::MADV_DODUMP)?;
218        #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
219        madvise(addr, len, libc::MADV_CORE)?;
220
221        let res = unsafe { munlock(addr.cast(), len) };
222        if res == 0 {
223            Ok(())
224        } else {
225            Err(MemoryError)
226        }
227    }
228    #[cfg(all(windows, not(miri)))]
229    {
230        let res = unsafe { Memory::VirtualUnlock(addr.cast(), len) };
231        if res != 0 {
232            Ok(())
233        } else {
234            Err(MemoryError)
235        }
236    }
237}
238
239/// Adjust the protection mode for a given memory range.
240pub fn set_page_protection(
241    addr: *mut u8,
242    len: usize,
243    mode: ProtectionMode,
244) -> Result<(), MemoryError> {
245    #[cfg(miri)]
246    {
247        _ = (addr, len, mode);
248        Ok(())
249    }
250    #[cfg(all(unix, not(miri)))]
251    {
252        let res = unsafe { mprotect(addr.cast(), len, mode.as_native()) };
253        if res == 0 {
254            Ok(())
255        } else {
256            Err(MemoryError)
257        }
258    }
259    #[cfg(all(windows, not(miri)))]
260    {
261        let mut prev_mode = MaybeUninit::<u32>::uninit();
262        let res = unsafe {
263            Memory::VirtualProtect(addr.cast(), len, mode.as_native(), prev_mode.as_mut_ptr())
264        };
265        if res != 0 {
266            Ok(())
267        } else {
268            Err(MemoryError)
269        }
270    }
271}
272
273#[cfg(unix)]
274#[allow(unused)]
275#[inline]
276fn madvise(addr: *mut u8, len: usize, advice: i32) -> Result<(), MemoryError> {
277    {
278        let res = unsafe { libc::madvise(addr.cast(), len, advice) };
279        if res == 0 {
280            Ok(())
281        } else {
282            Err(MemoryError)
283        }
284    }
285}
286
287/// Round up a length of bytes to a multiple of the page size.
288#[inline(always)]
289pub fn page_rounded_length(len: usize, page_size: usize) -> usize {
290    len + ((page_size - (len & (page_size - 1))) % page_size)
291}
292
293/// An allocator which obtains a discrete number of virtual memory pages.
294///
295/// The virutal memory pages are flagged using `mlock` (`VirtualLock` on
296/// Windows) in order to restrict them to physical memory. When the
297/// allocation is released, the allocated memory is securely zeroed.
298#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
299pub struct SecureAlloc;
300
301impl SecureAlloc {
302    pub(crate) fn set_page_protection(
303        &self,
304        ptr: *mut u8,
305        len: usize,
306        mode: ProtectionMode,
307    ) -> Result<(), StorageError> {
308        if len != 0 {
309            let alloc_len = page_rounded_length(len, default_page_size());
310            set_page_protection(ptr, alloc_len, mode).map_err(|_| StorageError::ProtectionError)
311        } else {
312            Ok(())
313        }
314    }
315}
316
317unsafe impl Allocator for SecureAlloc {
318    #[inline]
319    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
320        debug_assert!(
321            layout.align() <= default_page_size(),
322            "alignment cannot exceed page size"
323        );
324        let layout_len = layout.size();
325        if layout_len == 0 {
326            // FIXME: use Layout::dangling when stabilized
327            // SAFETY: layout alignments are guaranteed to be non-zero.
328            #[allow(clippy::useless_transmute)]
329            let head = unsafe { NonNull::new_unchecked(transmute(layout.align())) };
330            Ok(NonNull::slice_from_raw_parts(head, 0))
331        } else {
332            let alloc = alloc_pages(layout_len).map_err(|_| AllocError)?;
333            let alloc_len = alloc.len();
334
335            // Initialize with uninitialized indicator value
336            // SAFETY: the allocated pointer is guaranteed to be valid and have a length
337            // equal to `alloc_len`.
338            unsafe { ptr::write_bytes(alloc.as_ptr().cast::<u8>(), UNINIT_ALLOC_BYTE, alloc_len) };
339
340            // Keep data page(s) out of swap
341            lock_pages(alloc.as_ptr().cast(), alloc_len).map_err(|_| AllocError)?;
342
343            Ok(alloc)
344        }
345    }
346
347    #[inline]
348    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
349        let len = layout.size();
350        if len > 0 {
351            let alloc_len = page_rounded_length(len, default_page_size());
352
353            // Zero protected data
354            let mem = unsafe { slice::from_raw_parts_mut(ptr.as_ptr(), alloc_len) };
355            mem.zeroize();
356
357            // Restore normal swapping behavior
358            unlock_pages(ptr.as_ptr().cast(), alloc_len).ok();
359
360            // Free the memory
361            dealloc_pages(ptr.as_ptr(), alloc_len);
362        }
363    }
364}
365
366impl AllocatorDefault for SecureAlloc {
367    const DEFAULT: Self = Self;
368}
369
370impl AllocatorZeroizes for SecureAlloc {}
371
372#[cfg(test)]
373mod tests {
374    use core::alloc::Layout;
375    use flex_alloc::alloc::Allocator;
376
377    use crate::{alloc::UNINIT_ALLOC_BYTE, vec::SecureVec};
378
379    use super::SecureAlloc;
380
381    #[test]
382    fn check_extra_capacity() {
383        let vec = SecureVec::<usize>::with_capacity(1);
384        // We always allocate pages, so there should be plenty of room for more values.
385        assert!(vec.capacity() > 1);
386    }
387
388    #[test]
389    fn check_uninit() {
390        let layout = Layout::new::<usize>();
391        let buf = SecureAlloc.allocate(layout).expect("allocation error");
392        #[allow(clippy::len_zero)]
393        {
394            assert!(buf.len() != 0 && unsafe { buf.as_ref() }[..4] == [UNINIT_ALLOC_BYTE; 4]);
395        }
396        unsafe { SecureAlloc.deallocate(buf.cast(), layout) };
397    }
398}