1use core::alloc::{Layout, LayoutError};
2use core::num::NonZeroUsize;
3
4#[derive(Debug, Clone, Copy, Eq, PartialEq)]
6pub struct StackReq {
7 align: Option<NonZeroUsize>,
8 size: usize,
9}
10
11impl Default for StackReq {
12 #[inline]
13 fn default() -> Self {
14 Self::empty()
15 }
16}
17
18#[inline(always)]
19const fn try_round_up_pow2(a: usize, b: usize) -> Option<usize> {
20 match a.checked_add(!b.wrapping_neg()) {
21 None => None,
22 Some(x) => Some(x & b.wrapping_neg()),
23 }
24}
25
26#[inline(always)]
27const fn max(a: usize, b: usize) -> usize {
28 if a > b { a } else { b }
29}
30
31impl StackReq {
32 pub const EMPTY: Self = Self {
34 align: unsafe { Some(NonZeroUsize::new_unchecked(1)) },
35 size: 0,
36 };
37 pub const OVERFLOW: Self = Self { align: None, size: 0 };
39
40 #[inline]
42 pub const fn empty() -> StackReq {
43 Self::EMPTY
44 }
45
46 #[inline]
55 pub const fn new_aligned<T>(n: usize, align: usize) -> StackReq {
56 if align >= core::mem::align_of::<T>() && align.is_power_of_two() {
57 StackReq {
58 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
59 size: core::mem::size_of::<T>(),
60 }
61 .array(n)
62 } else {
63 StackReq { align: None, size: 0 }
64 }
65 }
66
67 #[inline]
73 pub const fn new<T>(n: usize) -> StackReq {
74 StackReq::new_aligned::<T>(n, core::mem::align_of::<T>())
75 }
76
77 #[inline]
79 pub const fn size_bytes(&self) -> usize {
80 self.size
81 }
82
83 #[inline]
85 pub const fn align_bytes(&self) -> usize {
86 match self.align {
87 Some(align) => align.get(),
88 None => 0,
89 }
90 }
91
92 #[inline]
99 pub const fn unaligned_bytes_required(&self) -> usize {
100 match self.layout() {
101 Ok(layout) => layout.size() + (layout.align() - 1),
102 Err(_) => usize::MAX,
103 }
104 }
105
106 #[inline]
108 pub const fn layout(self) -> Result<Layout, LayoutError> {
109 Layout::from_size_align(self.size_bytes(), self.align_bytes())
110 }
111
112 #[inline]
119 pub const fn and(self, other: StackReq) -> StackReq {
120 match (self.align, other.align) {
121 (Some(left), Some(right)) => {
122 let align = max(left.get(), right.get());
123 let left = try_round_up_pow2(self.size, align);
124 let right = try_round_up_pow2(other.size, align);
125
126 match (left, right) {
127 (Some(left), Some(right)) => {
128 match left.checked_add(right) {
129 Some(size) => StackReq {
130 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
132 size,
133 },
134 _ => StackReq::OVERFLOW,
135 }
136 },
137 _ => StackReq::OVERFLOW,
138 }
139 },
140 _ => StackReq::OVERFLOW,
141 }
142 }
143
144 #[inline]
151 pub const fn all_of(reqs: &[Self]) -> Self {
152 let mut total = StackReq::EMPTY;
153 let mut reqs = reqs;
154 while let Some((req, next)) = reqs.split_first() {
155 total = total.and(*req);
156 reqs = next;
157 }
158 total
159 }
160
161 #[inline]
168 pub const fn or(self, other: StackReq) -> StackReq {
169 match (self.align, other.align) {
170 (Some(left), Some(right)) => {
171 let align = max(left.get(), right.get());
172 let left = try_round_up_pow2(self.size, align);
173 let right = try_round_up_pow2(other.size, align);
174
175 match (left, right) {
176 (Some(left), Some(right)) => {
177 let size = max(left, right);
178 StackReq {
179 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
181 size,
182 }
183 },
184 _ => StackReq::OVERFLOW,
185 }
186 },
187 _ => StackReq::OVERFLOW,
188 }
189 }
190
191 #[inline]
198 pub fn any_of(reqs: &[StackReq]) -> StackReq {
199 let mut total = StackReq::EMPTY;
200 let mut reqs = reqs;
201 while let Some((req, next)) = reqs.split_first() {
202 total = total.or(*req);
203 reqs = next;
204 }
205 total
206 }
207
208 #[inline]
210 pub const fn array(self, n: usize) -> StackReq {
211 match self.align {
212 Some(align) => {
213 let size = self.size.checked_mul(n);
214 match size {
215 Some(size) => StackReq { size, align: Some(align) },
216 None => StackReq::OVERFLOW,
217 }
218 },
219 None => StackReq::OVERFLOW,
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn round_up() {
230 assert_eq!(try_round_up_pow2(0, 4), Some(0));
231 assert_eq!(try_round_up_pow2(1, 4), Some(4));
232 assert_eq!(try_round_up_pow2(2, 4), Some(4));
233 assert_eq!(try_round_up_pow2(3, 4), Some(4));
234 assert_eq!(try_round_up_pow2(4, 4), Some(4));
235 }
236
237 #[test]
238 fn overflow() {
239 assert_eq!(StackReq::new::<u32>(usize::MAX).align_bytes(), 0);
240 }
241
242 #[test]
243 fn and_overflow() {
244 assert_eq!(StackReq::new::<u8>(usize::MAX).and(StackReq::new::<u8>(1)).align_bytes(), 0,);
245 }
246}