use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt,
ops::{Deref, DerefMut},
sync::Arc,
};
use tracing::Instrument;
use super::msg::{ReadRequest, Value, WriteRequest};
use crate::{
RemoteSend, chmux, codec, exec,
rch::{base, mpsc, oneshot},
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum LockError {
Dropped,
RemoteReceive(base::RecvError),
RemoteConnect(chmux::ConnectError),
RemoteListen(chmux::ListenerError),
}
impl fmt::Display for LockError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dropped => write!(f, "owner dropped"),
Self::RemoteReceive(err) => write!(f, "receive error: {err}"),
Self::RemoteConnect(err) => write!(f, "connect error: {err}"),
Self::RemoteListen(err) => write!(f, "listen error: {err}"),
}
}
}
impl From<oneshot::RecvError> for LockError {
fn from(err: oneshot::RecvError) -> Self {
match err {
oneshot::RecvError::Closed => Self::Dropped,
oneshot::RecvError::RemoteReceive(err) => Self::RemoteReceive(err),
oneshot::RecvError::RemoteConnect(err) => Self::RemoteConnect(err),
oneshot::RecvError::RemoteListen(err) => Self::RemoteListen(err),
}
}
}
impl Error for LockError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum CommitError {
Dropped,
Failed,
}
impl fmt::Display for CommitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dropped => write!(f, "owner dropped"),
Self::Failed => write!(f, "commit failed"),
}
}
}
impl<T> From<oneshot::SendError<T>> for CommitError {
fn from(err: oneshot::SendError<T>) -> Self {
match err {
oneshot::SendError::Closed(_) => Self::Dropped,
oneshot::SendError::Failed => Self::Failed,
}
}
}
impl From<oneshot::RecvError> for CommitError {
fn from(_: oneshot::RecvError) -> Self {
Self::Failed
}
}
impl Error for CommitError {}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
pub struct ReadLock<T, Codec = codec::Default> {
req_tx: mpsc::Sender<ReadRequest<T, Codec>, Codec, 1>,
#[serde(skip)]
#[serde(default = "empty_cache")]
cache: Arc<tokio::sync::RwLock<Option<Value<T, Codec>>>>,
}
fn empty_cache<T, Codec>() -> Arc<tokio::sync::RwLock<Option<Value<T, Codec>>>> {
Arc::new(tokio::sync::RwLock::new(None))
}
impl<T, Codec> Clone for ReadLock<T, Codec> {
fn clone(&self) -> Self {
Self { req_tx: self.req_tx.clone(), cache: self.cache.clone() }
}
}
impl<T, Codec> fmt::Debug for ReadLock<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ReadLock").finish()
}
}
impl<T, Codec> ReadLock<T, Codec>
where
T: RemoteSend + Sync,
Codec: codec::Codec,
{
pub(crate) fn new(read_req_tx: mpsc::Sender<ReadRequest<T, Codec>, Codec, 1>) -> Self {
Self { req_tx: read_req_tx, cache: empty_cache() }
}
async fn fetch(&self) -> Result<tokio::sync::RwLockReadGuard<'_, Value<T, Codec>>, LockError> {
{
let cache_opt = self.cache.read().await;
match &*cache_opt {
Some(cache) if cache.is_valid() => {
return Ok(tokio::sync::RwLockReadGuard::map(cache_opt, |co| co.as_ref().unwrap()));
}
_ => (),
}
}
let mut cache_opt = self.cache.write().await;
let (value_tx, value_rx) = oneshot::channel();
let _ = self.req_tx.send(ReadRequest { value_tx }).await;
let value = value_rx.await?;
let mut invalid_rx = value.invalid_rx.clone();
let cache_lock = self.cache.clone();
exec::spawn(
async move {
loop {
match invalid_rx.borrow_and_update() {
Ok(invalid) if !*invalid => (),
_ => break,
}
if invalid_rx.changed().await.is_err() {
break;
}
}
let mut cache_opt = cache_lock.write().await;
match &*cache_opt {
Some(cache) if !cache.is_valid() => *cache_opt = None,
_ => (),
}
}
.in_current_span(),
);
*cache_opt = Some(value);
Ok(tokio::sync::RwLockReadGuard::map(tokio::sync::RwLockWriteGuard::downgrade(cache_opt), |co| {
co.as_ref().unwrap()
}))
}
pub async fn read(&self) -> Result<ReadGuard<'_, T, Codec>, LockError> {
let cache = self.fetch().await?;
Ok(ReadGuard(cache))
}
}
pub struct ReadGuard<'a, T, Codec = codec::Default>(tokio::sync::RwLockReadGuard<'a, Value<T, Codec>>);
impl<T, Codec> ReadGuard<'_, T, Codec>
where
Codec: codec::Codec,
{
pub async fn invalidated(&self) {
let mut invalid_rx = self.0.invalid_rx.clone();
while !invalid_rx.borrow_and_update().map(|v| *v).unwrap_or_default() {
if invalid_rx.changed().await.is_err() {
break;
}
}
}
pub fn is_invalidated(&self) -> bool {
self.0.invalid_rx.borrow().map(|v| *v).unwrap_or(true)
}
}
impl<T, Codec> Deref for ReadGuard<'_, T, Codec> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0.value
}
}
impl<T, Codec> fmt::Debug for ReadGuard<'_, T, Codec>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}
impl<T, Codec> Drop for ReadGuard<'_, T, Codec> {
fn drop(&mut self) {
}
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
pub struct RwLock<T, Codec = codec::Default> {
read: ReadLock<T, Codec>,
req_tx: mpsc::Sender<WriteRequest<T, Codec>, Codec, 1>,
}
impl<T, Codec> Clone for RwLock<T, Codec> {
fn clone(&self) -> Self {
Self { read: self.read.clone(), req_tx: self.req_tx.clone() }
}
}
impl<T, Codec> fmt::Debug for RwLock<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("RwLock").finish()
}
}
impl<T, Codec> RwLock<T, Codec>
where
T: RemoteSend + Sync,
Codec: codec::Codec,
{
pub(crate) fn new(
read_lock: ReadLock<T, Codec>, write_req_tx: mpsc::Sender<WriteRequest<T, Codec>, Codec, 1>,
) -> Self {
Self { read: read_lock, req_tx: write_req_tx }
}
pub async fn read(&self) -> Result<ReadGuard<'_, T, Codec>, LockError> {
self.read.read().await
}
pub async fn write(&self) -> Result<WriteGuard<T, Codec>, LockError> {
let (value_tx, value_rx) = oneshot::channel();
let (new_value_tx, new_value_rx) = oneshot::channel();
let (confirm_tx, confirm_rx) = oneshot::channel();
let _ = self.req_tx.send(WriteRequest { value_tx, new_value_rx, confirm_tx }).await;
let value = value_rx.await?;
Ok(WriteGuard { value: Some(value), new_value_tx: Some(new_value_tx), confirm_rx: Some(confirm_rx) })
}
pub fn read_lock(&self) -> ReadLock<T, Codec> {
self.read.clone()
}
}
pub struct WriteGuard<T, Codec = codec::Default> {
value: Option<T>,
new_value_tx: Option<oneshot::Sender<T, Codec>>,
confirm_rx: Option<oneshot::Receiver<(), Codec>>,
}
impl<T, Codec> WriteGuard<T, Codec>
where
T: RemoteSend,
Codec: codec::Codec,
{
pub async fn commit(mut self) -> Result<(), CommitError> {
let new_value = self.value.take().unwrap();
let new_value_tx = self.new_value_tx.take().unwrap();
new_value_tx.send(new_value)?;
let confirm_rx = self.confirm_rx.take().unwrap();
confirm_rx.await?;
Ok(())
}
}
impl<T, Codec> Deref for WriteGuard<T, Codec> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value.as_ref().unwrap()
}
}
impl<T, Codec> DerefMut for WriteGuard<T, Codec> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value.as_mut().unwrap()
}
}
impl<T, Codec> fmt::Debug for WriteGuard<T, Codec>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}
impl<T, Codec> Drop for WriteGuard<T, Codec> {
fn drop(&mut self) {
}
}