1use std::{
3 alloc::{Layout, handle_alloc_error},
4 fmt::Debug,
5 mem,
6 ops::{Deref, DerefMut},
7 ptr::{self, NonNull},
8 slice,
9};
10
11use bytemuck::Zeroable;
12
13pub struct HugePageMemory<T> {
21 ptr: NonNull<T>,
22 len: usize,
23 capacity: usize,
24}
25
26pub const HUGE_PAGE_SIZE: usize = 2 * 1024 * 1024;
27
28impl<T> HugePageMemory<T> {
29 #[inline]
30 pub fn len(&self) -> usize {
31 self.len
32 }
33
34 #[inline]
35 pub fn is_empty(&self) -> bool {
36 self.len() == 0
37 }
38
39 #[inline]
40 pub fn capacity(&self) -> usize {
41 self.capacity
42 }
43
44 #[inline]
48 pub fn set_len(&mut self, new_len: usize) {
49 assert!(new_len <= self.capacity());
50 self.len = new_len;
55 }
56}
57
58#[cfg(target_os = "linux")]
59impl<T: Zeroable> HugePageMemory<T> {
60 pub fn zeroed(len: usize) -> Self {
63 let layout = Self::layout(len);
64 let capacity = layout.size();
65 let ptr = unsafe {
66 let ptr = libc::mmap(
68 ptr::null_mut(),
69 capacity,
70 libc::PROT_READ | libc::PROT_WRITE,
71 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
72 -1,
73 0,
74 );
75 if ptr == libc::MAP_FAILED {
76 handle_alloc_error(layout)
77 }
78 #[cfg(not(miri))]
79 if libc::madvise(ptr, capacity, libc::MADV_HUGEPAGE) != 0 {
80 let err = std::io::Error::last_os_error();
81 match err.raw_os_error() {
82 Some(
83 libc::ENOMEM
85 | libc::EINVAL) => {
87 libc::munmap(ptr, capacity);
88 handle_alloc_error(layout);
89 }
90 _ => {
92 tracing::warn!("Failed to enable huge pages: {}", err);
93 }
94 }
95 }
96 NonNull::new_unchecked(ptr.cast())
97 };
98
99 Self { ptr, len, capacity }
100 }
101}
102
103impl<T: Zeroable + Clone> HugePageMemory<T> {
104 pub fn grow_zeroed(&mut self, new_size: usize) {
106 if new_size <= self.capacity() {
108 self.set_len(new_size);
109 return;
110 }
111
112 #[cfg(target_os = "linux")]
113 {
114 self.grow_with_mremap(new_size);
115 }
116
117 #[cfg(not(target_os = "linux"))]
118 {
119 self.grow_with_mmap(new_size);
120 }
121 }
122
123 #[cfg(target_os = "linux")]
125 fn grow_with_mremap(&mut self, new_size: usize) {
126 let new_layout = Self::layout(new_size);
128 let new_capacity = new_layout.size();
129
130 let new_ptr = unsafe {
131 let remapped_ptr = libc::mremap(
132 self.ptr.as_ptr().cast(),
133 self.capacity,
134 new_capacity,
135 libc::MREMAP_MAYMOVE,
136 );
137
138 if remapped_ptr == libc::MAP_FAILED {
139 libc::munmap(self.ptr.as_ptr().cast(), self.capacity);
140 handle_alloc_error(new_layout);
141 }
142
143 #[cfg(not(miri))]
145 if libc::madvise(remapped_ptr, new_capacity, libc::MADV_HUGEPAGE) != 0 {
146 let err = std::io::Error::last_os_error();
147 tracing::warn!("Failed to enable huge pages after mremap: {}", err);
148 }
149
150 NonNull::new_unchecked(remapped_ptr.cast())
151 };
152
153 self.ptr = new_ptr;
155 self.capacity = new_capacity;
156 self.set_len(new_size);
157 }
158
159 #[allow(dead_code)]
161 fn grow_with_mmap(&mut self, new_size: usize) {
162 let mut new = Self::zeroed(new_size);
163 new[..self.len()].clone_from_slice(self);
164 *self = new;
165 }
166}
167
168#[cfg(target_os = "linux")]
169impl<T> HugePageMemory<T> {
170 fn layout(len: usize) -> Layout {
171 let size = len * mem::size_of::<T>();
172 let align = mem::align_of::<T>().min(HUGE_PAGE_SIZE);
173 let layout = Layout::from_size_align(size, align).expect("alloc too large");
174 layout.pad_to_align()
175 }
176}
177
178#[cfg(target_os = "linux")]
179impl<T> Drop for HugePageMemory<T> {
180 #[inline]
181 fn drop(&mut self) {
182 unsafe {
183 libc::munmap(self.ptr.as_ptr().cast(), self.capacity);
184 }
185 }
186}
187
188#[cfg(not(target_os = "linux"))]
190impl<T: Zeroable> HugePageMemory<T> {
191 pub fn zeroed(len: usize) -> Self {
192 let v = allocate_zeroed_vec(len);
193 assert_eq!(v.len(), v.capacity());
194 let ptr = NonNull::new(v.leak().as_mut_ptr()).expect("not null");
195 Self {
196 ptr,
197 len,
198 capacity: len,
199 }
200 }
201}
202
203#[cfg(not(target_os = "linux"))]
204impl<T> Drop for HugePageMemory<T> {
205 fn drop(&mut self) {
206 unsafe { Vec::from_raw_parts(self.ptr.as_ptr(), self.len, self.capacity) };
207 }
208}
209
210impl<T> Deref for HugePageMemory<T> {
211 type Target = [T];
212
213 #[inline]
214 fn deref(&self) -> &Self::Target {
215 unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
216 }
217}
218
219impl<T> DerefMut for HugePageMemory<T> {
220 #[inline]
221 fn deref_mut(&mut self) -> &mut Self::Target {
222 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
223 }
224}
225
226impl<T> Default for HugePageMemory<T> {
227 fn default() -> Self {
228 Self {
229 ptr: NonNull::dangling(),
230 len: 0,
231 capacity: 0,
232 }
233 }
234}
235
236impl<T: Debug> Debug for HugePageMemory<T> {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_list().entries(self.iter()).finish()
239 }
240}
241
242unsafe impl<T: Send> Send for HugePageMemory<T> {}
243unsafe impl<T: Sync> Sync for HugePageMemory<T> {}
244
245pub fn allocate_zeroed_vec<T: Zeroable>(len: usize) -> Vec<T> {
250 unsafe {
251 let size = len * mem::size_of::<T>();
252 let align = mem::align_of::<T>();
253 let layout = Layout::from_size_align(size, align).expect("len too large");
254 let zeroed = std::alloc::alloc_zeroed(layout);
255 Vec::from_raw_parts(zeroed as *mut T, len, len)
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::{HUGE_PAGE_SIZE, HugePageMemory};
269
270 #[test]
271 fn test_huge_page_memory() {
272 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE + HUGE_PAGE_SIZE / 2);
273 #[cfg(not(miri))] for b in mem.iter() {
275 assert_eq!(0, *b);
276 }
277 assert!(mem[0] == 0);
278 assert!(mem[mem.len() - 1] == 0);
279 mem[42] = 5;
280 mem.set_len(HUGE_PAGE_SIZE);
281 assert_eq!(HUGE_PAGE_SIZE, mem.len());
282 }
283
284 #[test]
285 fn test_set_len_correct_dealloc() {
286 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE);
287 mem.set_len(HUGE_PAGE_SIZE / 2);
288 }
289
290 #[test]
291 #[should_panic]
292 fn test_set_len_panics() {
293 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE);
294 mem.set_len(HUGE_PAGE_SIZE + 1);
295 }
296
297 #[test]
298 fn test_grow() {
299 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE);
300 assert_eq!(0, mem[0]);
301 mem[0] = 1;
302 mem.grow_zeroed(2 * HUGE_PAGE_SIZE);
303 assert_eq!(2 * HUGE_PAGE_SIZE, mem.len());
304 assert_eq!(2 * HUGE_PAGE_SIZE, mem.capacity());
305 assert_eq!(1, mem[0]);
306 assert_eq!(0, mem[HUGE_PAGE_SIZE + 1]);
307 }
308}