1use std::{future::Future, io, marker::PhantomData, ops::Deref, sync::Arc};
5
6use bytes::Bytes;
7use iroh_io::AsyncSliceWriter;
8
9pub trait ProgressSender: std::fmt::Debug + Clone + Send + Sync + 'static {
87 type Msg: Send + Sync + 'static;
89
90 #[must_use]
94 fn send(&self, msg: Self::Msg) -> impl Future<Output = ProgressSendResult<()>> + Send;
95
96 fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
100
101 fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()>;
105
106 fn with_map<U: Send + Sync + 'static, F: Fn(U) -> Self::Msg + Send + Sync + Clone + 'static>(
108 self,
109 f: F,
110 ) -> WithMap<Self, U, F> {
111 WithMap(self, f, PhantomData)
112 }
113
114 fn with_filter_map<
116 U: Send + Sync + 'static,
117 F: Fn(U) -> Option<Self::Msg> + Send + Sync + Clone + 'static,
118 >(
119 self,
120 f: F,
121 ) -> WithFilterMap<Self, U, F> {
122 WithFilterMap(self, f, PhantomData)
123 }
124
125 fn boxed(self) -> BoxedProgressSender<Self::Msg>
127 where
128 Self: IdGenerator,
129 {
130 BoxedProgressSender(Arc::new(BoxableProgressSenderWrapper(self)))
131 }
132}
133
134pub struct BoxedProgressSender<T>(Arc<dyn BoxableProgressSender<T>>);
136
137impl<T> Clone for BoxedProgressSender<T> {
138 fn clone(&self) -> Self {
139 Self(self.0.clone())
140 }
141}
142
143impl<T> std::fmt::Debug for BoxedProgressSender<T> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_tuple("BoxedProgressSender").field(&self.0).finish()
146 }
147}
148
149type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + Send + 'a>>;
150
151trait BoxableProgressSender<T>: IdGenerator + std::fmt::Debug + Send + Sync + 'static {
153 #[must_use]
157 fn send(&self, msg: T) -> BoxFuture<'_, ProgressSendResult<()>>;
158
159 fn try_send(&self, msg: T) -> ProgressSendResult<()>;
163
164 fn blocking_send(&self, msg: T) -> ProgressSendResult<()>;
168}
169
170impl<I: ProgressSender + IdGenerator> BoxableProgressSender<I::Msg>
171 for BoxableProgressSenderWrapper<I>
172{
173 fn send(&self, msg: I::Msg) -> BoxFuture<'_, ProgressSendResult<()>> {
174 Box::pin(self.0.send(msg))
175 }
176
177 fn try_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
178 self.0.try_send(msg)
179 }
180
181 fn blocking_send(&self, msg: I::Msg) -> ProgressSendResult<()> {
182 self.0.blocking_send(msg)
183 }
184}
185
186#[derive(Debug)]
188#[repr(transparent)]
189struct BoxableProgressSenderWrapper<I>(I);
190
191impl<I: ProgressSender + IdGenerator> IdGenerator for BoxableProgressSenderWrapper<I> {
192 fn new_id(&self) -> u64 {
193 self.0.new_id()
194 }
195}
196
197impl<T: Send + Sync + 'static> IdGenerator for Arc<dyn BoxableProgressSender<T>> {
198 fn new_id(&self) -> u64 {
199 self.deref().new_id()
200 }
201}
202
203impl<T: Send + Sync + 'static> ProgressSender for Arc<dyn BoxableProgressSender<T>> {
204 type Msg = T;
205
206 fn send(&self, msg: T) -> impl Future<Output = ProgressSendResult<()>> + Send {
207 self.deref().send(msg)
208 }
209
210 fn try_send(&self, msg: T) -> ProgressSendResult<()> {
211 self.deref().try_send(msg)
212 }
213
214 fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
215 self.deref().blocking_send(msg)
216 }
217}
218
219impl<T: Send + Sync + 'static> IdGenerator for BoxedProgressSender<T> {
220 fn new_id(&self) -> u64 {
221 self.0.new_id()
222 }
223}
224
225impl<T: Send + Sync + 'static> ProgressSender for BoxedProgressSender<T> {
226 type Msg = T;
227
228 async fn send(&self, msg: T) -> ProgressSendResult<()> {
229 self.0.send(msg).await
230 }
231
232 fn try_send(&self, msg: T) -> ProgressSendResult<()> {
233 self.0.try_send(msg)
234 }
235
236 fn blocking_send(&self, msg: T) -> ProgressSendResult<()> {
237 self.0.blocking_send(msg)
238 }
239}
240
241impl<T: ProgressSender> ProgressSender for Option<T> {
242 type Msg = T::Msg;
243
244 async fn send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
245 if let Some(inner) = self {
246 inner.send(msg).await
247 } else {
248 Ok(())
249 }
250 }
251
252 fn try_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
253 if let Some(inner) = self {
254 inner.try_send(msg)
255 } else {
256 Ok(())
257 }
258 }
259
260 fn blocking_send(&self, msg: Self::Msg) -> ProgressSendResult<()> {
261 if let Some(inner) = self {
262 inner.blocking_send(msg)
263 } else {
264 Ok(())
265 }
266 }
267}
268
269pub trait IdGenerator {
271 fn new_id(&self) -> u64;
273}
274
275pub struct IgnoreProgressSender<T>(PhantomData<T>);
277
278impl<T> Default for IgnoreProgressSender<T> {
279 fn default() -> Self {
280 Self(PhantomData)
281 }
282}
283
284impl<T> Clone for IgnoreProgressSender<T> {
285 fn clone(&self) -> Self {
286 Self(PhantomData)
287 }
288}
289
290impl<T> std::fmt::Debug for IgnoreProgressSender<T> {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 f.debug_struct("IgnoreProgressSender").finish()
293 }
294}
295
296impl<T: Send + Sync + 'static> ProgressSender for IgnoreProgressSender<T> {
297 type Msg = T;
298
299 async fn send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
300 Ok(())
301 }
302
303 fn try_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
304 Ok(())
305 }
306
307 fn blocking_send(&self, _msg: T) -> std::result::Result<(), ProgressSendError> {
308 Ok(())
309 }
310}
311
312impl<T> IdGenerator for IgnoreProgressSender<T> {
313 fn new_id(&self) -> u64 {
314 0
315 }
316}
317
318pub struct WithMap<
322 I: ProgressSender,
323 U: Send + Sync + 'static,
324 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
325>(I, F, PhantomData<U>);
326
327impl<
328 I: ProgressSender,
329 U: Send + Sync + 'static,
330 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
331 > std::fmt::Debug for WithMap<I, U, F>
332{
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_tuple("With").field(&self.0).finish()
335 }
336}
337
338impl<
339 I: ProgressSender,
340 U: Send + Sync + 'static,
341 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
342 > Clone for WithMap<I, U, F>
343{
344 fn clone(&self) -> Self {
345 Self(self.0.clone(), self.1.clone(), PhantomData)
346 }
347}
348
349impl<
350 I: ProgressSender,
351 U: Send + Sync + 'static,
352 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
353 > ProgressSender for WithMap<I, U, F>
354{
355 type Msg = U;
356
357 async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
358 let msg = (self.1)(msg);
359 self.0.send(msg).await
360 }
361
362 fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
363 let msg = (self.1)(msg);
364 self.0.try_send(msg)
365 }
366
367 fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
368 let msg = (self.1)(msg);
369 self.0.blocking_send(msg)
370 }
371}
372
373pub struct WithFilterMap<I, U, F>(I, F, PhantomData<U>);
377
378impl<
379 I: ProgressSender,
380 U: Send + Sync + 'static,
381 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
382 > std::fmt::Debug for WithFilterMap<I, U, F>
383{
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 f.debug_tuple("FilterWith").field(&self.0).finish()
386 }
387}
388
389impl<
390 I: ProgressSender,
391 U: Send + Sync + 'static,
392 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
393 > Clone for WithFilterMap<I, U, F>
394{
395 fn clone(&self) -> Self {
396 Self(self.0.clone(), self.1.clone(), PhantomData)
397 }
398}
399
400impl<I: IdGenerator, U, F> IdGenerator for WithFilterMap<I, U, F> {
401 fn new_id(&self) -> u64 {
402 self.0.new_id()
403 }
404}
405
406impl<
407 I: IdGenerator + ProgressSender,
408 U: Send + Sync + 'static,
409 F: Fn(U) -> I::Msg + Clone + Send + Sync + 'static,
410 > IdGenerator for WithMap<I, U, F>
411{
412 fn new_id(&self) -> u64 {
413 self.0.new_id()
414 }
415}
416
417impl<
418 I: ProgressSender,
419 U: Send + Sync + 'static,
420 F: Fn(U) -> Option<I::Msg> + Clone + Send + Sync + 'static,
421 > ProgressSender for WithFilterMap<I, U, F>
422{
423 type Msg = U;
424
425 async fn send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
426 if let Some(msg) = (self.1)(msg) {
427 self.0.send(msg).await
428 } else {
429 Ok(())
430 }
431 }
432
433 fn try_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
434 if let Some(msg) = (self.1)(msg) {
435 self.0.try_send(msg)
436 } else {
437 Ok(())
438 }
439 }
440
441 fn blocking_send(&self, msg: U) -> std::result::Result<(), ProgressSendError> {
442 if let Some(msg) = (self.1)(msg) {
443 self.0.blocking_send(msg)
444 } else {
445 Ok(())
446 }
447 }
448}
449
450pub struct AsyncChannelProgressSender<T> {
452 sender: async_channel::Sender<T>,
453 id: std::sync::Arc<std::sync::atomic::AtomicU64>,
454}
455
456impl<T> std::fmt::Debug for AsyncChannelProgressSender<T> {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 f.debug_struct("AsyncChannelProgressSender")
459 .field("id", &self.id)
460 .field("sender", &self.sender)
461 .finish()
462 }
463}
464
465impl<T> Clone for AsyncChannelProgressSender<T> {
466 fn clone(&self) -> Self {
467 Self {
468 sender: self.sender.clone(),
469 id: self.id.clone(),
470 }
471 }
472}
473
474impl<T> AsyncChannelProgressSender<T> {
475 pub fn new(sender: async_channel::Sender<T>) -> Self {
477 Self {
478 sender,
479 id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0)),
480 }
481 }
482
483 pub fn same_channel(&self, other: &AsyncChannelProgressSender<T>) -> bool {
485 same_channel(&self.sender, &other.sender)
486 }
487}
488
489fn get_as_ptr<T>(value: &T) -> Option<usize> {
492 use std::mem;
493 if mem::size_of::<T>() == std::mem::size_of::<usize>()
494 && mem::align_of::<T>() == mem::align_of::<usize>()
495 {
496 unsafe { Some(mem::transmute_copy(value)) }
498 } else {
499 None
500 }
501}
502
503fn same_channel<T>(a: &async_channel::Sender<T>, b: &async_channel::Sender<T>) -> bool {
504 get_as_ptr(a).unwrap() == get_as_ptr(b).unwrap()
508}
509
510impl<T> IdGenerator for AsyncChannelProgressSender<T> {
511 fn new_id(&self) -> u64 {
512 self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
513 }
514}
515
516impl<T: Send + Sync + 'static> ProgressSender for AsyncChannelProgressSender<T> {
517 type Msg = T;
518
519 async fn send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
520 self.sender
521 .send(msg)
522 .await
523 .map_err(|_| ProgressSendError::ReceiverDropped)
524 }
525
526 fn try_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
527 match self.sender.try_send(msg) {
528 Ok(_) => Ok(()),
529 Err(async_channel::TrySendError::Full(_)) => Ok(()),
530 Err(async_channel::TrySendError::Closed(_)) => Err(ProgressSendError::ReceiverDropped),
531 }
532 }
533
534 fn blocking_send(&self, msg: Self::Msg) -> std::result::Result<(), ProgressSendError> {
535 match self.sender.send_blocking(msg) {
536 Ok(_) => Ok(()),
537 Err(_) => Err(ProgressSendError::ReceiverDropped),
538 }
539 }
540}
541
542#[derive(Debug, Clone, thiserror::Error)]
546pub enum ProgressSendError {
547 #[error("receiver dropped")]
549 ReceiverDropped,
550}
551
552pub type ProgressSendResult<T> = std::result::Result<T, ProgressSendError>;
554
555impl From<ProgressSendError> for std::io::Error {
556 fn from(e: ProgressSendError) -> Self {
557 std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)
558 }
559}
560
561#[derive(Debug)]
566pub struct ProgressSliceWriter<W, F>(W, F);
567
568impl<W: AsyncSliceWriter, F: FnMut(u64)> ProgressSliceWriter<W, F> {
569 pub fn new(inner: W, on_write: F) -> Self {
574 Self(inner, on_write)
575 }
576
577 pub fn into_inner(self) -> W {
579 self.0
580 }
581}
582
583impl<W: AsyncSliceWriter + 'static, F: FnMut(u64, usize) + 'static> AsyncSliceWriter
584 for ProgressSliceWriter<W, F>
585{
586 async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
587 (self.1)(offset, data.len());
588 self.0.write_bytes_at(offset, data).await
589 }
590
591 async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
592 (self.1)(offset, data.len());
593 self.0.write_at(offset, data).await
594 }
595
596 async fn sync(&mut self) -> io::Result<()> {
597 self.0.sync().await
598 }
599
600 async fn set_len(&mut self, size: u64) -> io::Result<()> {
601 self.0.set_len(size).await
602 }
603}
604
605#[derive(Debug)]
611pub struct FallibleProgressSliceWriter<W, F>(W, F);
612
613impl<W: AsyncSliceWriter, F: Fn(u64, usize) -> io::Result<()> + 'static>
614 FallibleProgressSliceWriter<W, F>
615{
616 pub fn new(inner: W, on_write: F) -> Self {
622 Self(inner, on_write)
623 }
624
625 pub fn into_inner(self) -> W {
627 self.0
628 }
629}
630
631impl<W: AsyncSliceWriter + 'static, F: Fn(u64, usize) -> io::Result<()> + 'static> AsyncSliceWriter
632 for FallibleProgressSliceWriter<W, F>
633{
634 async fn write_bytes_at(&mut self, offset: u64, data: Bytes) -> io::Result<()> {
635 (self.1)(offset, data.len())?;
636 self.0.write_bytes_at(offset, data).await
637 }
638
639 async fn write_at(&mut self, offset: u64, data: &[u8]) -> io::Result<()> {
640 (self.1)(offset, data.len())?;
641 self.0.write_at(offset, data).await
642 }
643
644 async fn sync(&mut self) -> io::Result<()> {
645 self.0.sync().await
646 }
647
648 async fn set_len(&mut self, size: u64) -> io::Result<()> {
649 self.0.set_len(size).await
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use std::sync::Arc;
656
657 use super::*;
658
659 #[test]
660 fn get_as_ptr_works() {
661 struct Wrapper(Arc<u64>);
662 let x = Wrapper(Arc::new(1u64));
663 assert_eq!(
664 get_as_ptr(&x).unwrap(),
665 Arc::as_ptr(&x.0) as usize - 2 * std::mem::size_of::<usize>()
666 );
667 }
668
669 #[test]
670 fn get_as_ptr_wrong_use() {
671 struct Wrapper(#[allow(dead_code)] u8);
672 let x = Wrapper(1);
673 assert!(get_as_ptr(&x).is_none());
674 }
675
676 #[test]
677 fn test_sender_is_ptr() {
678 assert_eq!(
679 std::mem::size_of::<usize>(),
680 std::mem::size_of::<async_channel::Sender<u8>>()
681 );
682 assert_eq!(
683 std::mem::align_of::<usize>(),
684 std::mem::align_of::<async_channel::Sender<u8>>()
685 );
686 }
687}