1use core::alloc::Layout;
4use core::mem::transmute;
5use core::ptr::{self, NonNull};
6use core::sync::atomic::{AtomicUsize, Ordering};
7use core::{fmt, slice};
8
9#[cfg(all(windows, not(miri)))]
10use core::mem::MaybeUninit;
11
12use flex_alloc::alloc::{AllocError, Allocator, AllocatorDefault, AllocatorZeroizes};
13use flex_alloc::StorageError;
14use zeroize::Zeroize;
15
16#[cfg(all(unix, not(miri)))]
17use libc::{free, mlock, mprotect, munlock, posix_memalign};
18
19#[cfg(all(windows, not(miri)))]
20use windows_sys::Win32::System::{Memory, SystemInformation};
21
22pub const UNINIT_ALLOC_BYTE: u8 = 0xdb;
24
25#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
27pub struct MemoryError;
28
29impl fmt::Display for MemoryError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.write_str("Memory error")
32 }
33}
34
35impl std::error::Error for MemoryError {}
36
37#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
39pub enum ProtectionMode {
40 NoAccess,
42 ReadOnly,
44 #[default]
46 ReadWrite,
47}
48
49impl ProtectionMode {
50 #[cfg(all(unix, not(miri)))]
51 pub(crate) const fn as_native(self) -> i32 {
52 match self {
53 Self::NoAccess => libc::PROT_NONE,
54 Self::ReadOnly => libc::PROT_READ,
55 Self::ReadWrite => libc::PROT_READ | libc::PROT_WRITE,
56 }
57 }
58
59 #[cfg(all(windows, not(miri)))]
60 pub(crate) const fn as_native(self) -> u32 {
61 match self {
62 Self::NoAccess => windows_sys::Win32::System::Memory::PAGE_NOACCESS,
63 Self::ReadOnly => windows_sys::Win32::System::Memory::PAGE_READONLY,
64 Self::ReadWrite => windows_sys::Win32::System::Memory::PAGE_READWRITE,
65 }
66 }
67}
68
69pub fn default_page_size() -> usize {
71 static CACHE: AtomicUsize = AtomicUsize::new(0);
72
73 let mut size = CACHE.load(Ordering::Relaxed);
74
75 if size == 0 {
76 #[cfg(miri)]
77 {
78 size = 4096;
79 }
80 #[cfg(all(target_os = "macos", not(miri)))]
81 {
82 size = unsafe { libc::vm_page_size };
83 }
84 #[cfg(all(unix, not(target_os = "macos"), not(miri)))]
85 {
86 size = unsafe { libc::sysconf(libc::_SC_PAGE_SIZE) } as usize;
87 }
88 #[cfg(all(windows, not(miri)))]
89 {
90 let mut sysinfo = MaybeUninit::<SystemInformation::SYSTEM_INFO>::uninit();
91 unsafe { SystemInformation::GetSystemInfo(sysinfo.as_mut_ptr()) };
92 size = unsafe { sysinfo.assume_init_ref() }.dwPageSize as usize;
93 }
94
95 debug_assert_ne!(size, 0);
96 debug_assert_eq!(size % size_of::<*const ()>(), 0);
98
99 CACHE.store(size, Ordering::Relaxed);
100 }
101
102 size
103}
104
105pub fn alloc_pages(len: usize) -> Result<NonNull<[u8]>, AllocError> {
108 let page_size = default_page_size();
109 let alloc_len = page_rounded_length(len, page_size);
110
111 #[cfg(miri)]
112 {
113 let addr =
114 unsafe { std::alloc::alloc(Layout::from_size_align_unchecked(alloc_len, page_size)) };
115 let range = ptr::slice_from_raw_parts_mut(addr, alloc_len);
116 NonNull::new(range).ok_or_else(|| AllocError)
117 }
118
119 #[cfg(all(unix, not(miri)))]
120 {
121 let mut addr = ptr::null_mut();
122 let ret = unsafe { posix_memalign(&mut addr, page_size, alloc_len) };
123 if ret == 0 {
124 let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
125 Ok(NonNull::new(range).expect("null allocation result"))
126 } else {
127 Err(AllocError)
128 }
129 }
130
131 #[cfg(all(windows, not(miri)))]
132 {
133 let addr = unsafe {
134 Memory::VirtualAlloc(
135 ptr::null_mut(),
136 alloc_len,
137 Memory::MEM_COMMIT | Memory::MEM_RESERVE,
138 Memory::PAGE_READWRITE,
139 )
140 };
141 let range = ptr::slice_from_raw_parts_mut(addr.cast(), alloc_len);
142 NonNull::new(range).ok_or_else(|| AllocError)
143 }
144}
145
146pub fn dealloc_pages(addr: *mut u8, len: usize) {
148 #[cfg(miri)]
149 {
150 let page_size = default_page_size();
151 let alloc_len = page_rounded_length(len, page_size);
152 unsafe {
153 std::alloc::dealloc(
154 addr,
155 Layout::from_size_align_unchecked(alloc_len, page_size),
156 )
157 };
158 return;
159 }
160
161 #[cfg(all(unix, not(miri)))]
162 {
163 let _ = len;
164 unsafe { free(addr.cast()) };
165 }
166
167 #[cfg(all(windows, not(miri)))]
168 {
169 let _ = len;
170 unsafe { Memory::VirtualFree(addr.cast(), 0, Memory::MEM_RELEASE) };
171 }
172}
173
174pub fn lock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
177 #[cfg(miri)]
178 {
179 _ = (addr, len);
180 Ok(())
181 }
182 #[cfg(all(unix, not(miri)))]
183 {
184 #[cfg(target_os = "linux")]
185 madvise(addr, len, libc::MADV_DONTDUMP)?;
186 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
187 madvise(addr, len, libc::MADV_NOCORE)?;
188
189 let res = unsafe { mlock(addr.cast(), len) };
190 if res == 0 {
191 Ok(())
192 } else {
193 Err(MemoryError)
194 }
195 }
196 #[cfg(all(windows, not(miri)))]
197 {
198 let res = unsafe { Memory::VirtualLock(addr.cast(), len) };
199 if res != 0 {
200 Ok(())
201 } else {
202 Err(MemoryError)
203 }
204 }
205}
206
207pub fn unlock_pages(addr: *mut u8, len: usize) -> Result<(), MemoryError> {
209 #[cfg(miri)]
210 {
211 _ = (addr, len);
212 Ok(())
213 }
214 #[cfg(all(unix, not(miri)))]
215 {
216 #[cfg(target_os = "linux")]
217 madvise(addr, len, libc::MADV_DODUMP)?;
218 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
219 madvise(addr, len, libc::MADV_CORE)?;
220
221 let res = unsafe { munlock(addr.cast(), len) };
222 if res == 0 {
223 Ok(())
224 } else {
225 Err(MemoryError)
226 }
227 }
228 #[cfg(all(windows, not(miri)))]
229 {
230 let res = unsafe { Memory::VirtualUnlock(addr.cast(), len) };
231 if res != 0 {
232 Ok(())
233 } else {
234 Err(MemoryError)
235 }
236 }
237}
238
239pub fn set_page_protection(
241 addr: *mut u8,
242 len: usize,
243 mode: ProtectionMode,
244) -> Result<(), MemoryError> {
245 #[cfg(miri)]
246 {
247 _ = (addr, len, mode);
248 Ok(())
249 }
250 #[cfg(all(unix, not(miri)))]
251 {
252 let res = unsafe { mprotect(addr.cast(), len, mode.as_native()) };
253 if res == 0 {
254 Ok(())
255 } else {
256 Err(MemoryError)
257 }
258 }
259 #[cfg(all(windows, not(miri)))]
260 {
261 let mut prev_mode = MaybeUninit::<u32>::uninit();
262 let res = unsafe {
263 Memory::VirtualProtect(addr.cast(), len, mode.as_native(), prev_mode.as_mut_ptr())
264 };
265 if res != 0 {
266 Ok(())
267 } else {
268 Err(MemoryError)
269 }
270 }
271}
272
273#[cfg(unix)]
274#[allow(unused)]
275#[inline]
276fn madvise(addr: *mut u8, len: usize, advice: i32) -> Result<(), MemoryError> {
277 {
278 let res = unsafe { libc::madvise(addr.cast(), len, advice) };
279 if res == 0 {
280 Ok(())
281 } else {
282 Err(MemoryError)
283 }
284 }
285}
286
287#[inline(always)]
289pub fn page_rounded_length(len: usize, page_size: usize) -> usize {
290 len + ((page_size - (len & (page_size - 1))) % page_size)
291}
292
293#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
299pub struct SecureAlloc;
300
301impl SecureAlloc {
302 pub(crate) fn set_page_protection(
303 &self,
304 ptr: *mut u8,
305 len: usize,
306 mode: ProtectionMode,
307 ) -> Result<(), StorageError> {
308 if len != 0 {
309 let alloc_len = page_rounded_length(len, default_page_size());
310 set_page_protection(ptr, alloc_len, mode).map_err(|_| StorageError::ProtectionError)
311 } else {
312 Ok(())
313 }
314 }
315}
316
317unsafe impl Allocator for SecureAlloc {
318 #[inline]
319 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
320 debug_assert!(
321 layout.align() <= default_page_size(),
322 "alignment cannot exceed page size"
323 );
324 let layout_len = layout.size();
325 if layout_len == 0 {
326 #[allow(clippy::useless_transmute)]
329 let head = unsafe { NonNull::new_unchecked(transmute(layout.align())) };
330 Ok(NonNull::slice_from_raw_parts(head, 0))
331 } else {
332 let alloc = alloc_pages(layout_len).map_err(|_| AllocError)?;
333 let alloc_len = alloc.len();
334
335 unsafe { ptr::write_bytes(alloc.as_ptr().cast::<u8>(), UNINIT_ALLOC_BYTE, alloc_len) };
339
340 lock_pages(alloc.as_ptr().cast(), alloc_len).map_err(|_| AllocError)?;
342
343 Ok(alloc)
344 }
345 }
346
347 #[inline]
348 unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
349 let len = layout.size();
350 if len > 0 {
351 let alloc_len = page_rounded_length(len, default_page_size());
352
353 let mem = unsafe { slice::from_raw_parts_mut(ptr.as_ptr(), alloc_len) };
355 mem.zeroize();
356
357 unlock_pages(ptr.as_ptr().cast(), alloc_len).ok();
359
360 dealloc_pages(ptr.as_ptr(), alloc_len);
362 }
363 }
364}
365
366impl AllocatorDefault for SecureAlloc {
367 const DEFAULT: Self = Self;
368}
369
370impl AllocatorZeroizes for SecureAlloc {}
371
372#[cfg(test)]
373mod tests {
374 use core::alloc::Layout;
375 use flex_alloc::alloc::Allocator;
376
377 use crate::{alloc::UNINIT_ALLOC_BYTE, vec::SecureVec};
378
379 use super::SecureAlloc;
380
381 #[test]
382 fn check_extra_capacity() {
383 let vec = SecureVec::<usize>::with_capacity(1);
384 assert!(vec.capacity() > 1);
386 }
387
388 #[test]
389 fn check_uninit() {
390 let layout = Layout::new::<usize>();
391 let buf = SecureAlloc.allocate(layout).expect("allocation error");
392 #[allow(clippy::len_zero)]
393 {
394 assert!(buf.len() != 0 && unsafe { buf.as_ref() }[..4] == [UNINIT_ALLOC_BYTE; 4]);
395 }
396 unsafe { SecureAlloc.deallocate(buf.cast(), layout) };
397 }
398}