iroh_blobs/util/
progress.rs

1//! Utilities for reporting progress.
2//!
3//! The main entry point is the [ProgressSender] trait.
4use std::{future::Future, io, marker::PhantomData, ops::Deref, sync::Arc};
5
6use bytes::Bytes;
7use iroh_io::AsyncSliceWriter;
8
9/// A general purpose progress sender. This should be usable for reporting progress
10/// from both blocking and non-blocking contexts.
11///
12/// # Id generation
13///
14/// Any good progress protocol will refer to entities by means of a unique id.
15/// E.g. if you want to report progress about some file operation, including details
16/// such as the full path of the file would be very wasteful. It is better to
17/// introduce a unique id for the file and then report progress using that id.
18///
19/// The [IdGenerator] trait provides a method to generate such ids, [IdGenerator::new_id].
20///
21/// # Sending important messages
22///
23/// Some messages are important for the receiver to receive. E.g. start and end
24/// messages for some operation. If the receiver would miss one of these messages,
25/// it would lose the ability to make sense of the progress message stream.
26///
27/// This trait provides a method to send such important messages, in both blocking
28/// contexts where you have to block until the message is sent [ProgressSender::blocking_send],
29/// and non-blocking contexts where you have to yield until the message is sent [ProgressSender::send].
30///
31/// # Sending unimportant messages
32///
33/// Some messages are self-contained and not important for the receiver to receive.
34/// E.g. if you send millions of progress messages for copying a file that each
35/// contain an id and the number of bytes copied so far, it is not important for
36/// the receiver to receive every single one of these messages. In fact it is
37/// useful to drop some of these messages because waiting for the progress events
38/// to be sent can slow down the actual operation.
39///
40/// This trait provides a method to send such unimportant messages that can be
41/// used in both blocking and non-blocking contexts, [ProgressSender::try_send].
42///
43/// # Errors
44///
45/// When the receiver is dropped, sending a message will fail. This provides a way
46/// for the receiver to signal that the operation should be stopped.
47///
48/// E.g. for a blocking copy operation that reports frequent progress messages,
49/// as soon as the receiver is dropped, this is a signal to stop the copy operation.
50///
51/// The error type is [ProgressSendError], which can be converted to an [std::io::Error]
52/// for convenience.
53///
54/// # Transforming the message type
55///
56/// Sometimes you have a progress sender that sends a message of type `A` but an
57/// operation that reports progress of type `B`. If you have a transformation for
58/// every `B` to an `A`, you can use the [ProgressSender::with_map] method to transform the message.
59///
60/// This is similar to the `futures::SinkExt::with` method.
61///
62/// # Filtering the message type
63///
64/// Sometimes you have a progress sender that sends a message of enum `A` but an
65/// operation that reports progress of type `B`. You are interested only in some
66/// enum cases of `A` that can be transformed to `B`. You can use the [ProgressSender::with_filter_map]
67/// method to filter and transform the message.
68///
69/// # No-op progress sender
70///
71/// If you don't want to report progress, you can use the [IgnoreProgressSender] type.
72///
73/// # Async channel progress sender
74///
75/// If you want to use an async channel, you can use the [AsyncChannelProgressSender] type.
76///
77/// # Implementing your own progress sender
78///
79/// Progress senders will frequently be used in a multi-threaded context.
80///
81/// They must be **cheap** to clone and send between threads.
82/// They must also be thread safe, which is ensured by the [Send] and [Sync] bounds.
83/// They must also be unencumbered by lifetimes, which is ensured by the `'static` bound.
84///
85/// A typical implementation will wrap the sender part of a channel and an id generator.
86pub trait ProgressSender: std::fmt::Debug + Clone + Send + Sync + 'static {
87    /// The message being sent.
88    type Msg: Send + Sync + 'static;
89
90    /// Send a message and wait if the receiver is full.
91    ///
92    /// Use this to send important progress messages where delivery must be guaranteed.
93    #[must_use]
94    fn send(&self, msg: Self::Msg) -> impl Future<Output = ProgressSendResult<()>> + Send;
95
96    /// Try to send a message and drop it if the receiver is full.
97    ///
98    /// Use this to send progress messages where delivery is not important, e.g. a self contained progress message.
99    fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
100
101    /// Send a message and block if the receiver is full.
102    ///
103    /// Use this to send important progress messages where delivery must be guaranteed.
104    fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
105
106    /// Transform the message type by mapping to the type of this sender.
107    fn with_map<U: Send + Sync + 'static, F: Fn(U) -> Self::Msg + Send + Sync + Clone + 'static>(
108        self,
109        f: F,
110    ) -> WithMap<Self, U, F> {
111        WithMap(self, f, PhantomData)
112    }
113
114    /// Transform the message type by filter-mapping to the type of this sender.
115    fn with_filter_map<
116        U: Send + Sync + 'static,
117        F: Fn(U) -> Option<Self::Msg> + Send + Sync + Clone + 'static,
118    >(
119        self,
120        f: F,
121    ) -> WithFilterMap<Self, U, F> {
122        WithFilterMap(self, f, PhantomData)
123    }
124
125    /// Create a boxed progress sender to get rid of the concrete type.
126    fn boxed(self) -> BoxedProgressSender<Self::Msg>
127    where
128        Self: IdGenerator,
129    {
130        BoxedProgressSender(Arc::new(BoxableProgressSenderWrapper(self)))
131    }
132}
133
134/// A boxed progress sender
135pub struct BoxedProgressSender<T>(Arc<dyn BoxableProgressSender<T>>);
136
137impl<T> Clone for BoxedProgressSender<T> {
138    fn clone(&self) -> Self {
139        Self(self.0.clone())
140    }
141}
142
143impl<T> std::fmt::Debug for BoxedProgressSender<T> {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_tuple("BoxedProgressSender").field(&self.0).finish()
146    }
147}
148
149type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + Send + 'a>>;
150
151/// Boxable progress sender
152trait BoxableProgressSender<T>: IdGenerator + std::fmt::Debug + Send + Sync + 'static {
153    /// Send a message and wait if the receiver is full.
154    ///
155    /// Use this to send important progress messages where delivery must be guaranteed.
156    #[must_use]
157    fn send(&self, msg: T) -> BoxFuture<'_, ProgressSendResult<()>>;
158
159    /// Try to send a message and drop it if the receiver is full.
160    ///
161    /// Use this to send progress messages where delivery is not important, e.g. a self contained progress message.
162    fn try_send(&self, msg: T) -> ProgressSendResult<()>;
163
164    /// Send a message and block if the receiver is full.
165    ///
166    /// Use this to send important progress messages where delivery must be guaranteed.
167    fn blocking_send(&self, msg: T) -> ProgressSendResult<()>;
168}
169
170impl<I: ProgressSender + IdGenerator> BoxableProgressSender<I::Msg>
171    for BoxableProgressSenderWrapper<I>
172{
173    fn send(&self, msg: I::Msg) -> BoxFuture<'_, ProgressSendResult<()>> {
174        Box::pin(self.0.send(msg))
175    }
176
177    fn try_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
178        self.0.try_send(msg)
179    }
180
181    fn blocking_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
182        self.0.blocking_send(msg)
183    }
184}
185
186/// Boxable progress sender wrapper, used internally.
187#[derive(Debug)]
188#[repr(transparent)]
189struct BoxableProgressSenderWrapper<I>(I);
190
191impl<I: ProgressSender + IdGenerator> IdGenerator for BoxableProgressSenderWrapper<I> {
192    fn new_id(&self) -> u64 {
193        self.0.new_id()
194    }
195}
196
197impl<T: Send + Sync + 'static> IdGenerator for Arc<dyn BoxableProgressSender<T>> {
198    fn new_id(&self) -> u64 {
199        self.deref().new_id()
200    }
201}
202
203impl<T: Send + Sync + 'static> ProgressSender for Arc<dyn BoxableProgressSender<T>> {
204    type Msg = T;
205
206    fn send(&self, msg: T) -> impl Future<Output = ProgressSendResult<()>> + Send {
207        self.deref().send(msg)
208    }
209
210    fn try_send(&self, msg: T) -> ProgressSendResult<()> {
211        self.deref().try_send(msg)
212    }
213
214    fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
215        self.deref().blocking_send(msg)
216    }
217}
218
219impl<T: Send + Sync + 'static> IdGenerator for BoxedProgressSender<T> {
220    fn new_id(&self) -> u64 {
221        self.0.new_id()
222    }
223}
224
225impl<T: Send + Sync + 'static> ProgressSender for BoxedProgressSender<T> {
226    type Msg = T;
227
228    async fn send(&self, msg: T) -> ProgressSendResult<()> {
229        self.0.send(msg).await
230    }
231
232    fn try_send(&self, msg: T) -> ProgressSendResult<()> {
233        self.0.try_send(msg)
234    }
235
236    fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
237        self.0.blocking_send(msg)
238    }
239}
240
241impl<T: ProgressSender> ProgressSender for Option<T> {
242    type Msg = T::Msg;
243
244    async fn send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
245        if let Some(inner) = self {
246            inner.send(msg).await
247        } else {
248            Ok(())
249        }
250    }
251
252    fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
253        if let Some(inner) = self {
254            inner.try_send(msg)
255        } else {
256            Ok(())
257        }
258    }
259
260    fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
261        if let Some(inner) = self {
262            inner.blocking_send(msg)
263        } else {
264            Ok(())
265        }
266    }
267}
268
269/// An id generator, to be combined with a progress sender.
270pub trait IdGenerator {
271    /// Get a new unique id
272    fn new_id(&self) -> u64;
273}
274
275/// A no-op progress sender.
276pub struct IgnoreProgressSender<T>(PhantomData<T>);
277
278impl<T> Default for IgnoreProgressSender<T> {
279    fn default() -> Self {
280        Self(PhantomData)
281    }
282}
283
284impl<T> Clone for IgnoreProgressSender<T> {
285    fn clone(&self) -> Self {
286        Self(PhantomData)
287    }
288}
289
290impl<T> std::fmt::Debug for IgnoreProgressSender<T> {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        f.debug_struct("IgnoreProgressSender").finish()
293    }
294}
295
296impl<T: Send + Sync + 'static> ProgressSender for IgnoreProgressSender<T> {
297    type Msg = T;
298
299    async fn send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
300        Ok(())
301    }
302
303    fn try_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
304        Ok(())
305    }
306
307    fn blocking_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
308        Ok(())
309    }
310}
311
312impl<T> IdGenerator for IgnoreProgressSender<T> {
313    fn new_id(&self) -> u64 {
314        0
315    }
316}
317
318/// Transform the message type by mapping to the type of this sender.
319///
320/// See [ProgressSender::with_map].
321pub struct WithMap<
322    I: ProgressSender,
323    U: Send + Sync + 'static,
324    F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
325>(I, F, PhantomData<U>);
326
327impl<
328        I: ProgressSender,
329        U: Send + Sync + 'static,
330        F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
331    > std::fmt::Debug for WithMap<I, U, F>
332{
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        f.debug_tuple("With").field(&self.0).finish()
335    }
336}
337
338impl<
339        I: ProgressSender,
340        U: Send + Sync + 'static,
341        F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
342    > Clone for WithMap<I, U, F>
343{
344    fn clone(&self) -> Self {
345        Self(self.0.clone(), self.1.clone(), PhantomData)
346    }
347}
348
349impl<
350        I: ProgressSender,
351        U: Send + Sync + 'static,
352        F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
353    > ProgressSender for WithMap<I, U, F>
354{
355    type Msg = U;
356
357    async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
358        let msg = (self.1)(msg);
359        self.0.send(msg).await
360    }
361
362    fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
363        let msg = (self.1)(msg);
364        self.0.try_send(msg)
365    }
366
367    fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
368        let msg = (self.1)(msg);
369        self.0.blocking_send(msg)
370    }
371}
372
373/// Transform the message type by filter-mapping to the type of this sender.
374///
375/// See [ProgressSender::with_filter_map].
376pub struct WithFilterMap<I, U, F>(I, F, PhantomData<U>);
377
378impl<
379        I: ProgressSender,
380        U: Send + Sync + 'static,
381        F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
382    > std::fmt::Debug for WithFilterMap<I, U, F>
383{
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        f.debug_tuple("FilterWith").field(&self.0).finish()
386    }
387}
388
389impl<
390        I: ProgressSender,
391        U: Send + Sync + 'static,
392        F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
393    > Clone for WithFilterMap<I, U, F>
394{
395    fn clone(&self) -> Self {
396        Self(self.0.clone(), self.1.clone(), PhantomData)
397    }
398}
399
400impl<I: IdGenerator, U, F> IdGenerator for WithFilterMap<I, U, F> {
401    fn new_id(&self) -> u64 {
402        self.0.new_id()
403    }
404}
405
406impl<
407        I: IdGenerator + ProgressSender,
408        U: Send + Sync + 'static,
409        F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
410    > IdGenerator for WithMap<I, U, F>
411{
412    fn new_id(&self) -> u64 {
413        self.0.new_id()
414    }
415}
416
417impl<
418        I: ProgressSender,
419        U: Send + Sync + 'static,
420        F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
421    > ProgressSender for WithFilterMap<I, U, F>
422{
423    type Msg = U;
424
425    async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
426        if let Some(msg) = (self.1)(msg) {
427            self.0.send(msg).await
428        } else {
429            Ok(())
430        }
431    }
432
433    fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
434        if let Some(msg) = (self.1)(msg) {
435            self.0.try_send(msg)
436        } else {
437            Ok(())
438        }
439    }
440
441    fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
442        if let Some(msg) = (self.1)(msg) {
443            self.0.blocking_send(msg)
444        } else {
445            Ok(())
446        }
447    }
448}
449
450/// A progress sender that uses an async channel.
451pub struct AsyncChannelProgressSender<T> {
452    sender: async_channel::Sender<T>,
453    id: std::sync::Arc<std::sync::atomic::AtomicU64>,
454}
455
456impl<T> std::fmt::Debug for AsyncChannelProgressSender<T> {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        f.debug_struct("AsyncChannelProgressSender")
459            .field("id", &self.id)
460            .field("sender", &self.sender)
461            .finish()
462    }
463}
464
465impl<T> Clone for AsyncChannelProgressSender<T> {
466    fn clone(&self) -> Self {
467        Self {
468            sender: self.sender.clone(),
469            id: self.id.clone(),
470        }
471    }
472}
473
474impl<T> AsyncChannelProgressSender<T> {
475    /// Create a new progress sender from an async channel sender.
476    pub fn new(sender: async_channel::Sender<T>) -> Self {
477        Self {
478            sender,
479            id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
480        }
481    }
482
483    /// Returns true if `other` sends on the same `async_channel` channel as `self`.
484    pub fn same_channel(&self, other: &AsyncChannelProgressSender<T>) -> bool {
485        same_channel(&self.sender, &other.sender)
486    }
487}
488
489/// Given a value that is aligned and sized like a pointer, return the value of
490/// the pointer as a usize.
491fn get_as_ptr<T>(value: &T) -> Option<usize> {
492    use std::mem;
493    if mem::size_of::<T>() == std::mem::size_of::<usize>()
494        && mem::align_of::<T>() == mem::align_of::<usize>()
495    {
496        // SAFETY: size and alignment requirements are checked and met
497        unsafe { Some(mem::transmute_copy(value)) }
498    } else {
499        None
500    }
501}
502
503fn same_channel<T>(a: &async_channel::Sender<T>, b: &async_channel::Sender<T>) -> bool {
504    // This relies on async_channel::Sender being just a newtype wrapper around
505    // an Arc<Channel<T>>, so if two senders point to the same channel, the
506    // pointers will be the same.
507    get_as_ptr(a).unwrap() == get_as_ptr(b).unwrap()
508}
509
510impl<T> IdGenerator for AsyncChannelProgressSender<T> {
511    fn new_id(&self) -> u64 {
512        self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
513    }
514}
515
516impl<T: Send + Sync + 'static> ProgressSender for AsyncChannelProgressSender<T> {
517    type Msg = T;
518
519    async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
520        self.sender
521            .send(msg)
522            .await
523            .map_err(|_| ProgressSendError::ReceiverDropped)
524    }
525
526    fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
527        match self.sender.try_send(msg) {
528            Ok(_) => Ok(()),
529            Err(async_channel::TrySendError::Full(_)) => Ok(()),
530            Err(async_channel::TrySendError::Closed(_)) => Err(ProgressSendError::ReceiverDropped),
531        }
532    }
533
534    fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
535        match self.sender.send_blocking(msg) {
536            Ok(_) => Ok(()),
537            Err(_) => Err(ProgressSendError::ReceiverDropped),
538        }
539    }
540}
541
542/// An error that can occur when sending progress messages.
543///
544/// Really the only error that can occur is if the receiver is dropped.
545#[derive(Debug, Clone, thiserror::Error)]
546pub enum ProgressSendError {
547    /// The receiver was dropped.
548    #[error("receiver dropped")]
549    ReceiverDropped,
550}
551
552/// A result type for progress sending.
553pub type ProgressSendResult<T> = std::result::Result<T, ProgressSendError>;
554
555impl From<ProgressSendError> for std::io::Error {
556    fn from(e: ProgressSendError) -> Self {
557        std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
558    }
559}
560
561/// A slice writer that adds a synchronous progress callback.
562///
563/// This wraps any `AsyncSliceWriter`, passes through all operations to the inner writer, and
564/// calls the passed `on_write` callback whenever data is written.
565#[derive(Debug)]
566pub struct ProgressSliceWriter<W, F>(W, F);
567
568impl<W: AsyncSliceWriter, F: FnMut(u64)> ProgressSliceWriter<W, F> {
569    /// Create a new `ProgressSliceWriter` from an inner writer and a progress callback
570    ///
571    /// The `on_write` function is called for each write, with the `offset` as the first and the
572    /// length of the data as the second param.
573    pub fn new(inner: W, on_write: F) -> Self {
574        Self(inner, on_write)
575    }
576
577    /// Return the inner writer
578    pub fn into_inner(self) -> W {
579        self.0
580    }
581}
582
583impl<W: AsyncSliceWriter + 'static, F: FnMut(u64, usize) + 'static> AsyncSliceWriter
584    for ProgressSliceWriter<W, F>
585{
586    async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
587        (self.1)(offset, data.len());
588        self.0.write_bytes_at(offset, data).await
589    }
590
591    async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
592        (self.1)(offset, data.len());
593        self.0.write_at(offset, data).await
594    }
595
596    async fn sync(&mut self) -> io::Result<()> {
597        self.0.sync().await
598    }
599
600    async fn set_len(&mut self, size: u64) -> io::Result<()> {
601        self.0.set_len(size).await
602    }
603}
604
605/// A slice writer that adds a fallible progress callback.
606///
607/// This wraps any `AsyncSliceWriter`, passes through all operations to the inner writer, and
608/// calls the passed `on_write` callback whenever data is written. `on_write` must return an
609/// `io::Result`, and can abort the download by returning an error.
610#[derive(Debug)]
611pub struct FallibleProgressSliceWriter<W, F>(W, F);
612
613impl<W: AsyncSliceWriter, F: Fn(u64, usize) -> io::Result<()> + 'static>
614    FallibleProgressSliceWriter<W, F>
615{
616    /// Create a new `ProgressSliceWriter` from an inner writer and a progress callback
617    ///
618    /// The `on_write` function is called for each write, with the `offset` as the first and the
619    /// length of the data as the second param. `on_write` must return a future which resolves to
620    /// an `io::Result`. If `on_write` returns an error, the download is aborted.
621    pub fn new(inner: W, on_write: F) -> Self {
622        Self(inner, on_write)
623    }
624
625    /// Return the inner writer.
626    pub fn into_inner(self) -> W {
627        self.0
628    }
629}
630
631impl<W: AsyncSliceWriter + 'static, F: Fn(u64, usize) -> io::Result<()> + 'static> AsyncSliceWriter
632    for FallibleProgressSliceWriter<W, F>
633{
634    async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
635        (self.1)(offset, data.len())?;
636        self.0.write_bytes_at(offset, data).await
637    }
638
639    async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
640        (self.1)(offset, data.len())?;
641        self.0.write_at(offset, data).await
642    }
643
644    async fn sync(&mut self) -> io::Result<()> {
645        self.0.sync().await
646    }
647
648    async fn set_len(&mut self, size: u64) -> io::Result<()> {
649        self.0.set_len(size).await
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use std::sync::Arc;
656
657    use super::*;
658
659    #[test]
660    fn get_as_ptr_works() {
661        struct Wrapper(Arc<u64>);
662        let x = Wrapper(Arc::new(1u64));
663        assert_eq!(
664            get_as_ptr(&x).unwrap(),
665            Arc::as_ptr(&x.0) as usize - 2 * std::mem::size_of::<usize>()
666        );
667    }
668
669    #[test]
670    fn get_as_ptr_wrong_use() {
671        struct Wrapper(#[allow(dead_code)] u8);
672        let x = Wrapper(1);
673        assert!(get_as_ptr(&x).is_none());
674    }
675
676    #[test]
677    fn test_sender_is_ptr() {
678        assert_eq!(
679            std::mem::size_of::<usize>(),
680            std::mem::size_of::<async_channel::Sender<u8>>()
681        );
682        assert_eq!(
683            std::mem::align_of::<usize>(),
684            std::mem::align_of::<async_channel::Sender<u8>>()
685        );
686    }
687}