use crate::Bits;
use core::cmp::Ordering;
use core::marker::PhantomData;
use core::mem;
use core::ptr::NonNull;
#[cfg_attr(feature = "fallback", path = "fallback.rs")]
#[cfg_attr(not(feature = "fallback"), path = "impl.rs")]
mod ptr_impl;
use ptr_impl::PtrImpl;
trait NumBits {
const BITS: u32;
}
impl<T> NumBits for PhantomData<T> {
const BITS: u32 = mem::align_of::<T>().trailing_zeros();
}
struct ConstBits<const N: Bits>;
impl<const N: Bits> NumBits for ConstBits<N> {
const BITS: u32 = {
const_assert!(N as u32 as Bits == N, "`BITS` is too large");
N as _
};
}
macro_rules! check_bits {
($tz_bits:expr, $n:literal, $msg:literal $(,)?) => {
const_assert!($tz_bits.0 != $n || $tz_bits.1 <= $n, $msg);
};
}
impl<T, B: NumBits> PtrImpl<T, B> {
pub const BITS: u32 = B::BITS;
pub const ALIGNMENT: usize = 1_usize.wrapping_shl(Self::BITS);
pub const MASK: usize = Self::ALIGNMENT - 1;
const ASSERT: bool = {
let bits = Self::BITS;
let size = mem::size_of::<T>();
let align = mem::align_of::<T>();
let tz = mem::align_of::<T>().trailing_zeros();
let c = (tz, bits);
check_bits!(c, 0, "`BITS` must be 0 (alignment of T is 1)");
check_bits!(c, 1, "`BITS` must be <= 1 (alignment of T is 2)");
check_bits!(c, 2, "`BITS` must be <= 2 (alignment of T is 4)");
check_bits!(c, 3, "`BITS` must be <= 3 (alignment of T is 8)");
check_bits!(c, 4, "`BITS` must be <= 4 (alignment of T is 16)");
check_bits!(c, 5, "`BITS` must be <= 5 (alignment of T is 32)");
check_bits!(c, 6, "`BITS` must be <= 6 (alignment of T is 64)");
const_assert!(
bits <= tz,
"`BITS` cannot exceed align_of::<T>().trailing_zeros()",
);
const_assert!(
1_usize.checked_shl(bits).is_some(),
"`BITS` must be less than number of bits in `usize`",
);
const_assert!(align.is_power_of_two());
const_assert!(size == 0 || size >= align);
const_assert!(
size % align == 0,
"expected size of `T` to be a multiple of alignment",
);
true
};
fn assert() {
assert!(Self::ASSERT);
}
}
impl<T, B> Clone for PtrImpl<T, B> {
fn clone(&self) -> Self {
*self
}
}
impl<T, B> Copy for PtrImpl<T, B> {}
impl<T, B> Eq for PtrImpl<T, B> {}
impl<T, B> PartialOrd for PtrImpl<T, B> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
macro_rules! impl_tagged_ptr_common {
(
[$($ty_params:tt)*],
[$($ty_args:tt)*],
$doctest_context:literal $(,)?
) => { const _: () = {
use core::cmp::Ordering;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::ptr::NonNull;
impl<$($ty_params)*> TaggedPtr<$($ty_args)*> {
pub fn get(self) -> (NonNull<T>, usize) {
self.0.get()
}
pub fn ptr(self) -> NonNull<T> {
self.get().0
}
#[doc = $doctest_context]
pub fn set_ptr(&mut self, ptr: NonNull<T>) {
*self = Self::new(ptr, self.tag());
}
pub fn tag(self) -> usize {
self.get().1
}
#[doc = $doctest_context]
pub fn set_tag(&mut self, tag: usize) {
*self = Self::new(self.ptr(), tag);
}
}
impl<$($ty_params)*> Clone for TaggedPtr<$($ty_args)*> {
fn clone(&self) -> Self {
*self
}
}
impl<$($ty_params)*> Copy for TaggedPtr<$($ty_args)*> {}
impl<$($ty_params)*> Eq for TaggedPtr<$($ty_args)*> {}
impl<$($ty_params)*> PartialEq for TaggedPtr<$($ty_args)*> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<$($ty_params)*> Ord for TaggedPtr<$($ty_args)*> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}
impl<$($ty_params)*> PartialOrd for TaggedPtr<$($ty_args)*> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<$($ty_params)*> Hash for TaggedPtr<$($ty_args)*> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<$($ty_params)*> fmt::Debug for TaggedPtr<$($ty_args)*> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (ptr, tag) = self.get();
f.debug_struct("TaggedPtr")
.field("ptr", &ptr)
.field("tag", &tag)
.finish()
}
}
}; };
}
pub mod implied;
with_bits_doc! {
#[repr(transparent)]
pub struct TaggedPtr<T, const BITS: Bits>(PtrImpl<T, ConstBits<BITS>>);
}
impl<T, const BITS: Bits> TaggedPtr<T, BITS> {
pub const BITS: u32 = BITS as _;
pub const MAX_TAG: usize = Self::max_tag();
const fn max_tag() -> usize {
PtrImpl::<T, ConstBits<BITS>>::MASK
}
pub fn new(ptr: NonNull<T>, tag: usize) -> Self {
Self(PtrImpl::new(ptr, tag))
}
pub unsafe fn new_unchecked(ptr: NonNull<T>, tag: usize) -> Self {
Self(unsafe { PtrImpl::new_unchecked(ptr, tag) })
}
pub unsafe fn new_unchecked_dereferenceable(
ptr: NonNull<T>,
tag: usize,
) -> Self {
Self(unsafe { PtrImpl::new_unchecked_dereferenceable(ptr, tag) })
}
}
impl_tagged_ptr_common!(
[T, const BITS: Bits],
[T, BITS],
"# type TaggedPtr<T> = tagged_pointer::TaggedPtr<T, 0>;",
);