iroh_blobs/downloader/
progress.rs

1use std::{
2    collections::HashMap,
3    sync::{
4        atomic::{AtomicU64, Ordering},
5        Arc,
6    },
7};
8
9use anyhow::anyhow;
10use parking_lot::Mutex;
11
12use super::DownloadKind;
13use crate::{
14    get::{db::DownloadProgress, progress::TransferState},
15    util::progress::{AsyncChannelProgressSender, IdGenerator, ProgressSendError, ProgressSender},
16};
17
18/// The channel that can be used to subscribe to progress updates.
19pub type ProgressSubscriber = AsyncChannelProgressSender<DownloadProgress>;
20
21/// Track the progress of downloads.
22///
23/// This struct allows to create [`ProgressSender`] structs to be passed to
24/// [`crate::get::db::get_to_db`]. Each progress sender can be subscribed to by any number of
25/// [`ProgressSubscriber`] channel senders, which will receive each progress update (if they have
26/// capacity). Additionally, the [`ProgressTracker`] maintains a [`TransferState`] for each
27/// transfer, applying each progress update to update this state. When subscribing to an already
28/// running transfer, the subscriber will receive a [`DownloadProgress::InitialState`] message
29/// containing the state at the time of the subscription, and then receive all further progress
30/// events directly.
31#[derive(Debug, Default)]
32pub struct ProgressTracker {
33    /// Map of shared state for each tracked download.
34    running: HashMap<DownloadKind, Shared>,
35    /// Shared [`IdGenerator`] for all progress senders created by the tracker.
36    id_gen: Arc<AtomicU64>,
37}
38
39impl ProgressTracker {
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Track a new download with a list of initial subscribers.
45    ///
46    /// Note that this should only be called for *new* downloads. If a download for the `kind` is
47    /// already tracked in this [`ProgressTracker`], calling `track` will replace all existing
48    /// state and subscribers (equal to calling [`Self::remove`] first).
49    pub fn track(
50        &mut self,
51        kind: DownloadKind,
52        subscribers: impl IntoIterator<Item = ProgressSubscriber>,
53    ) -> BroadcastProgressSender {
54        let inner = Inner {
55            subscribers: subscribers.into_iter().collect(),
56            state: TransferState::new(kind.hash()),
57        };
58        let shared = Arc::new(Mutex::new(inner));
59        self.running.insert(kind, Arc::clone(&shared));
60        let id_gen = Arc::clone(&self.id_gen);
61        BroadcastProgressSender { shared, id_gen }
62    }
63
64    /// Subscribe to a tracked download.
65    ///
66    /// Will return an error if `kind` is not yet tracked.
67    pub async fn subscribe(
68        &mut self,
69        kind: DownloadKind,
70        sender: ProgressSubscriber,
71    ) -> anyhow::Result<()> {
72        let initial_msg = self
73            .running
74            .get_mut(&kind)
75            .ok_or_else(|| anyhow!("state for download {kind:?} not found"))?
76            .lock()
77            .subscribe(sender.clone());
78        sender.send(initial_msg).await?;
79        Ok(())
80    }
81
82    /// Unsubscribe `sender` from `kind`.
83    pub fn unsubscribe(&mut self, kind: &DownloadKind, sender: &ProgressSubscriber) {
84        if let Some(shared) = self.running.get_mut(kind) {
85            shared.lock().unsubscribe(sender)
86        }
87    }
88
89    /// Remove all state for a download.
90    pub fn remove(&mut self, kind: &DownloadKind) {
91        self.running.remove(kind);
92    }
93}
94
95type Shared = Arc<Mutex<Inner>>;
96
97#[derive(Debug)]
98struct Inner {
99    subscribers: Vec<ProgressSubscriber>,
100    state: TransferState,
101}
102
103impl Inner {
104    fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress {
105        let msg = DownloadProgress::InitialState(self.state.clone());
106        self.subscribers.push(subscriber);
107        msg
108    }
109
110    fn unsubscribe(&mut self, sender: &ProgressSubscriber) {
111        self.subscribers.retain(|s| !s.same_channel(sender));
112    }
113
114    fn on_progress(&mut self, progress: DownloadProgress) {
115        self.state.on_progress(progress);
116    }
117}
118
119#[derive(Debug, Clone)]
120pub struct BroadcastProgressSender {
121    shared: Shared,
122    id_gen: Arc<AtomicU64>,
123}
124
125impl IdGenerator for BroadcastProgressSender {
126    fn new_id(&self) -> u64 {
127        self.id_gen.fetch_add(1, Ordering::SeqCst)
128    }
129}
130
131impl ProgressSender for BroadcastProgressSender {
132    type Msg = DownloadProgress;
133
134    async fn send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
135        // making sure that the lock is not held across an await point.
136        let futs = {
137            let mut inner = self.shared.lock();
138            inner.on_progress(msg.clone());
139            let futs = inner
140                .subscribers
141                .iter_mut()
142                .map(|sender| {
143                    let sender = sender.clone();
144                    let msg = msg.clone();
145                    async move {
146                        match sender.send(msg).await {
147                            Ok(()) => None,
148                            Err(ProgressSendError::ReceiverDropped) => Some(sender),
149                        }
150                    }
151                })
152                .collect::<Vec<_>>();
153            drop(inner);
154            futs
155        };
156
157        let failed_senders = futures_buffered::join_all(futs).await;
158        // remove senders where the receiver is dropped
159        if failed_senders.iter().any(|s| s.is_some()) {
160            let mut inner = self.shared.lock();
161            for sender in failed_senders.into_iter().flatten() {
162                inner.unsubscribe(&sender);
163            }
164            drop(inner);
165        }
166        Ok(())
167    }
168
169    fn try_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
170        let mut inner = self.shared.lock();
171        inner.on_progress(msg.clone());
172        // remove senders where the receiver is dropped
173        inner
174            .subscribers
175            .retain_mut(|sender| match sender.try_send(msg.clone()) {
176                Err(ProgressSendError::ReceiverDropped) => false,
177                Ok(()) => true,
178            });
179        Ok(())
180    }
181
182    fn blocking_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
183        let mut inner = self.shared.lock();
184        inner.on_progress(msg.clone());
185        // remove senders where the receiver is dropped
186        inner
187            .subscribers
188            .retain_mut(|sender| match sender.blocking_send(msg.clone()) {
189                Err(ProgressSendError::ReceiverDropped) => false,
190                Ok(()) => true,
191            });
192        Ok(())
193    }
194}