use futures::{
Future, FutureExt,
future::{self, BoxFuture, MaybeDone},
};
use serde::{Deserialize, Serialize};
use std::{error::Error, fmt, marker::PhantomData, ops::Deref, pin::Pin, sync::Arc};
use tokio::sync::Mutex;
use crate::{
RemoteSend, chmux, codec, exec,
rch::{base, mpsc, oneshot},
};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum FetchError {
Dropped,
RemoteReceive(base::RecvError),
RemoteConnect(chmux::ConnectError),
RemoteListen(chmux::ListenerError),
}
impl fmt::Display for FetchError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dropped => write!(f, "lazy provider dropped"),
Self::RemoteReceive(err) => write!(f, "receive error: {err}"),
Self::RemoteConnect(err) => write!(f, "connect error: {err}"),
Self::RemoteListen(err) => write!(f, "listen error: {err}"),
}
}
}
impl From<oneshot::RecvError> for FetchError {
fn from(err: oneshot::RecvError) -> Self {
match err {
oneshot::RecvError::Closed => Self::Dropped,
oneshot::RecvError::RemoteReceive(err) => Self::RemoteReceive(err),
oneshot::RecvError::RemoteConnect(err) => Self::RemoteConnect(err),
oneshot::RecvError::RemoteListen(err) => Self::RemoteListen(err),
}
}
}
impl Error for FetchError {}
pub struct Provider {
keep_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl fmt::Debug for Provider {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Provider").finish()
}
}
impl Provider {
pub fn keep(mut self) {
let _ = self.keep_tx.take().unwrap().send(());
}
pub async fn done(&mut self) {
self.keep_tx.as_mut().unwrap().closed().await
}
}
impl Drop for Provider {
fn drop(&mut self) {
}
}
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
pub struct Lazy<T, Codec = codec::Default> {
request_tx: mpsc::Sender<oneshot::Sender<T, Codec>, Codec, 1>,
#[serde(skip)]
#[serde(default)]
#[allow(clippy::type_complexity)]
fetch_task: Arc<Mutex<Option<Pin<Box<MaybeDone<BoxFuture<'static, Result<Arc<T>, FetchError>>>>>>>>,
}
impl<T, Codec> fmt::Debug for Lazy<T, Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Lazy").finish()
}
}
impl<T, Codec> Lazy<T, Codec>
where
T: RemoteSend,
Codec: codec::Codec,
{
pub fn new(value: T) -> Self {
Self::new_future(async move { value })
}
pub fn new_future<F>(value_fut: F) -> Self
where
F: Future<Output = T> + Send + 'static,
{
let (lazy, provider) = Self::provided_future(value_fut);
provider.keep();
lazy
}
pub fn provided(value: T) -> (Self, Provider) {
Self::provided_future(async move { value })
}
pub fn provided_future<F>(value_fut: F) -> (Self, Provider)
where
F: Future<Output = T> + Send + 'static,
{
let (request_tx, request_rx) = mpsc::channel::<oneshot::Sender<T, Codec>, _>(1);
let request_tx = request_tx.set_buffer::<1>();
let mut request_rx = request_rx.set_buffer::<1>();
let (keep_tx, keep_rx) = tokio::sync::oneshot::channel();
exec::spawn(async move {
tokio::select! {
res = request_rx.recv() => {
if let Ok(Some(value_tx)) = res {
let value = value_fut.await;
let _ = value_tx.send(value);
}
},
Err(_) = keep_rx => (),
}
});
let provider = Provider { keep_tx: Some(keep_tx) };
let lazy = Lazy { request_tx, fetch_task: Default::default() };
(lazy, provider)
}
async fn fetch(&self) {
let mut fetch_task = self.fetch_task.lock().await;
if fetch_task.is_none() {
let req_tx = self.request_tx.clone();
*fetch_task = Some(Box::pin(future::maybe_done(
async move {
let (value_tx, value_rx) = oneshot::channel();
let _ = req_tx.send(value_tx).await;
let value = value_rx.await?;
Ok(Arc::new(value))
}
.boxed(),
)));
}
fetch_task.as_mut().unwrap().await;
}
pub async fn get(&self) -> Result<Ref<'_, T>, FetchError> {
self.fetch().await;
let mut res_task = self.fetch_task.lock().await;
match res_task.as_mut().unwrap().as_mut().output_mut().unwrap() {
Ok(value) => Ok(Ref { value: value.clone(), _lifetime: PhantomData }),
Err(err) => Err(err.clone()),
}
}
pub async fn into_inner(self) -> Result<T, FetchError> {
self.fetch().await;
let mut res_task = self.fetch_task.lock().await;
res_task.as_mut().unwrap().as_mut().take_output().unwrap().map(|arc| match Arc::try_unwrap(arc) {
Ok(value) => value,
Err(_) => unreachable!("no other reference can exist"),
})
}
}
pub struct Ref<'a, T> {
value: Arc<T>,
_lifetime: PhantomData<&'a ()>,
}
impl<T> fmt::Debug for Ref<'_, T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", &**self)
}
}
impl<T> Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T> Drop for Ref<'_, T> {
fn drop(&mut self) {
}
}