use std::{
ops::{Deref, DerefMut},
ptr::NonNull,
};
use diskann_utils::{Reborrow, ReborrowMut};
pub trait CopyRef {
type Target: Copy;
fn copy_ref(&self) -> Self::Target;
}
pub trait CopyMut: CopyRef {
fn copy_mut(&mut self, value: Self::Target);
}
#[derive(Debug, Clone, Copy, Default)]
#[repr(transparent)]
pub struct Owned<T>(pub T)
where
T: 'static;
impl<T> From<T> for Owned<T>
where
T: 'static,
{
fn from(value: T) -> Self {
Self(value)
}
}
impl<T> Deref for Owned<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Owned<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> CopyRef for Owned<T>
where
T: Copy,
{
type Target = T;
fn copy_ref(&self) -> T {
self.0
}
}
impl<T> CopyMut for Owned<T>
where
T: Copy,
{
fn copy_mut(&mut self, value: T) {
self.0 = value;
}
}
impl<'a, T> Reborrow<'a> for Owned<T>
where
T: Copy,
{
type Target = Ref<'a, T>;
fn reborrow(&'a self) -> Self::Target {
Ref::from(&self.0)
}
}
impl<'a, T> ReborrowMut<'a> for Owned<T>
where
T: Copy,
{
type Target = Mut<'a, T>;
fn reborrow_mut(&'a mut self) -> Self::Target {
Mut::from(&mut self.0)
}
}
#[derive(Debug)]
pub struct Ref<'a, T: ?Sized> {
ptr: NonNull<T>,
_lifetime: std::marker::PhantomData<&'a T>,
}
impl<T> Ref<'_, T>
where
T: ?Sized,
{
pub unsafe fn new(ptr: NonNull<T>) -> Self {
Self {
ptr,
_lifetime: std::marker::PhantomData,
}
}
}
impl<T> Clone for Ref<'_, T>
where
T: ?Sized,
{
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for Ref<'_, T> where T: ?Sized {}
unsafe impl<T> Send for Ref<'_, T> where T: ?Sized + Send {}
unsafe impl<T> Sync for Ref<'_, T> where T: ?Sized + Sync {}
impl<T> Deref for Ref<'_, T>
where
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.ptr.as_ptr().cast_const() }
}
}
impl<T> CopyRef for Ref<'_, T>
where
T: Copy,
{
type Target = T;
fn copy_ref(&self) -> T {
unsafe { self.ptr.read_unaligned() }
}
}
impl<'a, T> From<&'a T> for Ref<'a, T>
where
T: ?Sized,
{
fn from(r: &'a T) -> Self {
unsafe { Self::new(r.into()) }
}
}
impl<'a, T> From<&'a mut T> for Ref<'a, T>
where
T: ?Sized,
{
fn from(r: &'a mut T) -> Self {
unsafe { Self::new(r.into()) }
}
}
impl<'short, T> Reborrow<'short> for Ref<'_, T>
where
T: ?Sized,
{
type Target = Ref<'short, T>;
fn reborrow(&'short self) -> Self::Target {
*self
}
}
#[derive(Debug)]
pub struct Mut<'a, T: ?Sized> {
ptr: NonNull<T>,
_lifetime: std::marker::PhantomData<&'a mut T>,
}
unsafe impl<T> Send for Mut<'_, T> where T: ?Sized + Send {}
unsafe impl<T> Sync for Mut<'_, T> where T: ?Sized + Sync {}
impl<T> Mut<'_, T>
where
T: ?Sized,
{
pub unsafe fn new(ptr: NonNull<T>) -> Self {
Self {
ptr,
_lifetime: std::marker::PhantomData,
}
}
}
impl<T> Deref for Mut<'_, T>
where
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.ptr.as_ptr().cast_const() }
}
}
impl<T> DerefMut for Mut<'_, T>
where
T: ?Sized,
{
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.ptr.as_ptr() }
}
}
impl<T> CopyRef for Mut<'_, T>
where
T: Copy,
{
type Target = T;
fn copy_ref(&self) -> T {
unsafe { self.ptr.read_unaligned() }
}
}
impl<T> CopyMut for Mut<'_, T>
where
T: Copy,
{
fn copy_mut(&mut self, value: T) {
unsafe { self.ptr.write_unaligned(value) }
}
}
impl<'a, T> From<&'a mut T> for Mut<'a, T>
where
T: ?Sized,
{
fn from(r: &'a mut T) -> Self {
unsafe { Self::new(r.into()) }
}
}
impl<'short, T> Reborrow<'short> for Mut<'_, T>
where
T: ?Sized,
{
type Target = Ref<'short, T>;
fn reborrow(&'short self) -> Self::Target {
unsafe { Ref::new(self.ptr) }
}
}
impl<'short, T> ReborrowMut<'short> for Mut<'_, T>
where
T: ?Sized,
{
type Target = Mut<'short, T>;
fn reborrow_mut(&'short mut self) -> Self::Target {
unsafe { Mut::new(self.ptr) }
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_is_ref<T>(_x: Ref<'_, T>) {}
fn assert_is_mut_ref<T>(_x: Mut<'_, T>) {}
#[test]
fn test_owned() {
let from: f32 = 10.0;
let mut owned: Owned<f32> = from.into();
assert_eq!(*owned, 10.0);
*owned.deref_mut() = 5.0;
assert_eq!(*owned, 5.0);
assert_eq!(owned.reborrow().copy_ref(), 5.0);
assert_is_ref(owned.reborrow());
owned.reborrow_mut().copy_mut(1.0);
assert_eq!(*owned, 1.0);
assert_is_mut_ref(owned.reborrow_mut());
}
#[test]
fn test_ref() {
let from: f32 = 10.0;
let r: Ref<f32> = (&from).into();
assert_eq!(r.copy_ref(), 10.0);
assert_eq!(*r, 10.0);
assert_eq!(r.reborrow().copy_ref(), 10.0);
assert_is_ref(r.reborrow());
let mut from: f32 = 10.0;
let r: Ref<f32> = (&mut from).into();
assert_eq!(r.copy_ref(), 10.0);
}
#[test]
fn test_ref_mut() {
let mut from: f32 = 10.0;
let mut r: Mut<f32> = (&mut from).into();
assert_eq!(r.copy_ref(), 10.0);
assert_eq!(*r, 10.0);
assert_eq!(r.reborrow().copy_ref(), 10.0);
assert_is_ref(r.reborrow());
assert_is_mut_ref(r.reborrow_mut());
r.copy_mut(5.0);
assert_eq!(r.copy_ref(), 5.0);
assert_eq!(*r, 5.0);
assert_eq!(from, 5.0);
let mut r: Mut<f32> = (&mut from).into();
*r = 10.0;
assert_eq!(r.copy_ref(), 10.0);
assert_eq!(*r, 10.0);
assert_eq!(from, 10.0);
let mut r: Mut<f32> = (&mut from).into();
r.reborrow_mut().copy_mut(1.0);
assert_eq!(r.copy_ref(), 1.0);
assert_eq!(*r, 1.0);
assert_eq!(from, 1.0);
}
}