iroh_blobs/downloader/
progress.rs1use std::{
2 collections::HashMap,
3 sync::{
4 atomic::{AtomicU64, Ordering},
5 Arc,
6 },
7};
8
9use anyhow::anyhow;
10use parking_lot::Mutex;
11
12use super::DownloadKind;
13use crate::{
14 get::{db::DownloadProgress, progress::TransferState},
15 util::progress::{AsyncChannelProgressSender, IdGenerator, ProgressSendError, ProgressSender},
16};
17
18pub type ProgressSubscriber = AsyncChannelProgressSender<DownloadProgress>;
20
21#[derive(Debug, Default)]
32pub struct ProgressTracker {
33 running: HashMap<DownloadKind, Shared>,
35 id_gen: Arc<AtomicU64>,
37}
38
39impl ProgressTracker {
40 pub fn new() -> Self {
41 Self::default()
42 }
43
44 pub fn track(
50 &mut self,
51 kind: DownloadKind,
52 subscribers: impl IntoIterator<Item = ProgressSubscriber>,
53 ) -> BroadcastProgressSender {
54 let inner = Inner {
55 subscribers: subscribers.into_iter().collect(),
56 state: TransferState::new(kind.hash()),
57 };
58 let shared = Arc::new(Mutex::new(inner));
59 self.running.insert(kind, Arc::clone(&shared));
60 let id_gen = Arc::clone(&self.id_gen);
61 BroadcastProgressSender { shared, id_gen }
62 }
63
64 pub async fn subscribe(
68 &mut self,
69 kind: DownloadKind,
70 sender: ProgressSubscriber,
71 ) -> anyhow::Result<()> {
72 let initial_msg = self
73 .running
74 .get_mut(&kind)
75 .ok_or_else(|| anyhow!("state for download {kind:?} not found"))?
76 .lock()
77 .subscribe(sender.clone());
78 sender.send(initial_msg).await?;
79 Ok(())
80 }
81
82 pub fn unsubscribe(&mut self, kind: &DownloadKind, sender: &ProgressSubscriber) {
84 if let Some(shared) = self.running.get_mut(kind) {
85 shared.lock().unsubscribe(sender)
86 }
87 }
88
89 pub fn remove(&mut self, kind: &DownloadKind) {
91 self.running.remove(kind);
92 }
93}
94
95type Shared = Arc<Mutex<Inner>>;
96
97#[derive(Debug)]
98struct Inner {
99 subscribers: Vec<ProgressSubscriber>,
100 state: TransferState,
101}
102
103impl Inner {
104 fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress {
105 let msg = DownloadProgress::InitialState(self.state.clone());
106 self.subscribers.push(subscriber);
107 msg
108 }
109
110 fn unsubscribe(&mut self, sender: &ProgressSubscriber) {
111 self.subscribers.retain(|s| !s.same_channel(sender));
112 }
113
114 fn on_progress(&mut self, progress: DownloadProgress) {
115 self.state.on_progress(progress);
116 }
117}
118
119#[derive(Debug, Clone)]
120pub struct BroadcastProgressSender {
121 shared: Shared,
122 id_gen: Arc<AtomicU64>,
123}
124
125impl IdGenerator for BroadcastProgressSender {
126 fn new_id(&self) -> u64 {
127 self.id_gen.fetch_add(1, Ordering::SeqCst)
128 }
129}
130
131impl ProgressSender for BroadcastProgressSender {
132 type Msg = DownloadProgress;
133
134 async fn send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
135 let futs = {
137 let mut inner = self.shared.lock();
138 inner.on_progress(msg.clone());
139 let futs = inner
140 .subscribers
141 .iter_mut()
142 .map(|sender| {
143 let sender = sender.clone();
144 let msg = msg.clone();
145 async move {
146 match sender.send(msg).await {
147 Ok(()) => None,
148 Err(ProgressSendError::ReceiverDropped) => Some(sender),
149 }
150 }
151 })
152 .collect::<Vec<_>>();
153 drop(inner);
154 futs
155 };
156
157 let failed_senders = futures_buffered::join_all(futs).await;
158 if failed_senders.iter().any(|s| s.is_some()) {
160 let mut inner = self.shared.lock();
161 for sender in failed_senders.into_iter().flatten() {
162 inner.unsubscribe(&sender);
163 }
164 drop(inner);
165 }
166 Ok(())
167 }
168
169 fn try_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
170 let mut inner = self.shared.lock();
171 inner.on_progress(msg.clone());
172 inner
174 .subscribers
175 .retain_mut(|sender| match sender.try_send(msg.clone()) {
176 Err(ProgressSendError::ReceiverDropped) => false,
177 Ok(()) => true,
178 });
179 Ok(())
180 }
181
182 fn blocking_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
183 let mut inner = self.shared.lock();
184 inner.on_progress(msg.clone());
185 inner
187 .subscribers
188 .retain_mut(|sender| match sender.blocking_send(msg.clone()) {
189 Err(ProgressSendError::ReceiverDropped) => false,
190 Ok(()) => true,
191 });
192 Ok(())
193 }
194}