use core::{
any::Any,
ops::{Deref, DerefMut},
};
use alloc::{boxed::Box, sync::Arc};
use spin::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::{Error, error::Result};
pub type EntryPtr = Arc<RwLock<EntryData>>;
pub struct EntryData {
pub(crate) sequence_id: usize,
pub(crate) data: Box<dyn Any + Send + Sync>,
}
impl Deref for EntryData {
type Target = Box<dyn Any + Send + Sync>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for EntryData {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl EntryData {
pub fn new<T: Any + Send + Sync>(value: T) -> Self {
Self {
data: Box::new(value),
sequence_id: 1,
}
}
pub fn data(&self) -> &Box<dyn Any + Send + Sync> {
&self.data
}
pub const fn sequence_id(&self) -> usize {
self.sequence_id
}
}
pub struct EntryReadGuard<T: Any + Send + Sync> {
entry: EntryPtr,
ptr_t: *const T,
}
impl<T: Any + Send + Sync> Deref for EntryReadGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
#[allow(unsafe_code)]
unsafe {
&*self.ptr_t
}
}
}
impl<T: Any + Send + Sync> Drop for EntryReadGuard<T> {
fn drop(&mut self) {
#[allow(unsafe_code)]
unsafe {
self.entry.force_read_decrement();
}
}
}
impl<T: Any + Send + Sync> EntryReadGuard<T> {
pub fn new(key: &str, entry: EntryPtr) -> Result<Self> {
let ptr_t = {
let guard = entry.read();
let x = &RwLockReadGuard::leak(guard).data;
if let Some(t) = x.downcast_ref::<T>() {
let ptr_t: *const T = t;
ptr_t
} else {
#[allow(unsafe_code)]
unsafe {
entry.force_read_decrement();
}
return Err(Error::WrongType { key: key.into() });
}
};
Ok(Self { entry, ptr_t })
}
pub fn try_new(key: &str, entry: &EntryPtr) -> Result<Self> {
let ptr_t = {
if let Some(guard) = entry.try_read() {
let x = &RwLockReadGuard::leak(guard).data;
if let Some(t) = x.downcast_ref::<T>() {
let ptr_t: *const T = t;
ptr_t
} else {
#[allow(unsafe_code)]
unsafe {
entry.force_read_decrement();
}
return Err(Error::WrongType { key: key.into() });
}
} else {
return Err(Error::IsLocked { key: key.into() });
}
};
Ok(Self {
entry: entry.clone(),
ptr_t,
})
}
}
pub struct EntryWriteGuard<T: Any + Send + Sync> {
entry: EntryPtr,
ptr_t: *mut T,
ptr_seq_id: *mut usize,
modified: bool,
}
impl<T: Any + Send + Sync> Deref for EntryWriteGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
#[allow(unsafe_code)]
unsafe {
&*self.ptr_t
}
}
}
impl<T: Any + Send + Sync> DerefMut for EntryWriteGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.modified = true;
#[allow(unsafe_code)]
unsafe {
&mut *self.ptr_t
}
}
}
impl<T: Any + Send + Sync> Drop for EntryWriteGuard<T> {
fn drop(&mut self) {
#[allow(unsafe_code)]
unsafe {
if self.modified {
*self.ptr_seq_id += 1;
}
self.entry.force_write_unlock();
}
}
}
impl<T: Any + Send + Sync> EntryWriteGuard<T> {
pub fn new(key: &str, entry: &EntryPtr) -> Result<Self> {
let (ptr_t, ptr_seq_id) = {
let mut guard = entry.write();
let ptr_seq_id: *mut usize = &raw mut guard.sequence_id;
let x = &mut RwLockWriteGuard::leak(guard).data;
if let Some(t) = x.downcast_mut::<T>() {
let ptr_t: *mut T = t;
(ptr_t, ptr_seq_id)
} else {
#[allow(unsafe_code)]
unsafe {
entry.force_write_unlock();
}
return Err(Error::WrongType { key: key.into() });
}
};
Ok(Self {
entry: entry.clone(),
ptr_t,
ptr_seq_id,
modified: false,
})
}
pub fn try_new(key: &str, entry: &EntryPtr) -> Result<Self> {
let (ptr_t, ptr_seq_id) = {
if let Some(mut guard) = entry.try_write() {
let ptr_seq_id: *mut usize = &raw mut guard.sequence_id;
let x = &mut RwLockWriteGuard::leak(guard).data;
if let Some(t) = x.downcast_mut::<T>() {
let ptr_t: *mut T = t;
(ptr_t, ptr_seq_id)
} else {
#[allow(unsafe_code)]
unsafe {
entry.force_write_unlock();
}
return Err(Error::WrongType { key: key.into() });
}
} else {
return Err(Error::IsLocked { key: key.into() });
}
};
Ok(Self {
entry: entry.clone(),
ptr_t,
ptr_seq_id,
modified: false,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug)]
struct Dummy {
_data: i32,
}
const fn is_normal<T: Sized + Send + Sync>() {}
#[test]
const fn normal_types() {
is_normal::<Dummy>();
is_normal::<EntryData>();
is_normal::<EntryPtr>();
}
}