1use std::{
2 alloc::{handle_alloc_error, Layout},
3 fmt::Debug,
4 mem,
5 ops::{Deref, DerefMut},
6 ptr::{self, NonNull},
7 slice,
8};
9
10use bytemuck::Zeroable;
11
12pub struct HugePageMemory<T> {
20 ptr: NonNull<T>,
21 len: usize,
22 capacity: usize,
23}
24
25pub const HUGE_PAGE_SIZE: usize = 2 * 1024 * 1024;
26
27impl<T> HugePageMemory<T> {
28 #[inline]
29 pub fn len(&self) -> usize {
30 self.len
31 }
32
33 #[inline]
34 pub fn is_empty(&self) -> bool {
35 self.len() == 0
36 }
37
38 #[inline]
39 pub fn capacity(&self) -> usize {
40 self.capacity
41 }
42
43 #[inline]
47 pub fn set_len(&mut self, new_len: usize) {
48 assert!(new_len <= self.capacity());
49 #[allow(unused_unsafe)]
54 unsafe {
55 self.len = new_len;
56 }
57 }
58}
59
60#[cfg(target_family = "unix")]
61impl<T: Zeroable> HugePageMemory<T> {
62 pub fn zeroed(len: usize) -> Self {
65 let layout = Self::layout(len);
66 let capacity = layout.size();
67 let ptr = unsafe {
68 let ptr = libc::mmap(
70 ptr::null_mut(),
71 capacity,
72 libc::PROT_READ | libc::PROT_WRITE,
73 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
74 -1,
75 0,
76 );
77 if ptr == libc::MAP_FAILED {
78 handle_alloc_error(layout)
79 }
80 #[cfg(not(miri))]
81 if libc::madvise(ptr, capacity, libc::MADV_HUGEPAGE) != 0 {
82 let err = std::io::Error::last_os_error();
83 match err.raw_os_error() {
84 Some(
85 libc::ENOMEM
87 | libc::EINVAL) => {
89 libc::munmap(ptr, capacity);
90 handle_alloc_error(layout);
91 }
92 _ => {
94 tracing::warn!("Failed to enable huge pages: {}", err);
95 }
96 }
97 }
98 NonNull::new_unchecked(ptr.cast())
99 };
100
101 Self { ptr, len, capacity }
102 }
103}
104
105impl<T: Zeroable + Clone> HugePageMemory<T> {
106 pub fn grow_zeroed(&mut self, new_size: usize) {
108 if new_size <= self.capacity() {
110 self.set_len(new_size);
111 return;
112 }
113
114 #[cfg(target_os = "linux")]
115 {
116 self.grow_with_mremap(new_size);
117 }
118
119 #[cfg(not(target_os = "linux"))]
120 {
121 self.grow_with_mmap(new_size);
122 }
123 }
124
125 #[cfg(target_os = "linux")]
127 fn grow_with_mremap(&mut self, new_size: usize) {
128 let new_layout = Self::layout(new_size);
130 let new_capacity = new_layout.size();
131
132 let new_ptr = unsafe {
133 let remapped_ptr = libc::mremap(
134 self.ptr.as_ptr().cast(),
135 self.capacity,
136 new_capacity,
137 libc::MREMAP_MAYMOVE,
138 );
139
140 if remapped_ptr == libc::MAP_FAILED {
141 libc::munmap(self.ptr.as_ptr().cast(), self.capacity);
142 handle_alloc_error(new_layout);
143 }
144
145 #[cfg(not(miri))]
147 if libc::madvise(remapped_ptr, new_capacity, libc::MADV_HUGEPAGE) != 0 {
148 let err = std::io::Error::last_os_error();
149 tracing::warn!("Failed to enable huge pages after mremap: {}", err);
150 }
151
152 NonNull::new_unchecked(remapped_ptr.cast())
153 };
154
155 self.ptr = new_ptr;
157 self.capacity = new_capacity;
158 self.set_len(new_size);
159 }
160
161 #[allow(dead_code)]
163 fn grow_with_mmap(&mut self, new_size: usize) {
164 let mut new = Self::zeroed(new_size);
165 new[..self.len()].clone_from_slice(self);
166 *self = new;
167 }
168}
169
170#[cfg(target_family = "unix")]
171impl<T> HugePageMemory<T> {
172 fn layout(len: usize) -> Layout {
173 let size = len * mem::size_of::<T>();
174 let align = mem::align_of::<T>().min(HUGE_PAGE_SIZE);
175 let layout = Layout::from_size_align(size, align).expect("alloc too large");
176 layout.pad_to_align()
177 }
178}
179
180#[cfg(target_family = "unix")]
181impl<T> Drop for HugePageMemory<T> {
182 #[inline]
183 fn drop(&mut self) {
184 unsafe {
185 libc::munmap(self.ptr.as_ptr().cast(), self.capacity);
186 }
187 }
188}
189
190#[cfg(not(target_family = "unix"))]
192impl<T: Zeroable> HugePageMemory<T> {
193 pub fn zeroed(len: usize) -> Self {
194 let v = allocate_zeroed_vec(len);
195 assert_eq!(v.len(), v.capacity());
196 let ptr = NonNull::new(v.leak().as_mut_ptr()).expect("not null");
197 Self { ptr, len }
198 }
199}
200
201#[cfg(not(target_family = "unix"))]
202impl<T> Drop for HugePageMemory<T> {
203 fn drop(&mut self) {
204 unsafe { Vec::from_raw_parts(self.ptr.as_ptr(), self.len, self.len) };
205 }
206}
207
208impl<T> Deref for HugePageMemory<T> {
209 type Target = [T];
210
211 #[inline]
212 fn deref(&self) -> &Self::Target {
213 unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
214 }
215}
216
217impl<T> DerefMut for HugePageMemory<T> {
218 #[inline]
219 fn deref_mut(&mut self) -> &mut Self::Target {
220 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
221 }
222}
223
224impl<T> Default for HugePageMemory<T> {
225 fn default() -> Self {
226 Self {
227 ptr: NonNull::dangling(),
228 len: 0,
229 capacity: 0,
230 }
231 }
232}
233
234impl<T: Debug> Debug for HugePageMemory<T> {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_list().entries(self.iter()).finish()
237 }
238}
239
240unsafe impl<T: Send> Send for HugePageMemory<T> {}
241unsafe impl<T: Sync> Sync for HugePageMemory<T> {}
242
243pub fn allocate_zeroed_vec<T: Zeroable>(len: usize) -> Vec<T> {
245 unsafe {
246 let size = len * mem::size_of::<T>();
247 let align = mem::align_of::<T>();
248 let layout = Layout::from_size_align(size, align).expect("len too large");
249 let zeroed = std::alloc::alloc_zeroed(layout);
250 Vec::from_raw_parts(zeroed as *mut T, len, len)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::{HugePageMemory, HUGE_PAGE_SIZE};
264
265 #[test]
266 fn test_huge_page_memory() {
267 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE + HUGE_PAGE_SIZE / 2);
268 #[cfg(not(miri))] for b in mem.iter() {
270 assert_eq!(0, *b);
271 }
272 assert!(mem[0] == 0);
273 assert!(mem[mem.len() - 1] == 0);
274 mem[42] = 5;
275 mem.set_len(HUGE_PAGE_SIZE);
276 assert_eq!(HUGE_PAGE_SIZE, mem.len());
277 }
278
279 #[test]
280 #[should_panic]
281 fn test_set_len_panics() {
282 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE);
283 mem.set_len(HUGE_PAGE_SIZE + 1);
284 }
285
286 #[test]
287 fn test_grow() {
288 let mut mem = HugePageMemory::<u8>::zeroed(HUGE_PAGE_SIZE);
289 assert_eq!(0, mem[0]);
290 mem[0] = 1;
291 mem.grow_zeroed(2 * HUGE_PAGE_SIZE);
292 assert_eq!(2 * HUGE_PAGE_SIZE, mem.len());
293 assert_eq!(2 * HUGE_PAGE_SIZE, mem.capacity());
294 assert_eq!(1, mem[0]);
295 assert_eq!(0, mem[HUGE_PAGE_SIZE + 1]);
296 }
297}