oxiblas_core/memory/
stack.rs1use core::mem::{MaybeUninit, align_of, size_of};
12
13use super::aligned_vec::AlignedVec;
14use super::alloc::*;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct StackReq {
26 pub size: usize,
28 pub align: usize,
30}
31
32impl StackReq {
33 pub const ZERO: StackReq = StackReq { size: 0, align: 1 };
35
36 #[inline]
38 pub const fn new(size: usize, align: usize) -> Self {
39 StackReq { size, align }
40 }
41
42 #[inline]
44 pub const fn new_for<T>(count: usize) -> Self {
45 StackReq {
46 size: count * size_of::<T>(),
47 align: align_of::<T>(),
48 }
49 }
50
51 #[inline]
53 pub const fn and(self, other: Self) -> Self {
54 let align = if self.align > other.align {
55 self.align
56 } else {
57 other.align
58 };
59 let size1 = round_up_pow2(self.size, other.align);
60 StackReq {
61 size: size1 + other.size,
62 align,
63 }
64 }
65
66 #[inline]
68 pub const fn or(self, other: Self) -> Self {
69 let align = if self.align > other.align {
70 self.align
71 } else {
72 other.align
73 };
74 let size = if self.size > other.size {
75 self.size
76 } else {
77 other.size
78 };
79 StackReq { size, align }
80 }
81
82 #[inline]
84 pub const fn with_align(self, align: usize) -> Self {
85 let new_align = if self.align > align {
86 self.align
87 } else {
88 align
89 };
90 StackReq {
91 size: self.size,
92 align: new_align,
93 }
94 }
95}
96
97#[macro_export]
99macro_rules! stack_req_all {
100 ($($req:expr),* $(,)?) => {{
101 let mut result = $crate::memory::StackReq::ZERO;
102 $(
103 result = result.and($req);
104 )*
105 result
106 }};
107}
108
109#[macro_export]
111macro_rules! stack_req_any {
112 ($($req:expr),* $(,)?) => {{
113 let mut result = $crate::memory::StackReq::ZERO;
114 $(
115 result = result.or($req);
116 )*
117 result
118 }};
119}
120
121pub struct MemStack {
130 buffer: AlignedVec<u8>,
131 offset: usize,
132}
133
134impl MemStack {
135 pub fn new(req: StackReq) -> Self {
137 let size = round_up_pow2(req.size, req.align);
138 MemStack {
139 buffer: AlignedVec::zeros(size),
140 offset: 0,
141 }
142 }
143
144 pub fn with_size(size: usize) -> Self {
146 MemStack {
147 buffer: AlignedVec::zeros(size),
148 offset: 0,
149 }
150 }
151
152 #[inline]
154 pub fn remaining(&self) -> usize {
155 self.buffer.len() - self.offset
156 }
157
158 #[inline]
160 pub fn reset(&mut self) {
161 self.offset = 0;
162 }
163
164 pub fn alloc<T>(&mut self, count: usize) -> &mut [MaybeUninit<T>] {
169 let align = align_of::<T>();
170 let aligned_offset = round_up_pow2(self.offset, align);
171 let size = count * size_of::<T>();
172 let new_offset = aligned_offset + size;
173
174 assert!(new_offset <= self.buffer.len(), "MemStack overflow");
175
176 let ptr = unsafe { self.buffer.as_mut_ptr().add(aligned_offset) as *mut MaybeUninit<T> };
177 self.offset = new_offset;
178
179 unsafe { core::slice::from_raw_parts_mut(ptr, count) }
180 }
181
182 pub fn alloc_zeroed<T: bytemuck::Zeroable>(&mut self, count: usize) -> &mut [T] {
184 let slice = self.alloc::<T>(count);
185 unsafe {
187 core::ptr::write_bytes(slice.as_mut_ptr() as *mut u8, 0, count * size_of::<T>());
188 core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut T, count)
189 }
190 }
191
192 pub fn make_sub_stack(&mut self) -> MemStack {
200 let _remaining = self.remaining();
201 let _ptr = unsafe { self.buffer.as_mut_ptr().add(self.offset) };
202
203 self.offset = self.buffer.len();
205
206 MemStack {
208 buffer: AlignedVec::new(),
209 offset: 0,
210 }
211 }
212}