use std::marker::PhantomData;
use std::ops::Deref;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex as SyncMutex};
use tokio::sync::{Mutex, MutexGuard};
use crate::internals::lincowcell::LinCowCellCapable;
#[derive(Debug)]
pub struct LinCowCell<T, R, U> {
updater: PhantomData<U>,
write: Mutex<T>,
active: SyncMutex<Arc<LinCowCellInner<R>>>,
}
#[derive(Debug)]
pub struct LinCowCellWriteTxn<'a, T, R, U> {
caller: &'a LinCowCell<T, R, U>,
guard: MutexGuard<'a, T>,
work: U,
}
#[derive(Debug)]
struct LinCowCellInner<R> {
pin: SyncMutex<Option<Arc<LinCowCellInner<R>>>>,
data: R,
}
#[derive(Debug)]
pub struct LinCowCellReadTxn<'a, T, R, U> {
_caller: &'a LinCowCell<T, R, U>,
work: Arc<LinCowCellInner<R>>,
}
impl<R> LinCowCellInner<R> {
pub fn new(data: R) -> Self {
LinCowCellInner {
pin: SyncMutex::new(None),
data,
}
}
}
impl<R> Drop for LinCowCellInner<R> {
fn drop(&mut self) {
let mut current = self
.pin
.get_mut()
.map(|pin| pin.take())
.unwrap_or_else(|e| e.into_inner().take());
while let Some(arc) = current {
match Arc::into_inner(arc) {
Some(mut inner) => {
current = inner
.pin
.get_mut()
.map(|pin| pin.take())
.unwrap_or_else(|e| e.into_inner().take());
}
None => {
break;
}
}
}
}
}
impl<T, R, U> LinCowCell<T, R, U>
where
T: LinCowCellCapable<R, U>,
{
pub fn new(data: T) -> Self {
let r = data.create_reader();
LinCowCell {
updater: PhantomData,
write: Mutex::new(data),
active: SyncMutex::new(Arc::new(LinCowCellInner::new(r))),
}
}
pub fn read(&self) -> LinCowCellReadTxn<'_, T, R, U> {
let rwguard = self.active.lock().unwrap();
LinCowCellReadTxn {
_caller: self,
work: rwguard.clone(),
}
}
pub async fn write<'x>(&'x self) -> LinCowCellWriteTxn<'x, T, R, U> {
let write_guard = self.write.lock().await;
let work: U = (*write_guard).create_writer();
LinCowCellWriteTxn {
caller: self,
guard: write_guard,
work,
}
}
pub fn try_write(&self) -> Option<LinCowCellWriteTxn<'_, T, R, U>> {
self.write
.try_lock()
.map(|write_guard| {
let work: U = (*write_guard).create_writer();
LinCowCellWriteTxn {
caller: self,
guard: write_guard,
work,
}
})
.ok()
}
fn commit(&self, write: LinCowCellWriteTxn<'_, T, R, U>) {
let LinCowCellWriteTxn {
caller: _caller,
mut guard,
work,
} = write;
let mut rwguard = self.active.lock().unwrap();
let newdata = guard.pre_commit(work, &rwguard.data);
let new_inner = Arc::new(LinCowCellInner::new(newdata));
{
let mut rwguard_inner = rwguard.pin.lock().unwrap();
*rwguard_inner = Some(new_inner.clone());
}
*rwguard = new_inner;
}
}
impl<T, R, U> Deref for LinCowCellReadTxn<'_, T, R, U> {
type Target = R;
#[inline]
fn deref(&self) -> &R {
&self.work.data
}
}
impl<T, R, U> AsRef<R> for LinCowCellReadTxn<'_, T, R, U> {
#[inline]
fn as_ref(&self) -> &R {
&self.work.data
}
}
impl<T, R, U> LinCowCellWriteTxn<'_, T, R, U>
where
T: LinCowCellCapable<R, U>,
{
#[inline]
pub fn get_mut(&mut self) -> &mut U {
&mut self.work
}
pub fn commit(self) {
self.caller.commit(self);
}
}
impl<T, R, U> Deref for LinCowCellWriteTxn<'_, T, R, U> {
type Target = U;
#[inline]
fn deref(&self) -> &U {
&self.work
}
}
impl<T, R, U> DerefMut for LinCowCellWriteTxn<'_, T, R, U> {
#[inline]
fn deref_mut(&mut self) -> &mut U {
&mut self.work
}
}
impl<T, R, U> AsRef<U> for LinCowCellWriteTxn<'_, T, R, U> {
#[inline]
fn as_ref(&self) -> &U {
&self.work
}
}
impl<T, R, U> AsMut<U> for LinCowCellWriteTxn<'_, T, R, U> {
#[inline]
fn as_mut(&mut self) -> &mut U {
&mut self.work
}
}
#[cfg(test)]
mod tests {
use super::LinCowCell;
use super::LinCowCellCapable;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug)]
struct TestData {
x: i64,
}
#[derive(Debug)]
struct TestDataReadTxn {
x: i64,
}
#[derive(Debug)]
struct TestDataWriteTxn {
x: i64,
}
impl LinCowCellCapable<TestDataReadTxn, TestDataWriteTxn> for TestData {
fn create_reader(&self) -> TestDataReadTxn {
TestDataReadTxn { x: self.x }
}
fn create_writer(&self) -> TestDataWriteTxn {
TestDataWriteTxn { x: self.x }
}
fn pre_commit(
&mut self,
new: TestDataWriteTxn,
_prev: &TestDataReadTxn,
) -> TestDataReadTxn {
self.x = new.x;
TestDataReadTxn { x: new.x }
}
}
#[tokio::test]
async fn test_simple_create() {
let data = TestData { x: 0 };
let cc = LinCowCell::new(data);
let cc_rotxn_a = cc.read();
println!("cc_rotxn_a -> {:?}", cc_rotxn_a);
assert_eq!(cc_rotxn_a.work.data.x, 0);
{
let mut cc_wrtxn = cc.write().await;
println!("cc_wrtxn -> {:?}", cc_wrtxn);
assert_eq!(cc_wrtxn.work.x, 0);
assert_eq!(cc_wrtxn.as_ref().x, 0);
{
let mut_ptr = cc_wrtxn.get_mut();
assert_eq!(mut_ptr.x, 0);
mut_ptr.x = 1;
assert_eq!(mut_ptr.x, 1);
}
assert_eq!(cc_rotxn_a.work.data.x, 0);
}
assert_eq!(cc_rotxn_a.work.data.x, 0);
{
let mut cc_wrtxn = cc.write().await;
println!("cc_wrtxn -> {:?}", cc_wrtxn);
assert_eq!(cc_wrtxn.work.x, 0);
assert_eq!(cc_wrtxn.as_ref().x, 0);
{
let mut_ptr = cc_wrtxn.get_mut();
assert_eq!(mut_ptr.x, 0);
mut_ptr.x = 2;
assert_eq!(mut_ptr.x, 2);
}
assert_eq!(cc_rotxn_a.work.data.x, 0);
cc_wrtxn.commit();
}
assert_eq!(cc_rotxn_a.work.data.x, 0);
let cc_rotxn_c = cc.read();
assert_eq!(cc_rotxn_c.work.data.x, 2);
}
async fn mt_writer(cc: Arc<LinCowCell<TestData, TestDataReadTxn, TestDataWriteTxn>>) {
let mut last_value: i64 = 0;
while last_value < 500 {
let mut cc_wrtxn = cc.write().await;
{
let mut_ptr = cc_wrtxn.get_mut();
assert!(mut_ptr.x >= last_value);
last_value = mut_ptr.x;
mut_ptr.x += 1;
}
cc_wrtxn.commit();
}
}
fn rt_writer(cc: Arc<LinCowCell<TestData, TestDataReadTxn, TestDataWriteTxn>>) {
let mut last_value: i64 = 0;
while last_value < 500 {
let cc_rotxn = cc.read();
{
assert!(cc_rotxn.work.data.x >= last_value);
last_value = cc_rotxn.work.data.x;
}
}
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn test_concurrent_create() {
use std::time::Instant;
let start = Instant::now();
let data = TestData { x: 0 };
let cc = Arc::new(LinCowCell::new(data));
let _ = tokio::join!(
tokio::task::spawn_blocking({
let cc = cc.clone();
move || rt_writer(cc)
}),
tokio::task::spawn_blocking({
let cc = cc.clone();
move || rt_writer(cc)
}),
tokio::task::spawn_blocking({
let cc = cc.clone();
move || rt_writer(cc)
}),
tokio::task::spawn_blocking({
let cc = cc.clone();
move || rt_writer(cc)
}),
tokio::task::spawn(mt_writer(cc.clone())),
tokio::task::spawn(mt_writer(cc.clone())),
);
let end = Instant::now();
print!("Arc MT create :{:?} ", end - start);
}
static GC_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone)]
struct TestGcWrapper<T> {
data: T,
}
#[derive(Debug)]
struct TestGcWrapperReadTxn<T> {
_data: T,
}
#[derive(Debug)]
struct TestGcWrapperWriteTxn<T> {
data: T,
}
impl<T: Clone> LinCowCellCapable<TestGcWrapperReadTxn<T>, TestGcWrapperWriteTxn<T>>
for TestGcWrapper<T>
{
fn create_reader(&self) -> TestGcWrapperReadTxn<T> {
TestGcWrapperReadTxn {
_data: self.data.clone(),
}
}
fn create_writer(&self) -> TestGcWrapperWriteTxn<T> {
TestGcWrapperWriteTxn {
data: self.data.clone(),
}
}
fn pre_commit(
&mut self,
new: TestGcWrapperWriteTxn<T>,
_prev: &TestGcWrapperReadTxn<T>,
) -> TestGcWrapperReadTxn<T> {
self.data = new.data.clone();
TestGcWrapperReadTxn {
_data: self.data.clone(),
}
}
}
impl<T> Drop for TestGcWrapperReadTxn<T> {
fn drop(&mut self) {
GC_COUNT.fetch_add(1, Ordering::Release);
}
}
async fn test_gc_operation_thread(
cc: Arc<
LinCowCell<TestGcWrapper<i64>, TestGcWrapperReadTxn<i64>, TestGcWrapperWriteTxn<i64>>,
>,
) {
while GC_COUNT.load(Ordering::Acquire) < 50 {
{
let mut cc_wrtxn = cc.write().await;
{
let mut_ptr = cc_wrtxn.get_mut();
mut_ptr.data += 1;
}
cc_wrtxn.commit();
}
}
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn test_gc_operation() {
GC_COUNT.store(0, Ordering::Release);
let data = TestGcWrapper { data: 0 };
let cc = Arc::new(LinCowCell::new(data));
let _ = tokio::join!(
tokio::task::spawn(test_gc_operation_thread(cc.clone())),
tokio::task::spawn(test_gc_operation_thread(cc.clone())),
tokio::task::spawn(test_gc_operation_thread(cc.clone())),
tokio::task::spawn(test_gc_operation_thread(cc.clone())),
);
assert!(GC_COUNT.load(Ordering::Acquire) >= 50);
}
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn test_long_chain_drop_no_stack_overflow() {
let data = TestData { x: 0 };
let cc = LinCowCell::new(data);
let initial_read = cc.read();
for i in 0..10000 {
let mut write_txn = cc.write().await;
write_txn.get_mut().x = i;
write_txn.commit();
}
drop(initial_read);
let final_read = cc.read();
assert_eq!(final_read.work.data.x, 9999);
}
}
#[cfg(test)]
mod tests_linear {
use super::LinCowCell;
use super::LinCowCellCapable;
use std::sync::atomic::{AtomicUsize, Ordering};
static GC_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone)]
struct TestGcWrapper<T> {
data: T,
}
#[derive(Debug)]
struct TestGcWrapperReadTxn<T> {
_data: T,
}
#[derive(Debug)]
struct TestGcWrapperWriteTxn<T> {
data: T,
}
impl<T: Clone> LinCowCellCapable<TestGcWrapperReadTxn<T>, TestGcWrapperWriteTxn<T>>
for TestGcWrapper<T>
{
fn create_reader(&self) -> TestGcWrapperReadTxn<T> {
TestGcWrapperReadTxn {
_data: self.data.clone(),
}
}
fn create_writer(&self) -> TestGcWrapperWriteTxn<T> {
TestGcWrapperWriteTxn {
data: self.data.clone(),
}
}
fn pre_commit(
&mut self,
new: TestGcWrapperWriteTxn<T>,
_prev: &TestGcWrapperReadTxn<T>,
) -> TestGcWrapperReadTxn<T> {
self.data = new.data.clone();
TestGcWrapperReadTxn {
_data: self.data.clone(),
}
}
}
impl<T> Drop for TestGcWrapperReadTxn<T> {
fn drop(&mut self) {
GC_COUNT.fetch_add(1, Ordering::Release);
}
}
#[tokio::test]
async fn test_gc_operation_linear() {
GC_COUNT.store(0, Ordering::Release);
assert!(GC_COUNT.load(Ordering::Acquire) == 0);
let data = TestGcWrapper { data: 0 };
let cc = LinCowCell::new(data);
let cc_rotxn_a = cc.read();
let cc_rotxn_a_2 = cc.read();
{
let mut cc_wrtxn = cc.write().await;
{
let mut_ptr = cc_wrtxn.get_mut();
mut_ptr.data += 1;
}
cc_wrtxn.commit();
}
let cc_rotxn_b = cc.read();
{
let mut cc_wrtxn = cc.write().await;
{
let mut_ptr = cc_wrtxn.get_mut();
mut_ptr.data += 1;
}
cc_wrtxn.commit();
}
let cc_rotxn_c = cc.read();
assert!(GC_COUNT.load(Ordering::Acquire) == 0);
drop(cc_rotxn_b);
assert!(GC_COUNT.load(Ordering::Acquire) == 0);
drop(cc_rotxn_c);
assert!(GC_COUNT.load(Ordering::Acquire) == 0);
drop(cc_rotxn_a_2);
assert!(GC_COUNT.load(Ordering::Acquire) == 0);
drop(cc_rotxn_a);
assert!(GC_COUNT.load(Ordering::Acquire) == 2);
}
}