use std::{future::Future, io, marker::PhantomData, ops::Deref, sync::Arc};
use bytes::Bytes;
use iroh_io::AsyncSliceWriter;
pub trait ProgressSender: std::fmt::Debug + Clone + Send + Sync + 'static {
type Msg: Send + Sync + 'static;
#[must_use]
fn send(&self, msg: Self::Msg) -> impl Future<Output = ProgressSendResult<()>> + Send;
fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
fn with_map<U: Send + Sync + 'static, F: Fn(U) -> Self::Msg + Send + Sync + Clone + 'static>(
self,
f: F,
) -> WithMap<Self, U, F> {
WithMap(self, f, PhantomData)
}
fn with_filter_map<
U: Send + Sync + 'static,
F: Fn(U) -> Option<Self::Msg> + Send + Sync + Clone + 'static,
>(
self,
f: F,
) -> WithFilterMap<Self, U, F> {
WithFilterMap(self, f, PhantomData)
}
fn boxed(self) -> BoxedProgressSender<Self::Msg>
where
Self: IdGenerator,
{
BoxedProgressSender(Arc::new(BoxableProgressSenderWrapper(self)))
}
}
pub struct BoxedProgressSender<T>(Arc<dyn BoxableProgressSender<T>>);
impl<T> Clone for BoxedProgressSender<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T> std::fmt::Debug for BoxedProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("BoxedProgressSender").field(&self.0).finish()
}
}
type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + Send + 'a>>;
trait BoxableProgressSender<T>: IdGenerator + std::fmt::Debug + Send + Sync + 'static {
#[must_use]
fn send(&self, msg: T) -> BoxFuture<'_, ProgressSendResult<()>>;
fn try_send(&self, msg: T) -> ProgressSendResult<()>;
fn blocking_send(&self, msg: T) -> ProgressSendResult<()>;
}
impl<I: ProgressSender + IdGenerator> BoxableProgressSender<I::Msg>
for BoxableProgressSenderWrapper<I>
{
fn send(&self, msg: I::Msg) -> BoxFuture<'_, ProgressSendResult<()>> {
Box::pin(self.0.send(msg))
}
fn try_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
self.0.try_send(msg)
}
fn blocking_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
self.0.blocking_send(msg)
}
}
#[derive(Debug)]
#[repr(transparent)]
struct BoxableProgressSenderWrapper<I>(I);
impl<I: ProgressSender + IdGenerator> IdGenerator for BoxableProgressSenderWrapper<I> {
fn new_id(&self) -> u64 {
self.0.new_id()
}
}
impl<T: Send + Sync + 'static> IdGenerator for Arc<dyn BoxableProgressSender<T>> {
fn new_id(&self) -> u64 {
self.deref().new_id()
}
}
impl<T: Send + Sync + 'static> ProgressSender for Arc<dyn BoxableProgressSender<T>> {
type Msg = T;
fn send(&self, msg: T) -> impl Future<Output = ProgressSendResult<()>> + Send {
self.deref().send(msg)
}
fn try_send(&self, msg: T) -> ProgressSendResult<()> {
self.deref().try_send(msg)
}
fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
self.deref().blocking_send(msg)
}
}
impl<T: Send + Sync + 'static> IdGenerator for BoxedProgressSender<T> {
fn new_id(&self) -> u64 {
self.0.new_id()
}
}
impl<T: Send + Sync + 'static> ProgressSender for BoxedProgressSender<T> {
type Msg = T;
async fn send(&self, msg: T) -> ProgressSendResult<()> {
self.0.send(msg).await
}
fn try_send(&self, msg: T) -> ProgressSendResult<()> {
self.0.try_send(msg)
}
fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
self.0.blocking_send(msg)
}
}
impl<T: ProgressSender> ProgressSender for Option<T> {
type Msg = T::Msg;
async fn send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
if let Some(inner) = self {
inner.send(msg).await
} else {
Ok(())
}
}
fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
if let Some(inner) = self {
inner.try_send(msg)
} else {
Ok(())
}
}
fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
if let Some(inner) = self {
inner.blocking_send(msg)
} else {
Ok(())
}
}
}
pub trait IdGenerator {
fn new_id(&self) -> u64;
}
pub struct IgnoreProgressSender<T>(PhantomData<T>);
impl<T> Default for IgnoreProgressSender<T> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<T> Clone for IgnoreProgressSender<T> {
fn clone(&self) -> Self {
Self(PhantomData)
}
}
impl<T> std::fmt::Debug for IgnoreProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IgnoreProgressSender").finish()
}
}
impl<T: Send + Sync + 'static> ProgressSender for IgnoreProgressSender<T> {
type Msg = T;
async fn send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
Ok(())
}
fn try_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
Ok(())
}
fn blocking_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
Ok(())
}
}
impl<T> IdGenerator for IgnoreProgressSender<T> {
fn new_id(&self) -> u64 {
0
}
}
pub struct WithMap<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
>(I, F, PhantomData<U>);
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> std::fmt::Debug for WithMap<I, U, F>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("With").field(&self.0).finish()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> Clone for WithMap<I, U, F>
{
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone(), PhantomData)
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> ProgressSender for WithMap<I, U, F>
{
type Msg = U;
async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
let msg = (self.1)(msg);
self.0.send(msg).await
}
fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
let msg = (self.1)(msg);
self.0.try_send(msg)
}
fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
let msg = (self.1)(msg);
self.0.blocking_send(msg)
}
}
pub struct WithFilterMap<I, U, F>(I, F, PhantomData<U>);
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> std::fmt::Debug for WithFilterMap<I, U, F>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("FilterWith").field(&self.0).finish()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> Clone for WithFilterMap<I, U, F>
{
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone(), PhantomData)
}
}
impl<I: IdGenerator, U, F> IdGenerator for WithFilterMap<I, U, F> {
fn new_id(&self) -> u64 {
self.0.new_id()
}
}
impl<
I: IdGenerator + ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
> IdGenerator for WithMap<I, U, F>
{
fn new_id(&self) -> u64 {
self.0.new_id()
}
}
impl<
I: ProgressSender,
U: Send + Sync + 'static,
F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
> ProgressSender for WithFilterMap<I, U, F>
{
type Msg = U;
async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
if let Some(msg) = (self.1)(msg) {
self.0.send(msg).await
} else {
Ok(())
}
}
fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
if let Some(msg) = (self.1)(msg) {
self.0.try_send(msg)
} else {
Ok(())
}
}
fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
if let Some(msg) = (self.1)(msg) {
self.0.blocking_send(msg)
} else {
Ok(())
}
}
}
pub struct FlumeProgressSender<T> {
sender: flume::Sender<T>,
id: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl<T> std::fmt::Debug for FlumeProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlumeProgressSender")
.field("id", &self.id)
.field("sender", &self.sender)
.finish()
}
}
impl<T> Clone for FlumeProgressSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
id: self.id.clone(),
}
}
}
impl<T> FlumeProgressSender<T> {
pub fn new(sender: flume::Sender<T>) -> Self {
Self {
sender,
id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub fn same_channel(&self, other: &FlumeProgressSender<T>) -> bool {
self.sender.same_channel(&other.sender)
}
}
impl<T> IdGenerator for FlumeProgressSender<T> {
fn new_id(&self) -> u64 {
self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
}
impl<T: Send + Sync + 'static> ProgressSender for FlumeProgressSender<T> {
type Msg = T;
async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
self.sender
.send_async(msg)
.await
.map_err(|_| ProgressSendError::ReceiverDropped)
}
fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.try_send(msg) {
Ok(_) => Ok(()),
Err(flume::TrySendError::Full(_)) => Ok(()),
Err(flume::TrySendError::Disconnected(_)) => Err(ProgressSendError::ReceiverDropped),
}
}
fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.send(msg) {
Ok(_) => Ok(()),
Err(_) => Err(ProgressSendError::ReceiverDropped),
}
}
}
pub struct AsyncChannelProgressSender<T> {
sender: async_channel::Sender<T>,
id: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl<T> std::fmt::Debug for AsyncChannelProgressSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncChannelProgressSender")
.field("id", &self.id)
.field("sender", &self.sender)
.finish()
}
}
impl<T> Clone for AsyncChannelProgressSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
id: self.id.clone(),
}
}
}
impl<T> AsyncChannelProgressSender<T> {
pub fn new(sender: async_channel::Sender<T>) -> Self {
Self {
sender,
id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub fn same_channel(&self, other: &AsyncChannelProgressSender<T>) -> bool {
same_channel(&self.sender, &other.sender)
}
}
fn get_as_ptr<T>(value: &T) -> Option<usize> {
use std::mem;
if mem::size_of::<T>() == std::mem::size_of::<usize>()
&& mem::align_of::<T>() == mem::align_of::<usize>()
{
unsafe { Some(mem::transmute_copy(value)) }
} else {
None
}
}
fn same_channel<T>(a: &async_channel::Sender<T>, b: &async_channel::Sender<T>) -> bool {
get_as_ptr(a).unwrap() == get_as_ptr(b).unwrap()
}
impl<T> IdGenerator for AsyncChannelProgressSender<T> {
fn new_id(&self) -> u64 {
self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
}
impl<T: Send + Sync + 'static> ProgressSender for AsyncChannelProgressSender<T> {
type Msg = T;
async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
self.sender
.send(msg)
.await
.map_err(|_| ProgressSendError::ReceiverDropped)
}
fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.try_send(msg) {
Ok(_) => Ok(()),
Err(async_channel::TrySendError::Full(_)) => Ok(()),
Err(async_channel::TrySendError::Closed(_)) => Err(ProgressSendError::ReceiverDropped),
}
}
fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
match self.sender.send_blocking(msg) {
Ok(_) => Ok(()),
Err(_) => Err(ProgressSendError::ReceiverDropped),
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ProgressSendError {
#[error("receiver dropped")]
ReceiverDropped,
}
pub type ProgressSendResult<T> = std::result::Result<T, ProgressSendError>;
impl From<ProgressSendError> for std::io::Error {
fn from(e: ProgressSendError) -> Self {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
}
}
#[derive(Debug)]
pub struct ProgressSliceWriter<W, F>(W, F);
impl<W: AsyncSliceWriter, F: FnMut(u64)> ProgressSliceWriter<W, F> {
pub fn new(inner: W, on_write: F) -> Self {
Self(inner, on_write)
}
pub fn into_inner(self) -> W {
self.0
}
}
impl<W: AsyncSliceWriter + 'static, F: FnMut(u64, usize) + 'static> AsyncSliceWriter
for ProgressSliceWriter<W, F>
{
async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
(self.1)(offset, data.len());
self.0.write_bytes_at(offset, data).await
}
async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
(self.1)(offset, data.len());
self.0.write_at(offset, data).await
}
async fn sync(&mut self) -> io::Result<()> {
self.0.sync().await
}
async fn set_len(&mut self, size: u64) -> io::Result<()> {
self.0.set_len(size).await
}
}
#[derive(Debug)]
pub struct FallibleProgressSliceWriter<W, F>(W, F);
impl<W: AsyncSliceWriter, F: Fn(u64, usize) -> io::Result<()> + 'static>
FallibleProgressSliceWriter<W, F>
{
pub fn new(inner: W, on_write: F) -> Self {
Self(inner, on_write)
}
pub fn into_inner(self) -> W {
self.0
}
}
impl<W: AsyncSliceWriter + 'static, F: Fn(u64, usize) -> io::Result<()> + 'static> AsyncSliceWriter
for FallibleProgressSliceWriter<W, F>
{
async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
(self.1)(offset, data.len())?;
self.0.write_bytes_at(offset, data).await
}
async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
(self.1)(offset, data.len())?;
self.0.write_at(offset, data).await
}
async fn sync(&mut self) -> io::Result<()> {
self.0.sync().await
}
async fn set_len(&mut self, size: u64) -> io::Result<()> {
self.0.set_len(size).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
fn get_as_ptr_works() {
struct Wrapper(Arc<u64>);
let x = Wrapper(Arc::new(1u64));
assert_eq!(
get_as_ptr(&x).unwrap(),
Arc::as_ptr(&x.0) as usize - 2 * std::mem::size_of::<usize>()
);
}
#[test]
fn get_as_ptr_wrong_use() {
struct Wrapper(#[allow(dead_code)] u8);
let x = Wrapper(1);
assert!(get_as_ptr(&x).is_none());
}
#[test]
fn test_sender_is_ptr() {
assert_eq!(
std::mem::size_of::<usize>(),
std::mem::size_of::<async_channel::Sender<u8>>()
);
assert_eq!(
std::mem::align_of::<usize>(),
std::mem::align_of::<async_channel::Sender<u8>>()
);
}
}