Skip to main content

zk_nalloc/
bump.rs

1//! Core bump allocator for nalloc.
2//!
3//! A bump allocator is the fastest possible allocator: it simply increments
4//! a pointer. This module provides a thread-safe, atomic bump allocator
5//! optimized for ZK prover workloads with fallback support.
6
7use std::alloc::{GlobalAlloc, Layout, System};
8use std::ptr::NonNull;
9use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
10
11use crate::config::SECURE_WIPE_PATTERN;
12
13/// A fast, lock-free bump allocator with fallback support.
14///
15/// Thread-safety is achieved via atomic compare-and-swap on the cursor.
16/// This allows multiple threads to allocate concurrently without locks,
17/// though there may be occasional retries on contention.
18///
19/// When the arena is exhausted and the `fallback` feature is enabled,
20/// allocations fall back to the system allocator.
21pub struct BumpAlloc {
22    /// Base pointer of the memory region (never changes after init).
23    base: NonNull<u8>,
24    /// End pointer of the memory region (never changes after init).
25    limit: NonNull<u8>,
26    /// Current allocation cursor (atomically updated).
27    cursor: AtomicUsize,
28    /// Tracks whether the arena has been recycled (reset after use).
29    /// Used to optimize zero-initialization in WitnessArena.
30    is_recycled: AtomicBool,
31    /// Counter for fallback allocations (for monitoring).
32    #[cfg(feature = "fallback")]
33    fallback_count: AtomicUsize,
34    /// Total bytes allocated via fallback.
35    #[cfg(feature = "fallback")]
36    fallback_bytes: AtomicUsize,
37}
38
39impl BumpAlloc {
40    /// Create a new bump allocator from a raw memory block.
41    ///
42    /// # Safety
43    /// The memory block `[base, base+size)` must be valid and writable.
44    #[inline]
45    pub unsafe fn new(base: *mut u8, size: usize) -> Self {
46        debug_assert!(!base.is_null());
47        debug_assert!(size > 0);
48
49        let base_nn = NonNull::new_unchecked(base);
50        let limit_nn = NonNull::new_unchecked(base.add(size));
51
52        Self {
53            base: base_nn,
54            limit: limit_nn,
55            cursor: AtomicUsize::new(base as usize),
56            is_recycled: AtomicBool::new(false),
57            #[cfg(feature = "fallback")]
58            fallback_count: AtomicUsize::new(0),
59            #[cfg(feature = "fallback")]
60            fallback_bytes: AtomicUsize::new(0),
61        }
62    }
63
64    /// Get the base pointer of this allocator.
65    #[inline]
66    pub fn base_ptr(&self) -> *mut u8 {
67        self.base.as_ptr()
68    }
69
70    /// Allocate memory with the given size and alignment.
71    ///
72    /// Returns a null pointer if there is not enough space and fallback is disabled.
73    /// With the `fallback` feature, falls back to system allocator.
74    #[inline(always)]
75    pub fn alloc(&self, size: usize, align: usize) -> *mut u8 {
76        // Runtime validation (Issue #6): prevent memory corruption from invalid inputs
77        if size == 0 || align == 0 || !align.is_power_of_two() {
78            return std::ptr::null_mut();
79        }
80
81        loop {
82            let current = self.cursor.load(Ordering::Relaxed);
83
84            // Issue #7: Use checked arithmetic to prevent integer overflow
85            let aligned = match current.checked_add(align - 1) {
86                Some(v) => v & !(align - 1),
87                None => return self.handle_exhaustion(size, align),
88            };
89            let next = match aligned.checked_add(size) {
90                Some(v) => v,
91                None => return self.handle_exhaustion(size, align),
92            };
93
94            if next > self.limit.as_ptr() as usize {
95                // Arena exhausted
96                return self.handle_exhaustion(size, align);
97            }
98
99            if self
100                .cursor
101                .compare_exchange_weak(current, next, Ordering::AcqRel, Ordering::Relaxed)
102                .is_ok()
103            {
104                return aligned as *mut u8;
105            }
106            // Contention: another thread allocated concurrently. Retry.
107        }
108    }
109
110    /// Handle arena exhaustion - either fallback or return null.
111    #[cold]
112    #[inline(never)]
113    fn handle_exhaustion(&self, size: usize, align: usize) -> *mut u8 {
114        #[cfg(debug_assertions)]
115        {
116            eprintln!(
117                "[nalloc] Arena exhausted: requested {} bytes (align {}), remaining {} bytes",
118                size,
119                align,
120                self.remaining()
121            );
122        }
123
124        #[cfg(feature = "fallback")]
125        {
126            // Fall back to system allocator
127            let layout = match Layout::from_size_align(size, align) {
128                Ok(l) => l,
129                Err(_) => return std::ptr::null_mut(),
130            };
131
132            let ptr = unsafe { System.alloc(layout) };
133
134            if !ptr.is_null() {
135                self.fallback_count.fetch_add(1, Ordering::Relaxed);
136                self.fallback_bytes.fetch_add(size, Ordering::Relaxed);
137
138                #[cfg(debug_assertions)]
139                eprintln!("[nalloc] Fallback allocation: {} bytes", size);
140            }
141
142            ptr
143        }
144
145        #[cfg(not(feature = "fallback"))]
146        {
147            std::ptr::null_mut()
148        }
149    }
150
151    /// Check if this arena has been recycled (reset after initial use).
152    #[inline]
153    pub fn is_recycled(&self) -> bool {
154        self.is_recycled.load(Ordering::Relaxed)
155    }
156
157    /// Get the number of fallback allocations (only with `fallback` feature).
158    #[cfg(feature = "fallback")]
159    #[inline]
160    pub fn fallback_count(&self) -> usize {
161        self.fallback_count.load(Ordering::Relaxed)
162    }
163
164    /// Get the total bytes allocated via fallback (only with `fallback` feature).
165    ///
166    /// **Note (Issue #9)**: This tracks the *requested* allocation size, not the actual
167    /// size allocated by the system allocator (which may be larger due to alignment
168    /// and internal bookkeeping). Use this for monitoring, not precise accounting.
169    #[cfg(feature = "fallback")]
170    #[inline]
171    pub fn fallback_bytes(&self) -> usize {
172        self.fallback_bytes.load(Ordering::Relaxed)
173    }
174
175    /// Reset the bump pointer to the base.
176    ///
177    /// # Safety
178    /// All previously allocated memory becomes invalid after this call.
179    ///
180    /// # Warning (Issue #10)
181    /// **Fallback allocations are NOT freed by reset.** When arena exhaustion triggers
182    /// fallback to the system allocator (with `fallback` feature), those allocations
183    /// must be individually deallocated via `GlobalAlloc::dealloc`. If using NAlloc
184    /// as the global allocator, this happens automatically when the memory is dropped.
185    /// However, if using arenas directly, be aware that reset only reclaims arena memory,
186    /// not system allocator memory.
187    #[inline]
188    pub unsafe fn reset(&self) {
189        self.cursor
190            .store(self.base.as_ptr() as usize, Ordering::SeqCst);
191        self.is_recycled.store(true, Ordering::Release);
192
193        #[cfg(feature = "fallback")]
194        {
195            // Reset fallback counters
196            self.fallback_count.store(0, Ordering::Relaxed);
197            self.fallback_bytes.store(0, Ordering::Relaxed);
198        }
199    }
200
201    /// Zero out all memory in the arena and reset the cursor.
202    ///
203    /// This is critical for security-sensitive applications like ZK provers,
204    /// where witness data must be wiped after use to prevent leakage.
205    ///
206    /// Uses volatile writes to prevent the compiler from optimizing away
207    /// the zeroing operation (dead store elimination).
208    ///
209    /// # Safety
210    /// All previously allocated memory becomes invalid after this call.
211    #[inline]
212    pub unsafe fn secure_reset(&self) {
213        let base = self.base.as_ptr();
214        let size = self.limit.as_ptr() as usize - base as usize;
215
216        // Use volatile writes to prevent dead store elimination.
217        // This ensures the memory is actually zeroed even if it's never read again.
218        Self::volatile_memset(base, SECURE_WIPE_PATTERN, size);
219
220        // Issue #5: Full memory barrier for multi-threaded safety.
221        // compiler_fence only prevents compiler reordering, not CPU reordering.
222        // Using atomic fence ensures other threads observe the zeroed memory
223        // before seeing the reset cursor.
224        fence(Ordering::SeqCst);
225
226        self.reset();
227    }
228
229    /// Volatile memset implementation that cannot be optimized away.
230    ///
231    /// This is critical for cryptographic security - we need to guarantee
232    /// that sensitive data is actually erased from memory.
233    #[inline(never)]
234    #[allow(unreachable_code)] // Platform-specific code paths return early, making fallback unreachable on some platforms
235    unsafe fn volatile_memset(ptr: *mut u8, value: u8, len: usize) {
236        // Method 1: Use platform-specific secure zeroing where available (for value == 0)
237        #[cfg(any(target_os = "linux", target_os = "android"))]
238        if value == 0 {
239            // explicit_bzero is guaranteed not to be optimized away
240            extern "C" {
241                fn explicit_bzero(s: *mut libc::c_void, n: libc::size_t);
242            }
243            explicit_bzero(ptr as *mut libc::c_void, len);
244            return;
245        }
246
247        #[cfg(target_vendor = "apple")]
248        {
249            // memset_s is guaranteed not to be optimized away (C11)
250            // Note: memset_s supports non-zero values
251            extern "C" {
252                fn memset_s(
253                    s: *mut libc::c_void,
254                    smax: libc::size_t,
255                    c: libc::c_int,
256                    n: libc::size_t,
257                ) -> libc::c_int;
258            }
259            let _ = memset_s(ptr as *mut libc::c_void, len, value as libc::c_int, len);
260            return;
261        }
262
263        #[cfg(target_os = "windows")]
264        if value == 0 {
265            // RtlSecureZeroMemory is guaranteed not to be optimized away
266            extern "system" {
267                fn RtlSecureZeroMemory(ptr: *mut u8, len: usize);
268            }
269            RtlSecureZeroMemory(ptr, len);
270            return;
271        }
272
273        // Issue #4: Generic volatile write loop for:
274        // - Non-zero values on Linux/Android/Windows (platform APIs only handle zero)
275        // - All values on other platforms
276        // Using usize-sized writes for better performance
277        let ptr_usize = ptr as *mut usize;
278        let pattern_usize = if value == 0 {
279            0usize
280        } else {
281            let mut p = 0usize;
282            for i in 0..std::mem::size_of::<usize>() {
283                p |= (value as usize) << (i * 8);
284            }
285            p
286        };
287
288        let full_words = len / std::mem::size_of::<usize>();
289        let remainder = len % std::mem::size_of::<usize>();
290
291        // Write full usize words
292        for i in 0..full_words {
293            std::ptr::write_volatile(ptr_usize.add(i), pattern_usize);
294        }
295
296        // Write remaining bytes
297        let remainder_ptr = ptr.add(full_words * std::mem::size_of::<usize>());
298        for i in 0..remainder {
299            std::ptr::write_volatile(remainder_ptr.add(i), value);
300        }
301    }
302
303    /// Returns the total capacity in bytes.
304    #[inline]
305    pub fn capacity(&self) -> usize {
306        self.limit.as_ptr() as usize - self.base.as_ptr() as usize
307    }
308
309    /// Returns the number of bytes currently allocated.
310    #[inline]
311    pub fn used(&self) -> usize {
312        self.cursor.load(Ordering::Relaxed) - self.base.as_ptr() as usize
313    }
314
315    /// Returns the number of bytes remaining.
316    #[inline]
317    pub fn remaining(&self) -> usize {
318        self.capacity() - self.used()
319    }
320}
321
322// Safety: BumpAlloc can be shared across threads because:
323// - `base` and `limit` are never modified after construction
324// - `cursor` uses atomic operations for thread-safe updates
325// - `is_recycled` uses atomic operations
326unsafe impl Send for BumpAlloc {}
327unsafe impl Sync for BumpAlloc {}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_nonnull_safety() {
335        let mut buffer = vec![0u8; 1024];
336        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
337
338        assert_eq!(alloc.capacity(), 1024);
339        assert_eq!(alloc.used(), 0);
340        assert_eq!(alloc.remaining(), 1024);
341        assert!(!alloc.is_recycled());
342    }
343
344    #[test]
345    fn test_recycled_flag() {
346        let mut buffer = vec![0u8; 1024];
347        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
348
349        assert!(!alloc.is_recycled());
350
351        let _ = alloc.alloc(64, 8);
352        assert!(!alloc.is_recycled());
353
354        unsafe { alloc.reset() };
355        assert!(alloc.is_recycled());
356    }
357
358    #[test]
359    fn test_secure_reset_zeroes_memory() {
360        let mut buffer = vec![0xFFu8; 1024];
361        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
362
363        // Allocate and write data
364        let ptr = alloc.alloc(512, 8);
365        assert!(!ptr.is_null());
366        unsafe {
367            std::ptr::write_bytes(ptr, 0xAB, 512);
368        }
369
370        // Secure reset
371        unsafe { alloc.secure_reset() };
372
373        // Verify memory is zeroed
374        for i in 0..1024 {
375            assert_eq!(buffer[i], 0, "Byte {} not zeroed", i);
376        }
377    }
378
379    #[test]
380    fn test_alignment() {
381        let mut buffer = vec![0u8; 4096];
382        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
383
384        // Test various alignments
385        for align_pow in 0..8 {
386            let align = 1usize << align_pow;
387            let ptr = alloc.alloc(64, align);
388            assert!(!ptr.is_null());
389            assert_eq!((ptr as usize) % align, 0, "Alignment {} failed", align);
390        }
391    }
392
393    #[test]
394    #[cfg(feature = "fallback")]
395    fn test_fallback_allocation() {
396        // Create a tiny arena that will exhaust quickly
397        let mut buffer = vec![0u8; 256];
398        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
399
400        // Fill the arena
401        let _ = alloc.alloc(256, 1);
402
403        // This should trigger fallback
404        let ptr = alloc.alloc(64, 8);
405        assert!(!ptr.is_null(), "Fallback allocation should succeed");
406
407        assert!(alloc.fallback_count() > 0, "Fallback count should increase");
408        assert!(alloc.fallback_bytes() >= 64, "Fallback bytes should track");
409
410        // Don't forget to free the fallback allocation
411        unsafe {
412            System.dealloc(ptr, Layout::from_size_align(64, 8).unwrap());
413        }
414    }
415
416    #[test]
417    #[cfg(not(feature = "fallback"))]
418    fn test_exhaustion_returns_null() {
419        let mut buffer = vec![0u8; 256];
420        let alloc = unsafe { BumpAlloc::new(buffer.as_mut_ptr(), buffer.len()) };
421
422        // Fill the arena
423        let _ = alloc.alloc(256, 1);
424
425        // This should return null without fallback
426        let ptr = alloc.alloc(64, 8);
427        assert!(
428            ptr.is_null(),
429            "Should return null when exhausted without fallback"
430        );
431    }
432}