use bytes::Bytes;
use futures::{
future,
future::{BoxFuture, MaybeDone},
FutureExt,
};
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, fmt, pin::Pin, sync::Arc};
use tokio::sync::Mutex;
use crate::{
chmux,
chmux::DataBuf,
codec,
rch::{mpsc, ConnectError},
};
mod fw_bin;
#[derive(Debug, Clone)]
pub struct UsizeExceeded(pub u64);
impl fmt::Display for UsizeExceeded {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "binary data ({} bytes) exceeds maximum array size", self.0)
}
}
impl std::error::Error for UsizeExceeded {}
#[derive(Debug, Clone)]
pub enum FetchError {
Dropped,
Size(UsizeExceeded),
RemoteReceive(chmux::RecvError),
RemoteConnect(ConnectError),
}
impl From<UsizeExceeded> for FetchError {
fn from(err: UsizeExceeded) -> Self {
Self::Size(err)
}
}
impl fmt::Display for FetchError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dropped => write!(f, "provider was dropped"),
Self::Size(err) => write!(f, "{}", err),
Self::RemoteReceive(err) => write!(f, "receive error: {}", &err),
Self::RemoteConnect(err) => write!(f, "connect error: {}", &err),
}
}
}
impl std::error::Error for FetchError {}
pub struct Provider {
keep_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl fmt::Debug for Provider {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Provider").finish()
}
}
impl Provider {
pub fn keep(mut self) {
let _ = self.keep_tx.take().unwrap().send(());
}
pub async fn done(&mut self) {
self.keep_tx.as_mut().unwrap().closed().await
}
}
impl Drop for Provider {
fn drop(&mut self) {
}
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "Codec: codec::Codec"))]
#[serde(bound(deserialize = "Codec: codec::Codec"))]
pub struct LazyBlob<Codec = codec::Default> {
req_tx: mpsc::Sender<fw_bin::Sender, Codec, 1>,
len: u64,
#[serde(skip)]
#[serde(default)]
#[allow(clippy::type_complexity)]
fetch_task: Arc<Mutex<Option<Pin<Box<MaybeDone<BoxFuture<'static, Result<DataBuf, FetchError>>>>>>>>,
}
impl<Codec> fmt::Debug for LazyBlob<Codec> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("LazyBlob").field("len", &self.len).finish()
}
}
impl<Codec> LazyBlob<Codec>
where
Codec: codec::Codec,
{
pub fn new(data: Bytes) -> Self {
let (lazy_blob, provider) = Self::provided(data);
provider.keep();
lazy_blob
}
pub fn provided(data: Bytes) -> (Self, Provider) {
let (keep_tx, keep_rx) = tokio::sync::oneshot::channel();
let (req_tx, req_rx) = mpsc::channel(1);
let req_tx = req_tx.set_buffer();
let mut req_rx = req_rx.set_buffer::<1>();
let len = data.len() as _;
tokio::spawn(async move {
let do_send = async move {
loop {
let fw_tx: fw_bin::Sender = match req_rx.recv().await {
Ok(Some(fw_tx)) => fw_tx,
Ok(None) => break,
Err(err) if err.is_final() => break,
Err(_) => continue,
};
let data = data.clone();
tokio::spawn(async move {
let bin_tx = if let Some(tx) = fw_tx.into_inner() { tx } else { return };
let mut tx = if let Ok(tx) = bin_tx.into_inner().await { tx } else { return };
let _ = tx.send(data).await;
});
}
};
tokio::select! {
() = do_send => (),
Err(_) = keep_rx => (),
}
});
let lazy_blob = LazyBlob { req_tx, len, fetch_task: Default::default() };
let provider = Provider { keep_tx: Some(keep_tx) };
(lazy_blob, provider)
}
pub fn is_empty(&self) -> bool {
matches!(self.len(), Ok(0))
}
pub fn len(&self) -> Result<usize, UsizeExceeded> {
usize::try_from(self.len).map_err(|_| UsizeExceeded(self.len))
}
async fn fetch(&self) -> Result<(), FetchError> {
let mut fetch_task = self.fetch_task.lock().await;
if fetch_task.is_none() {
let req_tx = self.req_tx.clone();
let len = self.len()?;
*fetch_task = Some(Box::pin(future::maybe_done(
async move {
let (fw_tx, fw_rx) = fw_bin::channel();
let _ = req_tx.send(fw_tx).await;
let bin_rx = fw_rx.into_inner().await.ok_or(FetchError::Dropped)?;
let mut rx = bin_rx.into_inner().await.map_err(FetchError::RemoteConnect)?;
rx.set_max_data_size(len);
rx.recv().await.map_err(FetchError::RemoteReceive)?.ok_or(FetchError::Dropped)
}
.boxed(),
)));
}
fetch_task.as_mut().unwrap().await;
Ok(())
}
pub async fn get(&self) -> Result<DataBuf, FetchError> {
self.fetch().await?;
let mut res = self.fetch_task.lock().await;
res.as_mut().unwrap().as_mut().output_mut().unwrap().clone()
}
pub async fn into_inner(mut self) -> Result<DataBuf, FetchError> {
self.fetch().await?;
match Arc::try_unwrap(self.fetch_task) {
Ok(fetch_task) => {
let mut res = fetch_task.lock().await;
res.as_mut().unwrap().as_mut().take_output().unwrap()
}
Err(shared_fetch_task) => {
self.fetch_task = shared_fetch_task;
self.get().await
}
}
}
}