use futures::{FutureExt, TryFutureExt};
use std::marker::PhantomData;
pub trait ProgressSender: std::fmt::Debug + Clone + Send + Sync + 'static {
type Msg: Send + Sync + 'static;
type SendFuture<'a>: futures::Future<Output = std::result::Result<(), ProgressSendError>>
+ Send
+ 'a
where
Self: 'a;
#[must_use]
fn send(&self, msg: Self::Msg) -> Self::SendFuture<'_>;
fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError>;
fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError>;
fn with_map<U: Send + Sync + 'static, F: Fn(U) -> Self::Msg + Send + Sync + Clone + 'static>(
self,
f: F,
) -> WithMap<Self, U, F> {
WithMap(self, f, PhantomData)
}
fn with_filter_map<
U: Send + Sync + 'static,
F: Fn(U) -> Option<Self::Msg> + Send + Sync + Clone + 'static,
>(
self,
f: F,
) -> WithFilterMap<Self, U, F> {
WithFilterMap(self, f, PhantomData)
}
}
pub trait IdGenerator {
fn new_id(&self) -> u64;
}
pub struct IgnoreProgressSender<T>(PhantomData<T>);
impl<T> Default for IgnoreProgressSender<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Clone for IgnoreProgressSender<T> {
fn clone(&self) -> Self {
Self(PhantomData)
}
}
impl<T> std::fmt::Debug for IgnoreProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IgnoreProgressSender").finish()
}
}
impl<T: Send + Sync + 'static> ProgressSender for IgnoreProgressSender<T> {
type Msg = T;
type SendFuture<'a> = futures::future::Ready<std::result::Result<(), ProgressSendError>>;
fn send(&self, _msg: T) -> Self::SendFuture<'_> {
futures::future::ready(Ok(()))
}
fn try_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
Ok(())
}
fn blocking_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
Ok(())
}
}
impl<T> IdGenerator for IgnoreProgressSender<T> {
fn new_id(&self) -> u64 {
0
}
}
pub struct WithMap<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
>(I, F, PhantomData<U>);
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> std::fmt::Debug for WithMap<I, U, F>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("With").field(&self.0).finish()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> Clone for WithMap<I, U, F>
{
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone(), PhantomData)
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> ProgressSender for WithMap<I, U, F>
{
type Msg = U;
type SendFuture<'a> = I::SendFuture<'a>;
fn send(&self, msg: U) -> Self::SendFuture<'_> {
let msg = (self.1)(msg);
self.0.send(msg)
}
fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
let msg = (self.1)(msg);
self.0.try_send(msg)
}
fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
let msg = (self.1)(msg);
self.0.blocking_send(msg)
}
}
pub struct WithFilterMap<I, U, F>(I, F, PhantomData<U>);
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> std::fmt::Debug for WithFilterMap<I, U, F>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("FilterWith").field(&self.0).finish()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> Clone for WithFilterMap<I, U, F>
{
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone(), PhantomData)
}
}
impl<I: IdGenerator, U, F> IdGenerator for WithFilterMap<I, U, F> {
fn new_id(&self) -> u64 {
self.0.new_id()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> ProgressSender for WithFilterMap<I, U, F>
{
type Msg = U;
type SendFuture<'a> = futures::future::Either<
I::SendFuture<'a>,
futures::future::Ready<std::result::Result<(), ProgressSendError>>,
>;
fn send(&self, msg: U) -> Self::SendFuture<'_> {
if let Some(msg) = (self.1)(msg) {
self.0.send(msg).left_future()
} else {
futures::future::ok(()).right_future()
}
}
fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
if let Some(msg) = (self.1)(msg) {
self.0.try_send(msg)
} else {
Ok(())
}
}
fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
if let Some(msg) = (self.1)(msg) {
self.0.blocking_send(msg)
} else {
Ok(())
}
}
}
pub struct FlumeProgressSender<T> {
sender: flume::Sender<T>,
id: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl<T> std::fmt::Debug for FlumeProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlumeProgressSender")
.field("id", &self.id)
.field("sender", &self.sender)
.finish()
}
}
impl<T> Clone for FlumeProgressSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
id: self.id.clone(),
}
}
}
impl<T> FlumeProgressSender<T> {
pub fn new(sender: flume::Sender<T>) -> Self {
Self {
sender,
id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
}
impl<T> IdGenerator for FlumeProgressSender<T> {
fn new_id(&self) -> u64 {
self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
}
impl<T: Send + Sync + 'static> ProgressSender for FlumeProgressSender<T> {
type Msg = T;
type SendFuture<'a> =
futures::future::BoxFuture<'a, std::result::Result<(), ProgressSendError>>;
fn send(&self, msg: Self::Msg) -> Self::SendFuture<'_> {
self.sender
.send_async(msg)
.map_err(|_| ProgressSendError::ReceiverDropped)
.boxed()
}
fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.try_send(msg) {
Ok(_) => Ok(()),
Err(flume::TrySendError::Full(_)) => Ok(()),
Err(flume::TrySendError::Disconnected(_)) => Err(ProgressSendError::ReceiverDropped),
}
}
fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.send(msg) {
Ok(_) => Ok(()),
Err(_) => Err(ProgressSendError::ReceiverDropped),
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ProgressSendError {
#[error("receiver dropped")]
ReceiverDropped,
}
impl From<ProgressSendError> for std::io::Error {
fn from(e: ProgressSendError) -> Self {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
}
}