oxiblas_core/memory/
stack.rs

1//! Memory management utilities for OxiBLAS.
2//!
3//! This module provides:
4//! - Aligned memory allocation
5//! - Stack-based temporary allocation (StackReq pattern)
6//! - Cache-aware data layout utilities
7//! - Prefetch hints for cache optimization
8//! - Memory pool for temporary allocations
9//! - Custom allocator support via the `Alloc` trait
10
11use core::mem::{MaybeUninit, align_of, size_of};
12
13use super::aligned_vec::AlignedVec;
14use super::alloc::*;
15
16// =============================================================================
17// StackReq - Scratch space requirements
18// =============================================================================
19
20/// Represents the memory requirements for an operation.
21///
22/// This is used to pre-compute the scratch space needed for algorithms,
23/// allowing efficient allocation strategies.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct StackReq {
26    /// Size in bytes.
27    pub size: usize,
28    /// Alignment in bytes.
29    pub align: usize,
30}
31
32impl StackReq {
33    /// Zero requirements.
34    pub const ZERO: StackReq = StackReq { size: 0, align: 1 };
35
36    /// Creates a new stack requirement.
37    #[inline]
38    pub const fn new(size: usize, align: usize) -> Self {
39        StackReq { size, align }
40    }
41
42    /// Creates a requirement for a given type and count.
43    #[inline]
44    pub const fn new_for<T>(count: usize) -> Self {
45        StackReq {
46            size: count * size_of::<T>(),
47            align: align_of::<T>(),
48        }
49    }
50
51    /// Combines two requirements (both must be satisfied).
52    #[inline]
53    pub const fn and(self, other: Self) -> Self {
54        let align = if self.align > other.align {
55            self.align
56        } else {
57            other.align
58        };
59        let size1 = round_up_pow2(self.size, other.align);
60        StackReq {
61            size: size1 + other.size,
62            align,
63        }
64    }
65
66    /// Takes the maximum of two requirements (either one is sufficient).
67    #[inline]
68    pub const fn or(self, other: Self) -> Self {
69        let align = if self.align > other.align {
70            self.align
71        } else {
72            other.align
73        };
74        let size = if self.size > other.size {
75            self.size
76        } else {
77            other.size
78        };
79        StackReq { size, align }
80    }
81
82    /// Returns the requirement aligned to a larger alignment.
83    #[inline]
84    pub const fn with_align(self, align: usize) -> Self {
85        let new_align = if self.align > align {
86            self.align
87        } else {
88            align
89        };
90        StackReq {
91            size: self.size,
92            align: new_align,
93        }
94    }
95}
96
97/// Combines multiple stack requirements (all must be satisfied).
98#[macro_export]
99macro_rules! stack_req_all {
100    ($($req:expr),* $(,)?) => {{
101        let mut result = $crate::memory::StackReq::ZERO;
102        $(
103            result = result.and($req);
104        )*
105        result
106    }};
107}
108
109/// Takes the maximum of multiple stack requirements.
110#[macro_export]
111macro_rules! stack_req_any {
112    ($($req:expr),* $(,)?) => {{
113        let mut result = $crate::memory::StackReq::ZERO;
114        $(
115            result = result.or($req);
116        )*
117        result
118    }};
119}
120
121// =============================================================================
122// MemStack - Stack-based temporary allocation
123// =============================================================================
124
125/// A memory stack for temporary allocations.
126///
127/// This provides fast, stack-based allocation for scratch space needed
128/// by algorithms. Allocations are invalidated when the stack is reset.
129pub struct MemStack {
130    buffer: AlignedVec<u8>,
131    offset: usize,
132}
133
134impl MemStack {
135    /// Creates a new memory stack with the given requirement.
136    pub fn new(req: StackReq) -> Self {
137        let size = round_up_pow2(req.size, req.align);
138        MemStack {
139            buffer: AlignedVec::zeros(size),
140            offset: 0,
141        }
142    }
143
144    /// Creates a new memory stack with the given size.
145    pub fn with_size(size: usize) -> Self {
146        MemStack {
147            buffer: AlignedVec::zeros(size),
148            offset: 0,
149        }
150    }
151
152    /// Returns the remaining capacity.
153    #[inline]
154    pub fn remaining(&self) -> usize {
155        self.buffer.len() - self.offset
156    }
157
158    /// Resets the stack, invalidating all allocations.
159    #[inline]
160    pub fn reset(&mut self) {
161        self.offset = 0;
162    }
163
164    /// Allocates a slice of the given type.
165    ///
166    /// # Panics
167    /// Panics if there's not enough space.
168    pub fn alloc<T>(&mut self, count: usize) -> &mut [MaybeUninit<T>] {
169        let align = align_of::<T>();
170        let aligned_offset = round_up_pow2(self.offset, align);
171        let size = count * size_of::<T>();
172        let new_offset = aligned_offset + size;
173
174        assert!(new_offset <= self.buffer.len(), "MemStack overflow");
175
176        let ptr = unsafe { self.buffer.as_mut_ptr().add(aligned_offset) as *mut MaybeUninit<T> };
177        self.offset = new_offset;
178
179        unsafe { core::slice::from_raw_parts_mut(ptr, count) }
180    }
181
182    /// Allocates and zeros a slice of the given type.
183    pub fn alloc_zeroed<T: bytemuck::Zeroable>(&mut self, count: usize) -> &mut [T] {
184        let slice = self.alloc::<T>(count);
185        // Zero the memory
186        unsafe {
187            core::ptr::write_bytes(slice.as_mut_ptr() as *mut u8, 0, count * size_of::<T>());
188            core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut T, count)
189        }
190    }
191
192    /// Creates a sub-stack with the remaining memory.
193    ///
194    /// This is useful for recursive algorithms that need to pass
195    /// scratch space to sub-operations.
196    ///
197    /// Note: This is a placeholder implementation. A full implementation
198    /// would use raw pointers to share the buffer without ownership transfer.
199    pub fn make_sub_stack(&mut self) -> MemStack {
200        let _remaining = self.remaining();
201        let _ptr = unsafe { self.buffer.as_mut_ptr().add(self.offset) };
202
203        // Mark all remaining as used
204        self.offset = self.buffer.len();
205
206        // TODO: Implement proper sub-stack that shares buffer
207        MemStack {
208            buffer: AlignedVec::new(),
209            offset: 0,
210        }
211    }
212}