wraith/km/
allocator.rs

1//! Kernel pool memory allocation
2
3use core::alloc::{GlobalAlloc, Layout};
4use core::ffi::c_void;
5use core::ptr::NonNull;
6
7use super::error::{KmError, KmResult};
8
9/// pool allocation type
10#[repr(u32)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PoolType {
13    /// non-paged pool (always resident in physical memory)
14    NonPaged = 0,
15    /// paged pool (can be paged out)
16    Paged = 1,
17    /// non-paged pool, no execute
18    NonPagedNx = 512,
19    /// non-paged pool for session (drivers only)
20    NonPagedSession = 32,
21    /// paged pool for session
22    PagedSession = 33,
23}
24
25impl Default for PoolType {
26    fn default() -> Self {
27        Self::NonPagedNx
28    }
29}
30
31/// pool allocation tag (4-byte identifier for debugging)
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct PoolTag(pub u32);
34
35impl PoolTag {
36    /// create from 4-character string
37    pub const fn from_chars(chars: [u8; 4]) -> Self {
38        Self(u32::from_le_bytes(chars))
39    }
40
41    /// default tag for wraith allocations
42    pub const WRAITH: Self = Self::from_chars(*b"WRAT");
43}
44
45impl Default for PoolTag {
46    fn default() -> Self {
47        Self::WRAITH
48    }
49}
50
51/// kernel pool allocator
52pub struct PoolAllocator {
53    pool_type: PoolType,
54    tag: PoolTag,
55}
56
57impl PoolAllocator {
58    /// create new pool allocator with specified type and tag
59    pub const fn new(pool_type: PoolType, tag: PoolTag) -> Self {
60        Self { pool_type, tag }
61    }
62
63    /// create non-paged allocator
64    pub const fn non_paged() -> Self {
65        Self::new(PoolType::NonPagedNx, PoolTag::WRAITH)
66    }
67
68    /// create paged allocator
69    pub const fn paged() -> Self {
70        Self::new(PoolType::Paged, PoolTag::WRAITH)
71    }
72
73    /// allocate memory from pool
74    pub fn allocate(&self, size: usize) -> KmResult<NonNull<u8>> {
75        if size == 0 {
76            return Err(KmError::InvalidParameter {
77                context: "allocate: size cannot be zero",
78            });
79        }
80
81        // SAFETY: calling kernel pool allocation function
82        let ptr = unsafe {
83            ExAllocatePoolWithTag(self.pool_type as u32, size, self.tag.0)
84        };
85
86        NonNull::new(ptr as *mut u8).ok_or(KmError::PoolAllocationFailed {
87            size,
88            pool_type: self.pool_type as u32,
89        })
90    }
91
92    /// allocate zeroed memory from pool
93    pub fn allocate_zeroed(&self, size: usize) -> KmResult<NonNull<u8>> {
94        let ptr = self.allocate(size)?;
95        // SAFETY: ptr is valid and we own this memory
96        unsafe {
97            core::ptr::write_bytes(ptr.as_ptr(), 0, size);
98        }
99        Ok(ptr)
100    }
101
102    /// free previously allocated memory
103    ///
104    /// # Safety
105    /// ptr must have been allocated by this allocator (or one with same tag)
106    pub unsafe fn free(&self, ptr: NonNull<u8>) {
107        // SAFETY: caller ensures ptr is valid pool allocation
108        unsafe {
109            ExFreePoolWithTag(ptr.as_ptr() as *mut c_void, self.tag.0);
110        }
111    }
112
113    /// reallocate memory (allocate new, copy, free old)
114    ///
115    /// # Safety
116    /// old_ptr must have been allocated by this allocator
117    pub unsafe fn reallocate(
118        &self,
119        old_ptr: NonNull<u8>,
120        old_size: usize,
121        new_size: usize,
122    ) -> KmResult<NonNull<u8>> {
123        if new_size == 0 {
124            // SAFETY: caller ensures old_ptr is valid
125            unsafe { self.free(old_ptr) };
126            return Err(KmError::InvalidParameter {
127                context: "reallocate: new_size cannot be zero",
128            });
129        }
130
131        let new_ptr = self.allocate(new_size)?;
132
133        // SAFETY: both pointers are valid, copy the smaller of the two sizes
134        unsafe {
135            let copy_size = core::cmp::min(old_size, new_size);
136            core::ptr::copy_nonoverlapping(old_ptr.as_ptr(), new_ptr.as_ptr(), copy_size);
137            self.free(old_ptr);
138        }
139
140        Ok(new_ptr)
141    }
142}
143
144/// global pool allocator for use with alloc crate
145pub struct KernelAllocator;
146
147impl KernelAllocator {
148    const ALLOCATOR: PoolAllocator = PoolAllocator::non_paged();
149}
150
151unsafe impl GlobalAlloc for KernelAllocator {
152    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
153        // kernel pool allocations are 16-byte aligned on x64, 8-byte on x86
154        // for larger alignments we need to over-allocate
155        let align = layout.align();
156        let size = layout.size();
157
158        if align <= 16 {
159            match Self::ALLOCATOR.allocate(size) {
160                Ok(ptr) => ptr.as_ptr(),
161                Err(_) => core::ptr::null_mut(),
162            }
163        } else {
164            // over-allocate to handle alignment
165            let total_size = size + align;
166            match Self::ALLOCATOR.allocate(total_size) {
167                Ok(ptr) => {
168                    let raw = ptr.as_ptr() as usize;
169                    let aligned = (raw + align - 1) & !(align - 1);
170                    // store original pointer before aligned address
171                    let aligned_ptr = aligned as *mut u8;
172                    // SAFETY: aligned_ptr - sizeof(usize) is within our allocation
173                    unsafe {
174                        *((aligned_ptr as *mut usize).offset(-1)) = raw;
175                    }
176                    aligned_ptr
177                }
178                Err(_) => core::ptr::null_mut(),
179            }
180        }
181    }
182
183    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
184        if ptr.is_null() {
185            return;
186        }
187
188        let align = layout.align();
189
190        let actual_ptr = if align <= 16 {
191            ptr
192        } else {
193            // retrieve original pointer
194            // SAFETY: we stored the original pointer at ptr - sizeof(usize) during alloc
195            let raw = unsafe { *((ptr as *mut usize).offset(-1)) };
196            raw as *mut u8
197        };
198
199        if let Some(ptr) = NonNull::new(actual_ptr) {
200            // SAFETY: ptr was allocated by our allocator
201            unsafe { Self::ALLOCATOR.free(ptr) };
202        }
203    }
204
205    unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
206        let new_layout = match Layout::from_size_align(new_size, layout.align()) {
207            Ok(l) => l,
208            Err(_) => return core::ptr::null_mut(),
209        };
210
211        // SAFETY: allocate new, copy, deallocate old
212        unsafe {
213            let new_ptr = self.alloc(new_layout);
214            if !new_ptr.is_null() {
215                let copy_size = core::cmp::min(layout.size(), new_size);
216                core::ptr::copy_nonoverlapping(ptr, new_ptr, copy_size);
217                self.dealloc(ptr, layout);
218            }
219            new_ptr
220        }
221    }
222}
223
224/// RAII wrapper for pool allocations
225pub struct PoolBuffer {
226    ptr: NonNull<u8>,
227    size: usize,
228    allocator: PoolAllocator,
229}
230
231impl PoolBuffer {
232    /// allocate a new pool buffer
233    pub fn new(size: usize, pool_type: PoolType) -> KmResult<Self> {
234        let allocator = PoolAllocator::new(pool_type, PoolTag::WRAITH);
235        let ptr = allocator.allocate(size)?;
236        Ok(Self { ptr, size, allocator })
237    }
238
239    /// allocate zeroed buffer
240    pub fn zeroed(size: usize, pool_type: PoolType) -> KmResult<Self> {
241        let allocator = PoolAllocator::new(pool_type, PoolTag::WRAITH);
242        let ptr = allocator.allocate_zeroed(size)?;
243        Ok(Self { ptr, size, allocator })
244    }
245
246    /// get pointer to buffer
247    pub fn as_ptr(&self) -> *mut u8 {
248        self.ptr.as_ptr()
249    }
250
251    /// get buffer as slice
252    pub fn as_slice(&self) -> &[u8] {
253        // SAFETY: buffer is valid for size bytes
254        unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
255    }
256
257    /// get buffer as mutable slice
258    pub fn as_mut_slice(&mut self) -> &mut [u8] {
259        // SAFETY: buffer is valid for size bytes and we have exclusive access
260        unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
261    }
262
263    /// get buffer size
264    pub fn size(&self) -> usize {
265        self.size
266    }
267
268    /// leak the buffer, returning the raw pointer
269    pub fn leak(self) -> NonNull<u8> {
270        let ptr = self.ptr;
271        core::mem::forget(self);
272        ptr
273    }
274}
275
276impl Drop for PoolBuffer {
277    fn drop(&mut self) {
278        // SAFETY: ptr was allocated by our allocator
279        unsafe { self.allocator.free(self.ptr) };
280    }
281}
282
283// kernel pool allocation functions
284extern "system" {
285    fn ExAllocatePoolWithTag(PoolType: u32, NumberOfBytes: usize, Tag: u32) -> *mut c_void;
286    fn ExFreePoolWithTag(P: *mut c_void, Tag: u32);
287}