1use std::{future::Future, io, marker::PhantomData, ops::Deref, sync::Arc};
5
6use bytes::Bytes;
7use iroh_io::AsyncSliceWriter;
8
9pub trait ProgressSender: std::fmt::Debug + Clone + Send + Sync + 'static {
87 type Msg: Send + Sync + 'static;
89
90 #[must_use]
94 fn send(&self, msg: Self::Msg) -> impl Future<Output = ProgressSendResult<()>> + Send;
95
96 fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
100
101 fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
105
106 fn with_map<U: Send + Sync + 'static, F: Fn(U) -> Self::Msg + Send + Sync + Clone + 'static>(
108 self,
109 f: F,
110 ) -> WithMap<Self, U, F> {
111 WithMap(self, f, PhantomData)
112 }
113
114 fn with_filter_map<
116 U: Send + Sync + 'static,
117 F: Fn(U) -> Option<Self::Msg> + Send + Sync + Clone + 'static,
118 >(
119 self,
120 f: F,
121 ) -> WithFilterMap<Self, U, F> {
122 WithFilterMap(self, f, PhantomData)
123 }
124
125 fn boxed(self) -> BoxedProgressSender<Self::Msg>
127 where
128 Self: IdGenerator,
129 {
130 BoxedProgressSender(Arc::new(BoxableProgressSenderWrapper(self)))
131 }
132}
133
134pub struct BoxedProgressSender<T>(Arc<dyn BoxableProgressSender<T>>);
136
137impl<T> Clone for BoxedProgressSender<T> {
138 fn clone(&self) -> Self {
139 Self(self.0.clone())
140 }
141}
142
143impl<T> std::fmt::Debug for BoxedProgressSender<T> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_tuple("BoxedProgressSender").field(&self.0).finish()
146 }
147}
148
149type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + Send + 'a>>;
150
151trait BoxableProgressSender<T>: IdGenerator + std::fmt::Debug + Send + Sync + 'static {
153 #[must_use]
157 fn send(&self, msg: T) -> BoxFuture<'_, ProgressSendResult<()>>;
158
159 fn try_send(&self, msg: T) -> ProgressSendResult<()>;
163
164 fn blocking_send(&self, msg: T) -> ProgressSendResult<()>;
168}
169
170impl<I: ProgressSender + IdGenerator> BoxableProgressSender<I::Msg>
171 for BoxableProgressSenderWrapper<I>
172{
173 fn send(&self, msg: I::Msg) -> BoxFuture<'_, ProgressSendResult<()>> {
174 Box::pin(self.0.send(msg))
175 }
176
177 fn try_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
178 self.0.try_send(msg)
179 }
180
181 fn blocking_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
182 self.0.blocking_send(msg)
183 }
184}
185
186#[derive(Debug)]
188#[repr(transparent)]
189struct BoxableProgressSenderWrapper<I>(I);
190
191impl<I: ProgressSender + IdGenerator> IdGenerator for BoxableProgressSenderWrapper<I> {
192 fn new_id(&self) -> u64 {
193 self.0.new_id()
194 }
195}
196
197impl<T: Send + Sync + 'static> IdGenerator for Arc<dyn BoxableProgressSender<T>> {
198 fn new_id(&self) -> u64 {
199 self.deref().new_id()
200 }
201}
202
203impl<T: Send + Sync + 'static> ProgressSender for Arc<dyn BoxableProgressSender<T>> {
204 type Msg = T;
205
206 fn send(&self, msg: T) -> impl Future<Output = ProgressSendResult<()>> + Send {
207 self.deref().send(msg)
208 }
209
210 fn try_send(&self, msg: T) -> ProgressSendResult<()> {
211 self.deref().try_send(msg)
212 }
213
214 fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
215 self.deref().blocking_send(msg)
216 }
217}
218
219impl<T: Send + Sync + 'static> IdGenerator for BoxedProgressSender<T> {
220 fn new_id(&self) -> u64 {
221 self.0.new_id()
222 }
223}
224
225impl<T: Send + Sync + 'static> ProgressSender for BoxedProgressSender<T> {
226 type Msg = T;
227
228 async fn send(&self, msg: T) -> ProgressSendResult<()> {
229 self.0.send(msg).await
230 }
231
232 fn try_send(&self, msg: T) -> ProgressSendResult<()> {
233 self.0.try_send(msg)
234 }
235
236 fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
237 self.0.blocking_send(msg)
238 }
239}
240
241impl<T: ProgressSender> ProgressSender for Option<T> {
242 type Msg = T::Msg;
243
244 async fn send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
245 if let Some(inner) = self {
246 inner.send(msg).await
247 } else {
248 Ok(())
249 }
250 }
251
252 fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
253 if let Some(inner) = self {
254 inner.try_send(msg)
255 } else {
256 Ok(())
257 }
258 }
259
260 fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
261 if let Some(inner) = self {
262 inner.blocking_send(msg)
263 } else {
264 Ok(())
265 }
266 }
267}
268
269pub trait IdGenerator {
271 fn new_id(&self) -> u64;
273}
274
275pub struct IgnoreProgressSender<T>(PhantomData<T>);
277
278impl<T> Default for IgnoreProgressSender<T> {
279 fn default() -> Self {
280 Self(PhantomData)
281 }
282}
283
284impl<T> Clone for IgnoreProgressSender<T> {
285 fn clone(&self) -> Self {
286 Self(PhantomData)
287 }
288}
289
290impl<T> std::fmt::Debug for IgnoreProgressSender<T> {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 f.debug_struct("IgnoreProgressSender").finish()
293 }
294}
295
296impl<T: Send + Sync + 'static> ProgressSender for IgnoreProgressSender<T> {
297 type Msg = T;
298
299 async fn send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
300 Ok(())
301 }
302
303 fn try_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
304 Ok(())
305 }
306
307 fn blocking_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
308 Ok(())
309 }
310}
311
312impl<T> IdGenerator for IgnoreProgressSender<T> {
313 fn new_id(&self) -> u64 {
314 0
315 }
316}
317
318pub struct WithMap<
322 I: ProgressSender,
323 U: Send + Sync + 'static,
324 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
325>(I, F, PhantomData<U>);
326
327impl<
328 I: ProgressSender,
329 U: Send + Sync + 'static,
330 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
331 > std::fmt::Debug for WithMap<I, U, F>
332{
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_tuple("With").field(&self.0).finish()
335 }
336}
337
338impl<
339 I: ProgressSender,
340 U: Send + Sync + 'static,
341 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
342 > Clone for WithMap<I, U, F>
343{
344 fn clone(&self) -> Self {
345 Self(self.0.clone(), self.1.clone(), PhantomData)
346 }
347}
348
349impl<
350 I: ProgressSender,
351 U: Send + Sync + 'static,
352 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
353 > ProgressSender for WithMap<I, U, F>
354{
355 type Msg = U;
356
357 async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
358 let msg = (self.1)(msg);
359 self.0.send(msg).await
360 }
361
362 fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
363 let msg = (self.1)(msg);
364 self.0.try_send(msg)
365 }
366
367 fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
368 let msg = (self.1)(msg);
369 self.0.blocking_send(msg)
370 }
371}
372
373pub struct WithFilterMap<I, U, F>(I, F, PhantomData<U>);
377
378impl<
379 I: ProgressSender,
380 U: Send + Sync + 'static,
381 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
382 > std::fmt::Debug for WithFilterMap<I, U, F>
383{
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 f.debug_tuple("FilterWith").field(&self.0).finish()
386 }
387}
388
389impl<
390 I: ProgressSender,
391 U: Send + Sync + 'static,
392 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
393 > Clone for WithFilterMap<I, U, F>
394{
395 fn clone(&self) -> Self {
396 Self(self.0.clone(), self.1.clone(), PhantomData)
397 }
398}
399
400impl<I: IdGenerator, U, F> IdGenerator for WithFilterMap<I, U, F> {
401 fn new_id(&self) -> u64 {
402 self.0.new_id()
403 }
404}
405
406impl<
407 I: IdGenerator + ProgressSender,
408 U: Send + Sync + 'static,
409 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
410 > IdGenerator for WithMap<I, U, F>
411{
412 fn new_id(&self) -> u64 {
413 self.0.new_id()
414 }
415}
416
417impl<
418 I: ProgressSender,
419 U: Send + Sync + 'static,
420 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
421 > ProgressSender for WithFilterMap<I, U, F>
422{
423 type Msg = U;
424
425 async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
426 if let Some(msg) = (self.1)(msg) {
427 self.0.send(msg).await
428 } else {
429 Ok(())
430 }
431 }
432
433 fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
434 if let Some(msg) = (self.1)(msg) {
435 self.0.try_send(msg)
436 } else {
437 Ok(())
438 }
439 }
440
441 fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
442 if let Some(msg) = (self.1)(msg) {
443 self.0.blocking_send(msg)
444 } else {
445 Ok(())
446 }
447 }
448}
449
450pub struct FlumeProgressSender<T> {
452 sender: flume::Sender<T>,
453 id: std::sync::Arc<std::sync::atomic::AtomicU64>,
454}
455
456impl<T> std::fmt::Debug for FlumeProgressSender<T> {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 f.debug_struct("FlumeProgressSender")
459 .field("id", &self.id)
460 .field("sender", &self.sender)
461 .finish()
462 }
463}
464
465impl<T> Clone for FlumeProgressSender<T> {
466 fn clone(&self) -> Self {
467 Self {
468 sender: self.sender.clone(),
469 id: self.id.clone(),
470 }
471 }
472}
473
474impl<T> FlumeProgressSender<T> {
475 pub fn new(sender: flume::Sender<T>) -> Self {
477 Self {
478 sender,
479 id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
480 }
481 }
482
483 pub fn same_channel(&self, other: &FlumeProgressSender<T>) -> bool {
485 self.sender.same_channel(&other.sender)
486 }
487}
488
489impl<T> IdGenerator for FlumeProgressSender<T> {
490 fn new_id(&self) -> u64 {
491 self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
492 }
493}
494
495impl<T: Send + Sync + 'static> ProgressSender for FlumeProgressSender<T> {
496 type Msg = T;
497
498 async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
499 self.sender
500 .send_async(msg)
501 .await
502 .map_err(|_| ProgressSendError::ReceiverDropped)
503 }
504
505 fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
506 match self.sender.try_send(msg) {
507 Ok(_) => Ok(()),
508 Err(flume::TrySendError::Full(_)) => Ok(()),
509 Err(flume::TrySendError::Disconnected(_)) => Err(ProgressSendError::ReceiverDropped),
510 }
511 }
512
513 fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
514 match self.sender.send(msg) {
515 Ok(_) => Ok(()),
516 Err(_) => Err(ProgressSendError::ReceiverDropped),
517 }
518 }
519}
520
521#[derive(Debug, Clone, thiserror::Error)]
525pub enum ProgressSendError {
526 #[error("receiver dropped")]
528 ReceiverDropped,
529}
530
531pub type ProgressSendResult<T> = std::result::Result<T, ProgressSendError>;
533
534impl From<ProgressSendError> for std::io::Error {
535 fn from(e: ProgressSendError) -> Self {
536 std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
537 }
538}
539
540#[derive(Debug)]
545pub struct ProgressSliceWriter<W, F>(W, F);
546
547impl<W: AsyncSliceWriter, F: FnMut(u64)> ProgressSliceWriter<W, F> {
548 pub fn new(inner: W, on_write: F) -> Self {
553 Self(inner, on_write)
554 }
555
556 pub fn into_inner(self) -> W {
558 self.0
559 }
560}
561
562impl<W: AsyncSliceWriter + 'static, F: FnMut(u64, usize) + 'static> AsyncSliceWriter
563 for ProgressSliceWriter<W, F>
564{
565 async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
566 (self.1)(offset, data.len());
567 self.0.write_bytes_at(offset, data).await
568 }
569
570 async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
571 (self.1)(offset, data.len());
572 self.0.write_at(offset, data).await
573 }
574
575 async fn sync(&mut self) -> io::Result<()> {
576 self.0.sync().await
577 }
578
579 async fn set_len(&mut self, size: u64) -> io::Result<()> {
580 self.0.set_len(size).await
581 }
582}
583
584#[derive(Debug)]
590pub struct FallibleProgressSliceWriter<W, F>(W, F);
591
592impl<W: AsyncSliceWriter, F: Fn(u64, usize) -> io::Result<()> + 'static>
593 FallibleProgressSliceWriter<W, F>
594{
595 pub fn new(inner: W, on_write: F) -> Self {
601 Self(inner, on_write)
602 }
603
604 pub fn into_inner(self) -> W {
606 self.0
607 }
608}
609
610impl<W: AsyncSliceWriter + 'static, F: Fn(u64, usize) -> io::Result<()> + 'static> AsyncSliceWriter
611 for FallibleProgressSliceWriter<W, F>
612{
613 async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
614 (self.1)(offset, data.len())?;
615 self.0.write_bytes_at(offset, data).await
616 }
617
618 async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
619 (self.1)(offset, data.len())?;
620 self.0.write_at(offset, data).await
621 }
622
623 async fn sync(&mut self) -> io::Result<()> {
624 self.0.sync().await
625 }
626
627 async fn set_len(&mut self, size: u64) -> io::Result<()> {
628 self.0.set_len(size).await
629 }
630}