use serde::{Deserialize, Serialize};
use std::{
fmt,
marker::PhantomData,
mem,
ops::{Deref, DerefMut},
sync::Arc,
};
use tokio::sync::{OwnedRwLockMappedWriteGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard};
use tracing::Instrument;
use uuid::Uuid;
use crate::{
chmux::{AnyBox, AnyEntry},
codec, exec,
rch::{
base::{PortDeserializer, PortSerializer},
mpsc,
},
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HandleError {
Unknown,
MismatchedType(String),
}
impl fmt::Display for HandleError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HandleError::Unknown => write!(f, "unknown, taken or non-local handle"),
HandleError::MismatchedType(ty) => write!(f, "mismatched handle type: {ty}"),
}
}
}
impl std::error::Error for HandleError {}
pub struct Provider {
keep_tx: Option<tokio::sync::watch::Sender<bool>>,
}
impl fmt::Debug for Provider {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Provider").finish()
}
}
impl Provider {
pub fn keep(mut self) {
let _ = self.keep_tx.take().unwrap().send(true);
}
pub async fn done(&mut self) {
self.keep_tx.as_mut().unwrap().closed().await
}
}
impl Drop for Provider {
fn drop(&mut self) {
}
}
#[derive(Clone, Default)]
enum State<Codec> {
#[default]
Empty,
LocalCreated {
entry: AnyEntry,
keep_rx: tokio::sync::watch::Receiver<bool>,
},
LocalReceived {
entry: AnyEntry,
id: Uuid,
dropped_tx: mpsc::Sender<(), Codec, 1>,
},
Remote {
id: Uuid,
dropped_tx: mpsc::Sender<(), Codec, 1>,
},
}
impl<Codec> fmt::Debug for State<Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Empty => write!(f, "Empty"),
Self::LocalCreated { .. } => write!(f, "LocalCreated"),
Self::LocalReceived { id, .. } => f.debug_struct("LocalReceived").field("id", id).finish(),
Self::Remote { id, .. } => f.debug_struct("Remote").field("id", id).finish(),
}
}
}
#[derive(Clone)]
pub struct Handle<T, Codec = codec::Default> {
state: State<Codec>,
_data: PhantomData<T>,
}
impl<T, Codec> fmt::Debug for Handle<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &self.state)
}
}
impl<T, Codec> Handle<T, Codec>
where
T: Send + Sync + 'static,
Codec: codec::Codec,
{
pub fn new(value: T) -> Self {
let (handle, provider) = Self::provided(value);
provider.keep();
handle
}
pub fn provided(value: T) -> (Self, Provider) {
let (keep_tx, keep_rx) = tokio::sync::watch::channel(false);
let handle = Self {
state: State::LocalCreated {
entry: Arc::new(tokio::sync::RwLock::new(Some(Box::new(value)))),
keep_rx,
},
_data: PhantomData,
};
let provider = Provider { keep_tx: Some(keep_tx) };
(handle, provider)
}
pub async fn into_inner(mut self) -> Result<T, HandleError> {
let entry = match mem::take(&mut self.state) {
State::LocalCreated { entry, .. } => entry,
State::LocalReceived { entry, .. } => entry,
_ => return Err(HandleError::Unknown),
};
let mut entry = entry.write().await;
match entry.take() {
Some(any) => match any.downcast::<T>() {
Ok(value) => Ok(*value),
Err(any) => Err(HandleError::MismatchedType(format!("{:?}", (*any).type_id()))),
},
None => Err(HandleError::Unknown),
}
}
pub async fn as_ref(&self) -> Result<Ref<T>, HandleError> {
let entry = match &self.state {
State::LocalCreated { entry, .. } | State::LocalReceived { entry, .. } => entry.clone(),
_ => return Err(HandleError::Unknown),
};
let entry = entry.read_owned().await;
match &*entry {
Some(any) => {
if !any.is::<T>() {
return Err(HandleError::MismatchedType(format!("{:?}", (**any).type_id())));
}
let value_ref = OwnedRwLockReadGuard::map(entry, |entry| {
entry.as_ref().unwrap().downcast_ref::<T>().unwrap()
});
Ok(Ref(value_ref))
}
None => Err(HandleError::Unknown),
}
}
pub async fn as_mut(&mut self) -> Result<RefMut<T>, HandleError> {
let entry = match &self.state {
State::LocalCreated { entry, .. } | State::LocalReceived { entry, .. } => entry.clone(),
_ => return Err(HandleError::Unknown),
};
let entry = entry.write_owned().await;
match &*entry {
Some(any) => {
if !any.is::<T>() {
return Err(HandleError::MismatchedType(format!("{:?}", (**any).type_id())));
}
let value_ref = OwnedRwLockWriteGuard::map(entry, |entry| {
entry.as_mut().unwrap().downcast_mut::<T>().unwrap()
});
Ok(RefMut(value_ref))
}
None => Err(HandleError::Unknown),
}
}
pub fn cast<TNew>(self) -> Handle<TNew, Codec> {
Handle { state: self.state.clone(), _data: PhantomData }
}
}
impl<T, Codec> Drop for Handle<T, Codec> {
fn drop(&mut self) {
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "Codec: codec::Codec"))]
#[serde(bound(deserialize = "Codec: codec::Codec"))]
pub(crate) struct TransportedHandle<T, Codec> {
id: Uuid,
dropped_tx: mpsc::Sender<(), Codec, 1>,
data: PhantomData<T>,
codec: PhantomData<Codec>,
}
impl<T, Codec> Serialize for Handle<T, Codec>
where
Codec: codec::Codec,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let (id, dropped_tx) = match self.state.clone() {
State::LocalCreated { entry, mut keep_rx, .. } => {
let handle_storage = PortSerializer::storage()?;
let id = handle_storage.insert(entry.clone());
let (dropped_tx, dropped_rx) = mpsc::channel(1);
let dropped_tx = dropped_tx.set_buffer::<1>();
let mut dropped_rx = dropped_rx.set_buffer::<1>();
exec::spawn(
async move {
loop {
if *keep_rx.borrow_and_update() {
let _ = dropped_rx.recv().await;
break;
} else {
tokio::select! {
biased;
res = keep_rx.changed() => {
if !*keep_rx.borrow_and_update() && res.is_err() {
break;
}
},
_ = dropped_rx.recv() => break,
}
}
}
handle_storage.remove(id);
}
.in_current_span(),
);
(id, dropped_tx)
}
State::LocalReceived { id, dropped_tx, .. } | State::Remote { id, dropped_tx } => (id, dropped_tx),
State::Empty => unreachable!("state is only empty when dropping"),
};
let transported = TransportedHandle::<T, Codec> { id, dropped_tx, data: PhantomData, codec: PhantomData };
transported.serialize(serializer)
}
}
impl<'de, T, Codec> Deserialize<'de> for Handle<T, Codec>
where
Codec: codec::Codec,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let TransportedHandle { id, dropped_tx, .. } = TransportedHandle::<T, Codec>::deserialize(deserializer)?;
let handle_storage = PortDeserializer::storage()?;
let state = match handle_storage.remove(id) {
Some(entry) => State::LocalReceived { entry, id, dropped_tx },
None => State::Remote { id, dropped_tx },
};
Ok(Self { state, _data: PhantomData })
}
}
pub struct Ref<T>(OwnedRwLockReadGuard<Option<AnyBox>, T>);
impl<T> Deref for Ref<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> fmt::Debug for Ref<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}
pub struct RefMut<T>(OwnedRwLockMappedWriteGuard<Option<AnyBox>, T>);
impl<T> Deref for RefMut<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for RefMut<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> fmt::Debug for RefMut<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}