use remoc::{
codec,
rch::{mpsc, oneshot},
robj::rw_lock::LockError,
RemoteSend,
};
use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt,
ops::{Deref, DerefMut},
sync::{Arc, Weak},
};
use tokio::sync::{
RwLock as TokioRwLock, RwLockReadGuard as TokioRwLockReadGuard,
RwLockWriteGuard as TokioRwLockWriteGuard,
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CommitError {
Dropped,
Failed,
}
impl fmt::Display for CommitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dropped => write!(f, "owner dropped"),
Self::Failed => write!(f, "commit failed"),
}
}
}
impl<T> From<oneshot::SendError<T>> for CommitError {
fn from(err: oneshot::SendError<T>) -> Self {
match err {
oneshot::SendError::Closed(_) => Self::Dropped,
oneshot::SendError::Failed => Self::Failed,
}
}
}
impl From<oneshot::RecvError> for CommitError {
fn from(_: oneshot::RecvError) -> Self {
Self::Failed
}
}
impl Error for CommitError {}
pub struct RwLock<T, Codec = codec::Default> {
owner: Option<RwLockOwner<T, Codec>>,
remote: Arc<RwLockRemote<T, Codec>>,
}
impl<T, Codec> RwLock<T, Codec> {
pub fn owner(&self) -> Option<&Arc<TokioRwLock<T>>> {
self.owner.as_ref().map(|owner| &owner.value)
}
pub async fn owner_read(&self) -> Option<TokioRwLockReadGuard<'_, T>> {
let owner = self.owner.as_ref()?.value.read().await;
Some(owner)
}
pub async fn owner_write(&self) -> Option<TokioRwLockWriteGuard<T>> {
let owner = self.owner.as_ref()?.value.write().await;
Some(owner)
}
fn new_remote(remote: RwLockRemote<T, Codec>) -> Self {
Self {
owner: None,
remote: Arc::new(remote),
}
}
}
impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> RwLock<T, Codec> {
pub fn new(value: T) -> Self {
let (read_req_tx, read_req_rx) = mpsc::channel(1);
let read_req_tx = read_req_tx.set_buffer();
let read_req_rx = read_req_rx.set_buffer();
let (write_req_tx, write_req_rx) = mpsc::channel(1);
let write_req_tx = write_req_tx.set_buffer();
let write_req_rx = write_req_rx.set_buffer();
let owner = RwLockOwner {
value: Arc::new(TokioRwLock::new(value)),
drop_chanel: Arc::new(TokioRwLock::new(None)),
};
let weak_owner = owner.as_weak();
let rw_lock = Self {
owner: Some(owner),
remote: Arc::new(RwLockRemote {
read_req_tx,
write_req_tx,
is_frivolous: false,
}),
};
tokio::spawn(Self::handle_owner_requests(
weak_owner,
read_req_rx,
write_req_rx,
));
rw_lock
}
#[doc(hidden)]
pub fn new_frivolous(value: T) -> Self {
let mut this = Self::new(value);
let remote = Arc::get_mut(&mut this.remote).unwrap();
remote.is_frivolous = true;
this
}
pub async fn read(&self) -> Result<RwLockReadGuard<T, Codec>, LockError> {
if let Some(owner) = self.owner.as_ref() {
let (value_guard, _remote_gruard) =
tokio::join!(owner.value.read(), owner.drop_chanel.read());
return Ok(RwLockReadGuard::new_owner(value_guard));
}
let (value_tx, value_rx) = oneshot::channel();
let _ = self.remote.read_req_tx.send(ReadRequest { value_tx }).await;
let value = value_rx.await?;
Ok(RwLockReadGuard::new_remote(value))
}
pub async fn write(&self) -> Result<RwLockWriteGuard<T, Codec>, LockError> {
if let Some(owner) = self.owner.as_ref() {
let (value_guard, _remote_gruard) = tokio::join!(
owner.value.write(),
drop_remote_read_guard(&owner.drop_chanel)
);
if self.remote.is_frivolous {
return Ok(RwLockWriteGuard::new_owner_frivolous(value_guard));
}
return Ok(RwLockWriteGuard::new_owner(value_guard));
}
let (value_tx, value_rx) = oneshot::channel();
let (new_value_tx, new_value_rx) = oneshot::channel();
let (confirm_tx, confirm_rx) = oneshot::channel();
let _ = self
.remote
.write_req_tx
.send(WriteRequest {
value_tx,
new_value_rx,
confirm_tx,
})
.await;
let value = value_rx.await?;
Ok(RwLockWriteGuard::new_remote(
value,
new_value_tx,
confirm_rx,
))
}
async fn handle_owner_requests(
weak_owner: WeakRwLockOwner<T, Codec>,
mut read_req_rx: mpsc::Receiver<ReadRequest<T, Codec>, Codec, 1>,
mut write_req_rx: mpsc::Receiver<WriteRequest<T, Codec>, Codec, 1>,
) {
loop {
tokio::select! {
biased;
res = write_req_rx.recv() => {
let WriteRequest {value_tx, new_value_rx, confirm_tx} = match res {
Ok(Some(req)) => req,
Ok(None) => break,
Err(err) if err.is_final() => break,
Err(_) => continue,
};
let Some(owner) = weak_owner.upgrade() else {
break
};
{
let _remote_write_guard = drop_remote_read_guard(&owner.drop_chanel).await;
let remote_value = owner.value.write().await.clone();
if value_tx.send(remote_value).is_err() {
continue
}
if let Ok(new_value) = new_value_rx.await {
*owner.value.write().await = new_value;
let _ = confirm_tx.send(());
}
}
}
res = read_req_rx.recv() => {
let ReadRequest {value_tx} = match res {
Ok(Some(req)) => req,
Ok(None) => break,
Err(err) if err.is_final() => break,
Err(_) => continue,
};
let Some(remote_value) = weak_owner.make_remote_value().await else {
break
};
let _ = value_tx.send(remote_value);
}
}
}
}
}
impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Serialize for RwLock<T, Codec> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.remote.serialize(serializer)
}
}
impl<'de, T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Deserialize<'de>
for RwLock<T, Codec>
{
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let remote = RwLockRemote::deserialize(deserializer)?;
Ok(Self::new_remote(remote))
}
}
impl<T, Codec> Clone for RwLock<T, Codec> {
fn clone(&self) -> Self {
Self {
owner: self.owner.clone(),
remote: self.remote.clone(),
}
}
}
pub struct RwLockReadGuard<'a, T, Codec = codec::Default> {
inner: RwLockReadGuardInner<'a, T, Codec>,
}
impl<'a, T, Codec> RwLockReadGuard<'a, T, Codec> {
fn new_owner(value_guard: TokioRwLockReadGuard<'a, T>) -> Self {
Self {
inner: RwLockReadGuardInner::Owner(value_guard),
}
}
fn new_remote(value: RemoteValue<T, Codec>) -> Self {
Self {
inner: RwLockReadGuardInner::Remote(value),
}
}
}
impl<'a, T, Codec> Deref for RwLockReadGuard<'a, T, Codec> {
type Target = T;
fn deref(&self) -> &Self::Target {
match &self.inner {
RwLockReadGuardInner::Owner(value_guard) => value_guard,
RwLockReadGuardInner::Remote(value) => &value.value,
}
}
}
pub struct RwLockWriteGuard<'a, T, Codec = codec::Default> {
inner: RwLockWriteGuardInner<'a, T, Codec>,
}
impl<'a, T, Codec> RwLockWriteGuard<'a, T, Codec> {
fn new_owner_frivolous(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
Self {
inner: RwLockWriteGuardInner::Owner {
new_value: None,
value_guard,
},
}
}
}
impl<'a, T: Clone, Codec> RwLockWriteGuard<'a, T, Codec> {
fn new_owner(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
Self {
inner: RwLockWriteGuardInner::Owner {
new_value: Some(value_guard.clone()),
value_guard,
},
}
}
fn new_remote(
value: T,
new_value_tx: oneshot::Sender<T, Codec>,
confirm_rx: oneshot::Receiver<(), Codec>,
) -> Self {
Self {
inner: RwLockWriteGuardInner::Remote {
value,
new_value_tx,
confirm_rx,
},
}
}
}
impl<'a, T: RemoteSend, Codec: codec::Codec> RwLockWriteGuard<'a, T, Codec> {
pub async fn commit(self) -> Result<(), CommitError> {
match self.inner {
RwLockWriteGuardInner::Owner {
new_value,
mut value_guard,
} => {
if let Some(new_value) = new_value {
*value_guard = new_value;
}
Ok(())
}
RwLockWriteGuardInner::Remote {
value,
new_value_tx,
confirm_rx,
} => {
new_value_tx.send(value)?;
confirm_rx.await?;
Ok(())
}
}
}
}
impl<'a, T, Codec> Deref for RwLockWriteGuard<'a, T, Codec> {
type Target = T;
fn deref(&self) -> &Self::Target {
match &self.inner {
RwLockWriteGuardInner::Owner {
new_value,
value_guard,
} => new_value.as_ref().unwrap_or(value_guard),
RwLockWriteGuardInner::Remote { value, .. } => &value,
}
}
}
impl<'a, T, Codec> DerefMut for RwLockWriteGuard<'a, T, Codec> {
fn deref_mut(&mut self) -> &mut Self::Target {
match &mut self.inner {
RwLockWriteGuardInner::Owner {
new_value,
value_guard,
} => new_value.as_mut().unwrap_or(value_guard),
RwLockWriteGuardInner::Remote { value, .. } => value,
}
}
}
struct RwLockOwner<T, Codec> {
value: Arc<TokioRwLock<T>>,
drop_chanel:
Arc<TokioRwLock<Option<(mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>)>>>,
}
impl<T, Codec> RwLockOwner<T, Codec> {
fn as_weak(&self) -> WeakRwLockOwner<T, Codec> {
WeakRwLockOwner {
value: Arc::downgrade(&self.value),
drop_chanel: Arc::downgrade(&self.drop_chanel),
}
}
}
impl<T, Codec> Clone for RwLockOwner<T, Codec> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
drop_chanel: self.drop_chanel.clone(),
}
}
}
type DropChanel<Codec> = (mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>);
struct WeakRwLockOwner<T, Codec> {
value: Weak<TokioRwLock<T>>,
drop_chanel: Weak<TokioRwLock<Option<DropChanel<Codec>>>>,
}
impl<T, Codec> WeakRwLockOwner<T, Codec> {
fn upgrade(&self) -> Option<RwLockOwner<T, Codec>> {
Some(RwLockOwner {
value: self.value.upgrade()?,
drop_chanel: self.drop_chanel.upgrade()?,
})
}
}
impl<T: Clone, Codec> WeakRwLockOwner<T, Codec> {
async fn make_remote_value(&self) -> Option<RemoteValue<T, Codec>> {
let value = self.value.upgrade()?;
let drop_chanel = self.drop_chanel.upgrade()?;
let dropped_tx = drop_chanel
.read()
.await
.as_ref()
.map(|drop_chanel| drop_chanel.0.clone());
let dropped_tx = if let Some(dropped_tx) = dropped_tx {
dropped_tx
} else {
let (dropped_tx, dropped_rx) = mpsc::channel(1);
let dropped_tx = dropped_tx.set_buffer();
let dropped_rx = dropped_rx.set_buffer();
{
let mut drop_chanel = drop_chanel.write().await;
*drop_chanel = Some((dropped_tx.clone(), dropped_rx));
}
dropped_tx
};
let value = value.read().await;
Some(RemoteValue {
value: value.clone(),
dropped_tx: dropped_tx.clone(),
})
}
}
async fn drop_remote_read_guard<Codec>(
drop_chanel: &Arc<TokioRwLock<Option<DropChanel<Codec>>>>,
) -> TokioRwLockWriteGuard<Option<DropChanel<Codec>>> {
let mut drop_chanel_write_guard = drop_chanel.write().await;
if let Some(drop_chanel) = drop_chanel_write_guard.take() {
let (dropped_tx, mut dropped_rx) = drop_chanel;
drop(dropped_tx);
loop {
if let Ok(None) = dropped_rx.recv().await {
break;
}
}
}
drop_chanel_write_guard
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
struct RwLockRemote<T, Codec> {
read_req_tx: mpsc::Sender<ReadRequest<T, Codec>, Codec, 1>,
write_req_tx: mpsc::Sender<WriteRequest<T, Codec>, Codec, 1>,
is_frivolous: bool,
}
enum RwLockReadGuardInner<'a, T, Codec> {
Owner(TokioRwLockReadGuard<'a, T>),
Remote(RemoteValue<T, Codec>),
}
enum RwLockWriteGuardInner<'a, T, Codec> {
Owner {
new_value: Option<T>,
value_guard: TokioRwLockWriteGuard<'a, T>,
},
Remote {
value: T,
new_value_tx: oneshot::Sender<T, Codec>,
confirm_rx: oneshot::Receiver<(), Codec>,
},
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
struct RemoteValue<T, Codec = codec::Default> {
value: T,
dropped_tx: mpsc::Sender<(), Codec, 1>,
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
struct ReadRequest<T, Codec = codec::Default> {
value_tx: oneshot::Sender<RemoteValue<T, Codec>, Codec>,
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
struct WriteRequest<T, Codec = codec::Default> {
value_tx: oneshot::Sender<T, Codec>,
new_value_rx: oneshot::Receiver<T, Codec>,
confirm_tx: oneshot::Sender<(), Codec>,
}