use core::{
cmp,
fmt,
ops::{
self,
Deref, }, };
use kernel::{
num::Integer,
prelude::*, };
macro_rules! fits_within {
($value:expr, $type:ty, $n:expr) => {{
let shift: u32 = <$type>::BITS - $n;
($value << shift) >> shift == $value
}};
}
#[inline(always)]
fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool {
fits_within!(value, T, num_bits)
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, Default, Hash)]
pub struct Bounded<T: Integer, const N: u32>(T);
macro_rules! impl_const_new {
($($type:ty)*) => {
$(
impl<const N: u32> Bounded<$type, N> {
#[doc = ::core::concat!(
"let v = Bounded::<",
::core::stringify!($type),
", 4>::new::<7>();")]
pub const fn new<const VALUE: $type>() -> Self {
const_assert!(fits_within!(VALUE, $type, N));
unsafe { Self::__new(VALUE) }
}
}
)*
};
}
impl_const_new!(
u8 u16 u32 u64 usize
i8 i16 i32 i64 isize
);
impl<T, const N: u32> Bounded<T, N>
where
T: Integer,
{
const unsafe fn __new(value: T) -> Self {
const_assert!(N != 0);
const_assert!(N <= T::BITS);
Self(value)
}
pub fn try_new(value: T) -> Option<Self> {
fits_within(value, N).then(|| {
unsafe { Self::__new(value) }
})
}
#[inline(always)]
pub fn from_expr(expr: T) -> Self {
crate::build_assert!(
fits_within(expr, N),
"Requested value larger than maximal representable value."
);
unsafe { Self::__new(expr) }
}
pub const fn get(self) -> T {
self.0
}
pub const fn extend<const M: u32>(self) -> Bounded<T, M> {
const_assert!(
M >= N,
"Requested number of bits is less than the current representation."
);
unsafe { Bounded::__new(self.0) }
}
pub fn try_shrink<const M: u32>(self) -> Option<Bounded<T, M>> {
Bounded::<T, M>::try_new(self.get())
}
pub fn cast<U>(self) -> Bounded<U, N>
where
U: TryFrom<T> + Integer,
T: Integer,
U: Integer<Signedness = T::Signedness>,
{
let value = unsafe { U::try_from(self.get()).unwrap_unchecked() };
unsafe { Bounded::__new(value) }
}
pub fn shr<const SHIFT: u32, const RES: u32>(self) -> Bounded<T, RES> {
const { assert!(RES + SHIFT >= N) }
unsafe { Bounded::__new(self.0 >> SHIFT) }
}
pub fn shl<const SHIFT: u32, const RES: u32>(self) -> Bounded<T, RES> {
const { assert!(RES >= N + SHIFT) }
unsafe { Bounded::__new(self.0 << SHIFT) }
}
}
impl<T, const N: u32> Deref for Bounded<T, N>
where
T: Integer,
{
type Target = T;
fn deref(&self) -> &Self::Target {
if !fits_within(self.0, N) {
unsafe { core::hint::unreachable_unchecked() }
}
&self.0
}
}
pub trait TryIntoBounded<T: Integer, const N: u32> {
fn try_into_bounded(self) -> Option<Bounded<T, N>>;
}
impl<T, U, const N: u32> TryIntoBounded<T, N> for U
where
T: Integer,
U: TryInto<T>,
{
fn try_into_bounded(self) -> Option<Bounded<T, N>> {
self.try_into().ok().and_then(Bounded::try_new)
}
}
impl<T, U, const N: u32, const M: u32> PartialEq<Bounded<U, M>> for Bounded<T, N>
where
T: Integer,
U: Integer,
T: PartialEq<U>,
{
fn eq(&self, other: &Bounded<U, M>) -> bool {
self.get() == other.get()
}
}
impl<T, const N: u32> Eq for Bounded<T, N> where T: Integer {}
impl<T, U, const N: u32, const M: u32> PartialOrd<Bounded<U, M>> for Bounded<T, N>
where
T: Integer,
U: Integer,
T: PartialOrd<U>,
{
fn partial_cmp(&self, other: &Bounded<U, M>) -> Option<cmp::Ordering> {
self.get().partial_cmp(&other.get())
}
}
impl<T, const N: u32> Ord for Bounded<T, N>
where
T: Integer,
T: Ord,
{
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.get().cmp(&other.get())
}
}
impl<T, const N: u32> PartialEq<T> for Bounded<T, N>
where
T: Integer,
T: PartialEq,
{
fn eq(&self, other: &T) -> bool {
self.get() == *other
}
}
impl<T, const N: u32> PartialOrd<T> for Bounded<T, N>
where
T: Integer,
T: PartialOrd,
{
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
self.get().partial_cmp(other)
}
}
impl<T, const N: u32, const M: u32> ops::Add<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::Add<Output = T>,
{
type Output = T;
fn add(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() + rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::BitAnd<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::BitAnd<Output = T>,
{
type Output = T;
fn bitand(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() & rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::BitOr<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::BitOr<Output = T>,
{
type Output = T;
fn bitor(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() | rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::BitXor<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::BitXor<Output = T>,
{
type Output = T;
fn bitxor(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() ^ rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::Div<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::Div<Output = T>,
{
type Output = T;
fn div(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() / rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::Mul<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::Mul<Output = T>,
{
type Output = T;
fn mul(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() * rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::Rem<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::Rem<Output = T>,
{
type Output = T;
fn rem(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() % rhs.get()
}
}
impl<T, const N: u32, const M: u32> ops::Sub<Bounded<T, M>> for Bounded<T, N>
where
T: Integer,
T: ops::Sub<Output = T>,
{
type Output = T;
fn sub(self, rhs: Bounded<T, M>) -> Self::Output {
self.get() - rhs.get()
}
}
impl<T, const N: u32> ops::Add<T> for Bounded<T, N>
where
T: Integer,
T: ops::Add<Output = T>,
{
type Output = T;
fn add(self, rhs: T) -> Self::Output {
self.get() + rhs
}
}
impl<T, const N: u32> ops::BitAnd<T> for Bounded<T, N>
where
T: Integer,
T: ops::BitAnd<Output = T>,
{
type Output = T;
fn bitand(self, rhs: T) -> Self::Output {
self.get() & rhs
}
}
impl<T, const N: u32> ops::BitOr<T> for Bounded<T, N>
where
T: Integer,
T: ops::BitOr<Output = T>,
{
type Output = T;
fn bitor(self, rhs: T) -> Self::Output {
self.get() | rhs
}
}
impl<T, const N: u32> ops::BitXor<T> for Bounded<T, N>
where
T: Integer,
T: ops::BitXor<Output = T>,
{
type Output = T;
fn bitxor(self, rhs: T) -> Self::Output {
self.get() ^ rhs
}
}
impl<T, const N: u32> ops::Div<T> for Bounded<T, N>
where
T: Integer,
T: ops::Div<Output = T>,
{
type Output = T;
fn div(self, rhs: T) -> Self::Output {
self.get() / rhs
}
}
impl<T, const N: u32> ops::Mul<T> for Bounded<T, N>
where
T: Integer,
T: ops::Mul<Output = T>,
{
type Output = T;
fn mul(self, rhs: T) -> Self::Output {
self.get() * rhs
}
}
impl<T, const N: u32> ops::Neg for Bounded<T, N>
where
T: Integer,
T: ops::Neg<Output = T>,
{
type Output = T;
fn neg(self) -> Self::Output {
-self.get()
}
}
impl<T, const N: u32> ops::Not for Bounded<T, N>
where
T: Integer,
T: ops::Not<Output = T>,
{
type Output = T;
fn not(self) -> Self::Output {
!self.get()
}
}
impl<T, const N: u32> ops::Rem<T> for Bounded<T, N>
where
T: Integer,
T: ops::Rem<Output = T>,
{
type Output = T;
fn rem(self, rhs: T) -> Self::Output {
self.get() % rhs
}
}
impl<T, const N: u32> ops::Sub<T> for Bounded<T, N>
where
T: Integer,
T: ops::Sub<Output = T>,
{
type Output = T;
fn sub(self, rhs: T) -> Self::Output {
self.get() - rhs
}
}
impl<T, const N: u32> fmt::Display for Bounded<T, N>
where
T: Integer,
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::Binary for Bounded<T, N>
where
T: Integer,
T: fmt::Binary,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::LowerExp for Bounded<T, N>
where
T: Integer,
T: fmt::LowerExp,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::LowerHex for Bounded<T, N>
where
T: Integer,
T: fmt::LowerHex,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::Octal for Bounded<T, N>
where
T: Integer,
T: fmt::Octal,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::UpperExp for Bounded<T, N>
where
T: Integer,
T: fmt::UpperExp,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
impl<T, const N: u32> fmt::UpperHex for Bounded<T, N>
where
T: Integer,
T: fmt::UpperHex,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(f)
}
}
macro_rules! impl_size_rule {
($trait:ty, $($num_bits:literal)*) => {
$(
impl<T> $trait for Bounded<T, $num_bits> where T: Integer {}
)*
};
}
trait AtLeastXBits<const N: usize> {}
mod atleast_impls {
use super::*;
impl_size_rule!(AtLeastXBits<64>, 64);
impl<T> AtLeastXBits<32> for T where T: AtLeastXBits<64> {}
impl_size_rule!(AtLeastXBits<32>,
32 33 34 35 36 37 38 39
40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55
56 57 58 59 60 61 62 63
);
impl<T> AtLeastXBits<16> for T where T: AtLeastXBits<32> {}
impl_size_rule!(AtLeastXBits<16>,
16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31
);
impl<T> AtLeastXBits<8> for T where T: AtLeastXBits<16> {}
impl_size_rule!(AtLeastXBits<8>, 8 9 10 11 12 13 14 15);
}
macro_rules! impl_from_primitive {
($($type:ty)*) => {
$(
#[doc = ::core::concat!(
"Conversion from a [`",
::core::stringify!($type),
"`] into a [`Bounded`] of same signedness with enough bits to store it.")]
impl<T, const N: u32> From<$type> for Bounded<T, N>
where
$type: Integer,
T: Integer<Signedness = <$type as Integer>::Signedness> + From<$type>,
Self: AtLeastXBits<{ <$type as Integer>::BITS as usize }>,
{
fn from(value: $type) -> Self {
unsafe { Self::__new(T::from(value)) }
}
}
)*
}
}
impl_from_primitive!(
u8 u16 u32 u64 usize
i8 i16 i32 i64 isize
);
trait FitsInXBits<const N: usize> {}
mod fits_impls {
use super::*;
impl_size_rule!(FitsInXBits<8>, 1 2 3 4 5 6 7 8);
impl<T> FitsInXBits<16> for T where T: FitsInXBits<8> {}
impl_size_rule!(FitsInXBits<16>, 9 10 11 12 13 14 15 16);
impl<T> FitsInXBits<32> for T where T: FitsInXBits<16> {}
impl_size_rule!(FitsInXBits<32>,
17 18 19 20 21 22 23 24
25 26 27 28 29 30 31 32
);
impl<T> FitsInXBits<64> for T where T: FitsInXBits<32> {}
impl_size_rule!(FitsInXBits<64>,
33 34 35 36 37 38 39 40
41 42 43 44 45 46 47 48
49 50 51 52 53 54 55 56
57 58 59 60 61 62 63 64
);
}
macro_rules! impl_into_primitive {
($($type:ty)*) => {
$(
#[doc = ::core::concat!(
"Conversion from a [`Bounded`] with no more bits than a [`",
::core::stringify!($type),
"`] and of same signedness into [`",
::core::stringify!($type),
"`]")]
impl<T, const N: u32> From<Bounded<T, N>> for $type
where
$type: Integer + TryFrom<T>,
T: Integer<Signedness = <$type as Integer>::Signedness>,
Bounded<T, N>: FitsInXBits<{ <$type as Integer>::BITS as usize }>,
{
fn from(value: Bounded<T, N>) -> $type {
unsafe { <$type>::try_from(value.get()).unwrap_unchecked() }
}
}
)*
}
}
impl_into_primitive!(
u8 u16 u32 u64 usize
i8 i16 i32 i64 isize
);
impl<T> From<Bounded<T, 1>> for bool
where
T: Integer + Zeroable,
{
fn from(value: Bounded<T, 1>) -> Self {
value.get() != Zeroable::zeroed()
}
}
impl<T, const N: u32> From<bool> for Bounded<T, N>
where
T: Integer + From<bool>,
{
fn from(value: bool) -> Self {
unsafe { Self::__new(T::from(value)) }
}
}
impl<T> Bounded<T, 1>
where
T: Integer + Zeroable,
{
pub fn into_bool(self) -> bool {
self.into()
}
}