use core::{fmt, mem::align_of, num::NonZeroUsize, ptr::NonNull};
use crate::rust_nightly_apis::ptr;
#[repr(transparent)]
pub struct TaggedPointer<P, const MIN_BITS: u32>(NonNull<P>);
impl<P, const MIN_BITS: u32> TaggedPointer<P, MIN_BITS> {
pub const ALIGNMENT: usize = align_of::<P>();
pub const DATA_MASK: usize = !Self::POINTER_MASK;
pub const NUM_BITS: u32 = {
let num_bits = Self::ALIGNMENT.trailing_zeros();
assert!(
num_bits >= MIN_BITS,
"need the alignment of the pointed to type to have sufficient bits"
);
num_bits
};
pub const POINTER_MASK: usize = usize::MAX << Self::NUM_BITS;
pub fn new(pointer: *mut P) -> Option<TaggedPointer<P, MIN_BITS>> {
if pointer.is_null() {
return None;
}
unsafe { Some(Self::new_unchecked(pointer)) }
}
pub unsafe fn new_unchecked(pointer: *mut P) -> TaggedPointer<P, MIN_BITS> {
let unchecked_ptr = unsafe { NonNull::new_unchecked(pointer) };
let ptr_addr = ptr::mut_addr(unchecked_ptr.as_ptr());
assert_eq!(
ptr_addr & Self::DATA_MASK,
0,
"this pointer was not aligned"
);
TaggedPointer(unchecked_ptr)
}
pub fn new_with_data(pointer: *mut P, data: usize) -> Option<TaggedPointer<P, MIN_BITS>> {
let mut tagged_ptr = TaggedPointer::new(pointer)?;
tagged_ptr.set_data(data);
Some(tagged_ptr)
}
#[inline]
pub fn to_ptr(self) -> NonNull<P> {
ptr::nonnull_map_addr(self.0, |ptr_addr|
unsafe { NonZeroUsize::new_unchecked(ptr_addr.get() & Self::POINTER_MASK) })
}
#[inline]
pub fn to_data(self) -> usize {
let ptr_addr = ptr::mut_addr(self.0.as_ptr());
ptr_addr & Self::DATA_MASK
}
pub fn set_data(&mut self, data: usize) {
assert_eq!(
data & Self::POINTER_MASK,
0,
"cannot set more data beyond the lowest NUM_BITS"
);
let data = data & Self::DATA_MASK;
self.0 = ptr::nonnull_map_addr(self.0, |ptr_addr| unsafe {
NonZeroUsize::new_unchecked(ptr_addr.get() & Self::POINTER_MASK) | data
});
}
}
impl<P, const MIN_BITS: u32> From<NonNull<P>> for TaggedPointer<P, MIN_BITS> {
fn from(pointer: NonNull<P>) -> Self {
unsafe { Self::new_unchecked(pointer.as_ptr()) }
}
}
impl<P, const MIN_BITS: u32> From<TaggedPointer<P, MIN_BITS>> for NonNull<P> {
fn from(pointer: TaggedPointer<P, MIN_BITS>) -> Self {
unsafe { NonNull::new_unchecked(pointer.to_ptr().as_ptr()) }
}
}
impl<P, const MIN_BITS: u32> From<&mut P> for TaggedPointer<P, MIN_BITS> {
fn from(reference: &mut P) -> Self {
unsafe { Self::new_unchecked(core::ptr::from_mut(reference)) }
}
}
impl<P, const MIN_BITS: u32> core::hash::Hash for TaggedPointer<P, MIN_BITS> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<P, const MIN_BITS: u32> Ord for TaggedPointer<P, MIN_BITS> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<P, const MIN_BITS: u32> PartialOrd for TaggedPointer<P, MIN_BITS> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<P, const MIN_BITS: u32> Eq for TaggedPointer<P, MIN_BITS> {}
impl<P, const MIN_BITS: u32> PartialEq for TaggedPointer<P, MIN_BITS> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<P, const MIN_BITS: u32> Clone for TaggedPointer<P, MIN_BITS> {
fn clone(&self) -> Self {
*self
}
}
impl<P, const MIN_BITS: u32> Copy for TaggedPointer<P, MIN_BITS> {}
impl<P, const MIN_BITS: u32> fmt::Debug for TaggedPointer<P, MIN_BITS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TaggedPointer")
.field("pointer", &self.to_ptr())
.field("data", &self.to_data())
.finish()
}
}
impl<P, const MIN_BITS: u32> fmt::Pointer for TaggedPointer<P, MIN_BITS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Pointer::fmt(&self.to_ptr(), f)
}
}
#[cfg(test)]
mod tests {
use alloc::boxed::Box;
use super::*;
#[test]
fn successful_tag() {
let pointee = "Hello world!";
let pointer = Box::into_raw(Box::new(pointee));
let tag_data = 0b101usize;
let mut tagged_pointer =
TaggedPointer::<&str, 3>::new_with_data(pointer, tag_data).unwrap();
assert_eq!(unsafe { *tagged_pointer.to_ptr().as_ptr() }, "Hello world!");
assert_eq!(tagged_pointer.to_data(), 0b101);
tagged_pointer.set_data(0b010);
assert_eq!(unsafe { *tagged_pointer.to_ptr().as_ptr() }, "Hello world!");
assert_eq!(tagged_pointer.to_data(), 0b010);
unsafe {
drop(Box::from_raw(tagged_pointer.to_ptr().as_ptr()));
}
}
#[test]
fn create_pointer_set_and_retrieve_data() {
let raw_pointer = Box::into_raw(Box::new(10));
let mut p = TaggedPointer::<_, 2>::new(raw_pointer).unwrap();
assert_eq!(p.to_data(), 0);
p.set_data(1);
assert_eq!(p.to_data(), 1);
assert_eq!(unsafe { *p.to_ptr().as_ptr() }, 10);
p.set_data(3);
assert_eq!(p.to_data(), 3);
assert_eq!(unsafe { *p.to_ptr().as_ptr() }, 10);
unsafe {
let _ = Box::from_raw(p.to_ptr().as_ptr());
};
}
#[test]
fn create_pointer_with_data_and_retrieve_data() {
let raw_pointer = Box::into_raw(Box::new(30));
let mut p = TaggedPointer::<_, 2>::new_with_data(raw_pointer, 3).unwrap();
assert_eq!(p.to_data(), 3);
assert_eq!(unsafe { *p.to_ptr().as_ptr() }, 30);
p.set_data(0);
assert_eq!(unsafe { *p.to_ptr().as_ptr() }, 30);
assert_eq!(p.to_data(), 0);
unsafe {
let _ = Box::from_raw(p.to_ptr().as_ptr());
};
}
#[test]
#[should_panic = "cannot set more data beyond the lowest NUM_BITS"]
fn set_data_beyond_capacity_u8() {
let mut val = 0u8;
let raw_ptr = &mut val as *mut _;
let mut p = TaggedPointer::<_, 0>::new(raw_ptr).unwrap();
p.set_data(0b1);
}
#[test]
#[should_panic = "cannot set more data beyond the lowest NUM_BITS"]
fn set_data_beyond_capacity_u16() {
let mut val = 0u16;
let raw_ptr = &mut val as *mut _;
let mut p = TaggedPointer::<_, 1>::new(raw_ptr).unwrap();
p.set_data(0b11);
}
#[test]
#[should_panic = "cannot set more data beyond the lowest NUM_BITS"]
fn set_data_beyond_capacity_u32() {
let mut val = 0u32;
let raw_ptr = &mut val as *mut _;
let mut p = TaggedPointer::<_, 2>::new(raw_ptr).unwrap();
p.set_data(0b111);
}
#[test]
#[should_panic = "cannot set more data beyond the lowest NUM_BITS"]
fn set_data_beyond_capacity_u64() {
let mut val = 0u64;
let raw_ptr = &mut val as *mut _;
let mut p = TaggedPointer::<_, 3>::new(raw_ptr).unwrap();
p.set_data(0b1111);
}
#[test]
fn set_data_different_alignments() {
let mut p0 = TaggedPointer::<_, 0>::new(Box::into_raw(Box::<[u8; 0]>::new([]))).unwrap();
let mut p1 = TaggedPointer::<_, 0>::new(Box::into_raw(Box::new(false))).unwrap();
let mut p2 = TaggedPointer::<_, 0>::new(Box::into_raw(Box::new(2u8))).unwrap();
let mut p3 = TaggedPointer::<_, 1>::new(Box::into_raw(Box::new(3u16))).unwrap();
let mut p4 = TaggedPointer::<_, 2>::new(Box::into_raw(Box::new(4u32))).unwrap();
let mut p5 = TaggedPointer::<_, 3>::new(Box::into_raw(Box::new(5u64))).unwrap();
assert_eq!(p0.to_data(), 0);
assert_eq!(unsafe { *p0.to_ptr().as_ptr() }.len(), 0);
p0.set_data(0);
assert_eq!(unsafe { *p0.to_ptr().as_ptr() }.len(), 0);
assert_eq!(p0.to_data(), 0);
assert_eq!(p1.to_data(), 0);
assert!(unsafe { !*p1.to_ptr().as_ptr() });
p1.set_data(0);
assert_eq!(p1.to_data(), 0);
assert!(unsafe { !*p1.to_ptr().as_ptr() });
assert_eq!(p2.to_data(), 0);
assert_eq!(unsafe { *p2.to_ptr().as_ptr() }, 2);
p2.set_data(0);
assert_eq!(p2.to_data(), 0);
assert_eq!(unsafe { *p2.to_ptr().as_ptr() }, 2);
assert_eq!(p3.to_data(), 0);
assert_eq!(unsafe { *p3.to_ptr().as_ptr() }, 3);
p3.set_data(1);
assert_eq!(p3.to_data(), 1);
assert_eq!(unsafe { *p3.to_ptr().as_ptr() }, 3);
assert_eq!(p4.to_data(), 0);
assert_eq!(unsafe { *p4.to_ptr().as_ptr() }, 4);
p4.set_data(3);
assert_eq!(p4.to_data(), 3);
assert_eq!(unsafe { *p4.to_ptr().as_ptr() }, 4);
assert_eq!(p5.to_data(), 0);
assert_eq!(unsafe { *p5.to_ptr().as_ptr() }, 5);
p5.set_data(7);
assert_eq!(p5.to_data(), 7);
assert_eq!(unsafe { *p5.to_ptr().as_ptr() }, 5);
unsafe {
drop(Box::from_raw(p1.to_ptr().as_ptr()));
drop(Box::from_raw(p2.to_ptr().as_ptr()));
drop(Box::from_raw(p3.to_ptr().as_ptr()));
drop(Box::from_raw(p4.to_ptr().as_ptr()));
drop(Box::from_raw(p5.to_ptr().as_ptr()));
}
}
#[test]
#[cfg(target_pointer_width = "64")]
fn test_alignment_bits_and_mask_values() {
assert_eq!(TaggedPointer::<(), 0>::ALIGNMENT, 1);
assert_eq!(TaggedPointer::<(), 0>::NUM_BITS, 0);
assert_eq!(
TaggedPointer::<(), 0>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_usize
);
assert_eq!(TaggedPointer::<u8, 0>::ALIGNMENT, 1);
assert_eq!(TaggedPointer::<u8, 0>::NUM_BITS, 0);
assert_eq!(
TaggedPointer::<u8, 0>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_usize
);
assert_eq!(TaggedPointer::<u16, 1>::ALIGNMENT, 2);
assert_eq!(TaggedPointer::<u16, 1>::NUM_BITS, 1);
assert_eq!(
TaggedPointer::<u16, 1>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1110_usize
);
assert_eq!(TaggedPointer::<u32, 2>::ALIGNMENT, 4);
assert_eq!(TaggedPointer::<u32, 2>::NUM_BITS, 2);
assert_eq!(
TaggedPointer::<u32, 2>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1100_usize
);
assert_eq!(TaggedPointer::<u64, 3>::ALIGNMENT, 8);
assert_eq!(TaggedPointer::<u64, 3>::NUM_BITS, 3);
assert_eq!(
TaggedPointer::<u64, 3>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1000_usize
);
#[cfg(feature = "std")]
let arch = std::env::consts::ARCH;
#[cfg(not(feature = "std"))]
let arch = "no_std";
assert_eq!(
TaggedPointer::<u128, 5>::ALIGNMENT,
16,
"Target architecture [{arch}]",
);
assert_eq!(TaggedPointer::<u128, 3>::NUM_BITS, 4);
assert_eq!(
TaggedPointer::<u128, 3>::POINTER_MASK,
0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_0000_usize
);
}
}