use core::alloc::{Layout, LayoutError};
use core::num::NonZeroUsize;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct StackReq {
align: Option<NonZeroUsize>,
size: usize,
}
impl Default for StackReq {
#[inline]
fn default() -> Self {
Self::empty()
}
}
#[inline(always)]
const fn try_round_up_pow2(a: usize, b: usize) -> Option<usize> {
match a.checked_add(!b.wrapping_neg()) {
None => None,
Some(x) => Some(x & b.wrapping_neg()),
}
}
#[inline(always)]
const fn max(a: usize, b: usize) -> usize {
if a > b { a } else { b }
}
impl StackReq {
pub const EMPTY: Self = Self {
align: unsafe { Some(NonZeroUsize::new_unchecked(1)) },
size: 0,
};
pub const OVERFLOW: Self = Self { align: None, size: 0 };
#[inline]
pub const fn empty() -> StackReq {
Self::EMPTY
}
#[inline]
pub const fn new_aligned<T>(n: usize, align: usize) -> StackReq {
if align >= core::mem::align_of::<T>() && align.is_power_of_two() {
StackReq {
align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
size: core::mem::size_of::<T>(),
}
.array(n)
} else {
StackReq { align: None, size: 0 }
}
}
#[inline]
pub const fn new<T>(n: usize) -> StackReq {
StackReq::new_aligned::<T>(n, core::mem::align_of::<T>())
}
#[inline]
pub const fn size_bytes(&self) -> usize {
self.size
}
#[inline]
pub const fn align_bytes(&self) -> usize {
match self.align {
Some(align) => align.get(),
None => 0,
}
}
#[inline]
pub const fn unaligned_bytes_required(&self) -> usize {
match self.layout() {
Ok(layout) => layout.size() + (layout.align() - 1),
Err(_) => usize::MAX,
}
}
#[inline]
pub const fn layout(self) -> Result<Layout, LayoutError> {
Layout::from_size_align(self.size_bytes(), self.align_bytes())
}
#[inline]
pub const fn and(self, other: StackReq) -> StackReq {
match (self.align, other.align) {
(Some(left), Some(right)) => {
let align = max(left.get(), right.get());
let left = try_round_up_pow2(self.size, align);
let right = try_round_up_pow2(other.size, align);
match (left, right) {
(Some(left), Some(right)) => {
match left.checked_add(right) {
Some(size) => StackReq {
align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
size,
},
_ => StackReq::OVERFLOW,
}
},
_ => StackReq::OVERFLOW,
}
},
_ => StackReq::OVERFLOW,
}
}
#[inline]
pub const fn all_of(reqs: &[Self]) -> Self {
let mut total = StackReq::EMPTY;
let mut reqs = reqs;
while let Some((req, next)) = reqs.split_first() {
total = total.and(*req);
reqs = next;
}
total
}
#[inline]
pub const fn or(self, other: StackReq) -> StackReq {
match (self.align, other.align) {
(Some(left), Some(right)) => {
let align = max(left.get(), right.get());
let left = try_round_up_pow2(self.size, align);
let right = try_round_up_pow2(other.size, align);
match (left, right) {
(Some(left), Some(right)) => {
let size = max(left, right);
StackReq {
align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
size,
}
},
_ => StackReq::OVERFLOW,
}
},
_ => StackReq::OVERFLOW,
}
}
#[inline]
pub fn any_of(reqs: &[StackReq]) -> StackReq {
let mut total = StackReq::EMPTY;
let mut reqs = reqs;
while let Some((req, next)) = reqs.split_first() {
total = total.or(*req);
reqs = next;
}
total
}
#[inline]
pub const fn array(self, n: usize) -> StackReq {
match self.align {
Some(align) => {
let size = self.size.checked_mul(n);
match size {
Some(size) => StackReq { size, align: Some(align) },
None => StackReq::OVERFLOW,
}
},
None => StackReq::OVERFLOW,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_up() {
assert_eq!(try_round_up_pow2(0, 4), Some(0));
assert_eq!(try_round_up_pow2(1, 4), Some(4));
assert_eq!(try_round_up_pow2(2, 4), Some(4));
assert_eq!(try_round_up_pow2(3, 4), Some(4));
assert_eq!(try_round_up_pow2(4, 4), Some(4));
}
#[test]
fn overflow() {
assert_eq!(StackReq::new::<u32>(usize::MAX).align_bytes(), 0);
}
#[test]
fn and_overflow() {
assert_eq!(StackReq::new::<u8>(usize::MAX).and(StackReq::new::<u8>(1)).align_bytes(), 0,);
}
}