Skip to main content

iroh_blobs/api/
downloader.rs

1//! API for downloads from multiple nodes.
2use std::{
3    collections::HashMap,
4    fmt::Debug,
5    future::{Future, IntoFuture},
6    sync::Arc,
7};
8
9use genawaiter::sync::Gen;
10use iroh::{Endpoint, EndpointId};
11use irpc::{
12    channel::{mpsc, oneshot},
13    rpc_requests,
14};
15use n0_error::{anyerr, Result};
16use n0_future::{
17    future, stream,
18    task::{JoinError, JoinSet},
19    BufferedStreamExt, Stream, StreamExt,
20};
21use rand::seq::SliceRandom;
22use serde::{de::Error, Deserialize, Serialize};
23use tracing::instrument::Instrument;
24
25use super::Store;
26use crate::{
27    protocol::{GetManyRequest, GetRequest},
28    util::{
29        connection_pool::ConnectionPool,
30        sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
31    },
32    BlobFormat, Hash, HashAndFormat,
33};
34
35#[derive(Debug, Clone)]
36pub struct Downloader {
37    client: irpc::Client<SwarmProtocol>,
38}
39
40#[rpc_requests(message = SwarmMsg, alias = "Msg", rpc_feature = "rpc")]
41#[derive(Debug, Serialize, Deserialize)]
42enum SwarmProtocol {
43    #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
44    Download(DownloadRequest),
45    #[rpc(tx = oneshot::Sender<()>)]
46    WaitIdle(WaitIdleRequest),
47}
48
49#[derive(Debug, Serialize, Deserialize)]
50pub struct WaitIdleRequest;
51
52struct DownloaderActor {
53    store: Store,
54    pool: ConnectionPool,
55    tasks: JoinSet<()>,
56    idle_waiters: Vec<irpc::channel::oneshot::Sender<()>>,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60pub enum DownloadProgressItem {
61    #[serde(skip)]
62    Error(n0_error::AnyError),
63    TryProvider {
64        id: EndpointId,
65        request: Arc<GetRequest>,
66    },
67    ProviderFailed {
68        id: EndpointId,
69        request: Arc<GetRequest>,
70    },
71    PartComplete {
72        request: Arc<GetRequest>,
73    },
74    Progress(u64),
75    DownloadError,
76}
77
78impl DownloaderActor {
79    fn new_with_opts(
80        store: Store,
81        endpoint: Endpoint,
82        pool_options: crate::util::connection_pool::Options,
83    ) -> Self {
84        Self {
85            store,
86            pool: ConnectionPool::new(endpoint, crate::ALPN, pool_options),
87            tasks: JoinSet::new(),
88            idle_waiters: Vec::new(),
89        }
90    }
91
92    async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
93        loop {
94            tokio::select! {
95                msg = rx.recv() => {
96                    let Some(msg) = msg else { break };
97                    match msg {
98                        SwarmMsg::Download(request) => {
99                            self.spawn(handle_download(
100                                self.store.clone(),
101                                self.pool.clone(),
102                                request,
103                            ));
104                        }
105                        SwarmMsg::WaitIdle(WaitIdleMsg { tx, .. }) => {
106                            if self.tasks.is_empty() {
107                                tx.send(()).await.ok();
108                            } else {
109                                self.idle_waiters.push(tx);
110                            }
111                        }
112                    }
113                }
114                Some(res) = self.tasks.join_next(), if !self.tasks.is_empty() => {
115                    Self::log_task_result(res);
116                    if self.tasks.is_empty() {
117                        for tx in self.idle_waiters.drain(..) {
118                            tx.send(()).await.ok();
119                        }
120                    }
121                }
122            }
123        }
124        while let Some(res) = self.tasks.join_next().await {
125            Self::log_task_result(res);
126        }
127    }
128
129    fn log_task_result(res: std::result::Result<(), JoinError>) {
130        match res {
131            Ok(()) => {}
132            Err(e) if e.is_cancelled() => tracing::trace!("download task cancelled: {e}"),
133            Err(e) => tracing::error!("download task failed: {e}"),
134        }
135    }
136
137    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
138        let span = tracing::Span::current();
139        self.tasks.spawn(fut.instrument(span));
140    }
141}
142
143async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
144    let DownloadMsg { inner, mut tx, .. } = msg;
145    if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
146        tx.send(DownloadProgressItem::Error(cause)).await.ok();
147    }
148}
149
150async fn handle_download_impl(
151    store: Store,
152    pool: ConnectionPool,
153    request: DownloadRequest,
154    tx: &mut mpsc::Sender<DownloadProgressItem>,
155) -> Result<()> {
156    match request.strategy {
157        SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
158        SplitStrategy::None => match request.request {
159            FiniteRequest::Get(get) => {
160                let sink = IrpcSenderRefSink(tx);
161                execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
162            }
163            FiniteRequest::GetMany(_) => {
164                handle_download_split_impl(store, pool, request, tx).await?
165            }
166        },
167    }
168    Ok(())
169}
170
171async fn handle_download_split_impl(
172    store: Store,
173    pool: ConnectionPool,
174    request: DownloadRequest,
175    tx: &mut mpsc::Sender<DownloadProgressItem>,
176) -> Result<()> {
177    let providers = request.providers;
178    let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
179    let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
180    let mut futs = stream::iter(requests.into_iter().enumerate())
181        .map(|(id, request)| {
182            let pool = pool.clone();
183            let providers = providers.clone();
184            let store = store.clone();
185            let progress_tx = progress_tx.clone();
186            async move {
187                let hash = request.hash;
188                let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
189                progress_tx.send(rx).await.ok();
190                let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
191                let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
192                (hash, res)
193            }
194        })
195        .buffered_unordered(32);
196    let mut progress_stream = {
197        let mut offsets = HashMap::new();
198        let mut total = 0;
199        into_stream(progress_rx)
200            .flat_map(into_stream)
201            .map(move |(id, item)| match item {
202                DownloadProgressItem::Progress(offset) => {
203                    total += offset;
204                    if let Some(prev) = offsets.insert(id, offset) {
205                        total -= prev;
206                    }
207                    DownloadProgressItem::Progress(total)
208                }
209                x => x,
210            })
211    };
212    loop {
213        tokio::select! {
214            Some(item) = progress_stream.next() => {
215                tx.send(item).await?;
216            },
217            res = futs.next() => {
218                match res {
219                    Some((_hash, Ok(()))) => {
220                    }
221                    Some((_hash, Err(_e))) => {
222                        tx.send(DownloadProgressItem::DownloadError).await?;
223                    }
224                    None => break,
225                }
226            }
227            _ = tx.closed() => {
228                // The sender has been closed, we should stop processing.
229                break;
230            }
231        }
232    }
233    Ok(())
234}
235
236fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
237    Gen::new(|co| async move {
238        while let Some(item) = recv.recv().await {
239            co.yield_(item).await;
240        }
241    })
242}
243
244#[derive(Debug, Serialize, Deserialize, derive_more::From)]
245pub enum FiniteRequest {
246    Get(GetRequest),
247    GetMany(GetManyRequest),
248}
249
250pub trait SupportedRequest {
251    fn into_request(self) -> FiniteRequest;
252}
253
254impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
255    fn into_request(self) -> FiniteRequest {
256        let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
257        FiniteRequest::GetMany(hashes)
258    }
259}
260
261impl SupportedRequest for GetRequest {
262    fn into_request(self) -> FiniteRequest {
263        self.into()
264    }
265}
266
267impl SupportedRequest for GetManyRequest {
268    fn into_request(self) -> FiniteRequest {
269        self.into()
270    }
271}
272
273impl SupportedRequest for Hash {
274    fn into_request(self) -> FiniteRequest {
275        GetRequest::blob(self).into()
276    }
277}
278
279impl SupportedRequest for HashAndFormat {
280    fn into_request(self) -> FiniteRequest {
281        (match self.format {
282            BlobFormat::Raw => GetRequest::blob(self.hash),
283            BlobFormat::HashSeq => GetRequest::all(self.hash),
284        })
285        .into()
286    }
287}
288
289#[derive(Debug, Serialize, Deserialize)]
290pub struct AddProviderRequest {
291    pub hash: Hash,
292    pub providers: Vec<EndpointId>,
293}
294
295#[derive(Debug)]
296pub struct DownloadRequest {
297    pub request: FiniteRequest,
298    pub providers: Arc<dyn ContentDiscovery>,
299    pub strategy: SplitStrategy,
300}
301
302impl DownloadRequest {
303    pub fn new(
304        request: impl SupportedRequest,
305        providers: impl ContentDiscovery,
306        strategy: SplitStrategy,
307    ) -> Self {
308        Self {
309            request: request.into_request(),
310            providers: Arc::new(providers),
311            strategy,
312        }
313    }
314}
315
316#[derive(Debug, Serialize, Deserialize)]
317pub enum SplitStrategy {
318    None,
319    Split,
320}
321
322impl Serialize for DownloadRequest {
323    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
324    where
325        S: serde::Serializer,
326    {
327        Err(serde::ser::Error::custom(
328            "cannot serialize DownloadRequest",
329        ))
330    }
331}
332
333// Implement Deserialize to always fail
334impl<'de> Deserialize<'de> for DownloadRequest {
335    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
336    where
337        D: serde::Deserializer<'de>,
338    {
339        Err(D::Error::custom("cannot deserialize DownloadRequest"))
340    }
341}
342
343pub type DownloadOptions = DownloadRequest;
344
345pub struct DownloadProgress {
346    fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
347}
348
349impl DownloadProgress {
350    fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
351        Self { fut }
352    }
353
354    pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
355        let rx = self.fut.await?;
356        Ok(Box::pin(rx.into_stream().map(|item| match item {
357            Ok(item) => item,
358            Err(e) => DownloadProgressItem::Error(e.into()),
359        })))
360    }
361
362    async fn complete(self) -> Result<()> {
363        let rx = self.fut.await?;
364        let stream = rx.into_stream();
365        tokio::pin!(stream);
366        while let Some(item) = stream.next().await {
367            match item? {
368                DownloadProgressItem::Error(e) => Err(e)?,
369                DownloadProgressItem::DownloadError => {
370                    n0_error::bail_any!("Download error");
371                }
372                _ => {}
373            }
374        }
375        Ok(())
376    }
377}
378
379impl IntoFuture for DownloadProgress {
380    type Output = Result<()>;
381    type IntoFuture = future::Boxed<Self::Output>;
382
383    fn into_future(self) -> Self::IntoFuture {
384        Box::pin(self.complete())
385    }
386}
387
388impl Downloader {
389    pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
390        Self::new_with_opts(store, endpoint, Default::default())
391    }
392
393    pub fn new_with_opts(
394        store: &Store,
395        endpoint: &Endpoint,
396        pool_options: crate::util::connection_pool::Options,
397    ) -> Self {
398        let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
399        let actor = DownloaderActor::new_with_opts(store.clone(), endpoint.clone(), pool_options);
400        n0_future::task::spawn(actor.run(rx));
401        Self { client: tx.into() }
402    }
403
404    pub fn download(
405        &self,
406        request: impl SupportedRequest,
407        providers: impl ContentDiscovery,
408    ) -> DownloadProgress {
409        let request = request.into_request();
410        let providers = Arc::new(providers);
411        self.download_with_opts(DownloadOptions {
412            request,
413            providers,
414            strategy: SplitStrategy::None,
415        })
416    }
417
418    pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
419        let fut = self.client.server_streaming(options, 32);
420        DownloadProgress::new(Box::pin(fut))
421    }
422
423    /// Wait until the downloader has no in-flight download tasks.
424    ///
425    /// This is mostly useful for tests, where you want to confirm that all
426    /// previously-issued downloads have settled (whether they completed
427    /// successfully or errored out). Note that the downloader is not
428    /// guaranteed to become idle if it is being interacted with concurrently;
429    /// in that case this might wait forever. Also note that once you get the
430    /// callback, the downloader is not guaranteed to still be idle — all this
431    /// tells you is that there was a point in time, between the call and the
432    /// response, where it was idle.
433    pub async fn wait_idle(&self) -> irpc::Result<()> {
434        self.client.rpc(WaitIdleRequest).await?;
435        Ok(())
436    }
437}
438
439/// Split a request into multiple requests that can be run in parallel.
440async fn split_request<'a>(
441    request: &'a FiniteRequest,
442    providers: &Arc<dyn ContentDiscovery>,
443    pool: &ConnectionPool,
444    store: &Store,
445    progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
446) -> Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
447    Ok(match request {
448        FiniteRequest::Get(req) => {
449            let Some(_first) = req.ranges.iter_infinite().next() else {
450                return Ok(Box::new(std::iter::empty()));
451            };
452            let first = GetRequest::blob(req.hash);
453            execute_get(pool, Arc::new(first), providers, store, progress).await?;
454            let size = store.observe(req.hash).await?.size();
455            n0_error::ensure_any!(size % 32 == 0, "Size is not a multiple of 32");
456            let n = size / 32;
457            Box::new(
458                req.ranges
459                    .iter_infinite()
460                    .take(n as usize + 1)
461                    .enumerate()
462                    .filter_map(|(i, ranges)| {
463                        if i != 0 && !ranges.is_empty() {
464                            Some(
465                                GetRequest::builder()
466                                    .offset(i as u64, ranges.clone())
467                                    .build(req.hash),
468                            )
469                        } else {
470                            None
471                        }
472                    }),
473            )
474        }
475        FiniteRequest::GetMany(req) => Box::new(
476            req.hashes
477                .iter()
478                .enumerate()
479                .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
480        ),
481    })
482}
483
484/// Execute a get request sequentially for multiple providers.
485///
486/// It will try each provider in order
487/// until it finds one that can fulfill the request. When trying a new provider,
488/// it takes the progress from the previous providers into account, so e.g.
489/// if the first provider had the first 10% of the data, it will only ask the next
490/// provider for the remaining 90%.
491///
492/// This is fully sequential, so there will only be one request in flight at a time.
493///
494/// If the request is not complete after trying all providers, it will return an error.
495/// If the provider stream never ends, it will try indefinitely.
496async fn execute_get(
497    pool: &ConnectionPool,
498    request: Arc<GetRequest>,
499    providers: &Arc<dyn ContentDiscovery>,
500    store: &Store,
501    mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
502) -> Result<()> {
503    let remote = store.remote();
504    let mut providers = providers.find_providers(request.content());
505    while let Some(provider) = providers.next().await {
506        progress
507            .send(DownloadProgressItem::TryProvider {
508                id: provider,
509                request: request.clone(),
510            })
511            .await?;
512        let conn = pool.get_or_connect(provider);
513        let local = remote.local_for_request(request.clone()).await?;
514        if local.is_complete() {
515            return Ok(());
516        }
517        let local_bytes = local.local_bytes();
518        let Ok(conn) = conn.await else {
519            progress
520                .send(DownloadProgressItem::ProviderFailed {
521                    id: provider,
522                    request: request.clone(),
523                })
524                .await?;
525            continue;
526        };
527        match remote
528            .execute_get_sink(
529                conn.clone(),
530                local.missing(),
531                (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
532            )
533            .await
534        {
535            Ok(_stats) => {
536                progress
537                    .send(DownloadProgressItem::PartComplete {
538                        request: request.clone(),
539                    })
540                    .await?;
541                return Ok(());
542            }
543            Err(_cause) => {
544                progress
545                    .send(DownloadProgressItem::ProviderFailed {
546                        id: provider,
547                        request: request.clone(),
548                    })
549                    .await?;
550                continue;
551            }
552        }
553    }
554    Err(anyerr!("Unable to download {}", request.hash))
555}
556
557/// Trait for pluggable content discovery strategies.
558pub trait ContentDiscovery: Debug + Send + Sync + 'static {
559    fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<EndpointId>;
560}
561
562impl<C, I> ContentDiscovery for C
563where
564    C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
565    C::IntoIter: Send + Sync + 'static,
566    I: Into<EndpointId> + Send + Sync + 'static,
567{
568    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
569        let providers = self.clone();
570        n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
571    }
572}
573
574#[derive(derive_more::Debug)]
575pub struct Shuffled {
576    nodes: Vec<EndpointId>,
577}
578
579impl Shuffled {
580    pub fn new(nodes: Vec<EndpointId>) -> Self {
581        Self { nodes }
582    }
583}
584
585impl ContentDiscovery for Shuffled {
586    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
587        let mut nodes = self.nodes.clone();
588        nodes.shuffle(&mut rand::rng());
589        n0_future::stream::iter(nodes).boxed()
590    }
591}
592
593#[cfg(test)]
594#[cfg(feature = "fs-store")]
595mod tests {
596    use std::ops::Deref;
597
598    use bao_tree::ChunkRanges;
599    use n0_future::StreamExt;
600    use testresult::TestResult;
601
602    use crate::{
603        api::{
604            blobs::AddBytesOptions,
605            downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
606        },
607        hashseq::HashSeq,
608        protocol::{GetManyRequest, GetRequest},
609        tests::node_test_setup_fs,
610        Hash,
611    };
612
613    #[tokio::test]
614    #[ignore = "todo"]
615    async fn downloader_get_many_smoke() -> TestResult<()> {
616        let testdir = tempfile::tempdir()?;
617        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
618        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
619        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
620        let tt1 = store1.add_slice("hello world").await?;
621        let tt2 = store2.add_slice("hello world 2").await?;
622        let node1_addr = r1.endpoint().addr();
623        let node1_id = node1_addr.id;
624        let node2_addr = r2.endpoint().addr();
625        let node2_id = node2_addr.id;
626        let swarm = Downloader::new(&store3, r3.endpoint());
627        sp3.add_endpoint_info(node1_addr.clone());
628        sp3.add_endpoint_info(node2_addr.clone());
629        let request = GetManyRequest::builder()
630            .hash(tt1.hash, ChunkRanges::all())
631            .hash(tt2.hash, ChunkRanges::all())
632            .build();
633        let mut progress = swarm
634            .download(request, Shuffled::new(vec![node1_id, node2_id]))
635            .stream()
636            .await?;
637        while progress.next().await.is_some() {}
638        assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
639        assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
640        Ok(())
641    }
642
643    #[tokio::test]
644    async fn downloader_get_smoke() -> TestResult<()> {
645        // tracing_subscriber::fmt::try_init().ok();
646        let testdir = tempfile::tempdir()?;
647        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
648        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
649        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
650        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
651        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
652        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
653        let root = store1
654            .add_bytes_with_opts(AddBytesOptions {
655                data: hs.clone().into(),
656                format: crate::BlobFormat::HashSeq,
657            })
658            .await?;
659        let node1_addr = r1.endpoint().addr();
660        let node1_id = node1_addr.id;
661        let node2_addr = r2.endpoint().addr();
662        let node2_id = node2_addr.id;
663        let swarm = Downloader::new(&store3, r3.endpoint());
664        sp3.add_endpoint_info(node1_addr.clone());
665        sp3.add_endpoint_info(node2_addr.clone());
666        let request = GetRequest::builder()
667            .root(ChunkRanges::all())
668            .next(ChunkRanges::all())
669            .next(ChunkRanges::all())
670            .build(root.hash);
671        if true {
672            let mut progress = swarm
673                .download_with_opts(DownloadOptions::new(
674                    request,
675                    [node1_id, node2_id],
676                    SplitStrategy::Split,
677                ))
678                .stream()
679                .await?;
680            while progress.next().await.is_some() {}
681        }
682        if false {
683            let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
684            let remote = store3.remote();
685            let _rh = remote
686                .execute_get(
687                    conn.clone(),
688                    GetRequest::builder()
689                        .root(ChunkRanges::all())
690                        .build(root.hash),
691                )
692                .await?;
693            let h1 = remote.execute_get(
694                conn.clone(),
695                GetRequest::builder()
696                    .child(0, ChunkRanges::all())
697                    .build(root.hash),
698            );
699            let h2 = remote.execute_get(
700                conn.clone(),
701                GetRequest::builder()
702                    .child(1, ChunkRanges::all())
703                    .build(root.hash),
704            );
705            h1.await?;
706            h2.await?;
707        }
708        Ok(())
709    }
710
711    #[tokio::test]
712    async fn downloader_get_all() -> TestResult<()> {
713        let testdir = tempfile::tempdir()?;
714        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
715        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
716        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
717        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
718        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
719        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
720        let root = store1
721            .add_bytes_with_opts(AddBytesOptions {
722                data: hs.clone().into(),
723                format: crate::BlobFormat::HashSeq,
724            })
725            .await?;
726        let node1_addr = r1.endpoint().addr();
727        let node1_id = node1_addr.id;
728        let node2_addr = r2.endpoint().addr();
729        let node2_id = node2_addr.id;
730        let swarm = Downloader::new(&store3, r3.endpoint());
731        sp3.add_endpoint_info(node1_addr.clone());
732        sp3.add_endpoint_info(node2_addr.clone());
733        let request = GetRequest::all(root.hash);
734        let mut progress = swarm
735            .download_with_opts(DownloadOptions::new(
736                request,
737                [node1_id, node2_id],
738                SplitStrategy::Split,
739            ))
740            .stream()
741            .await?;
742        while progress.next().await.is_some() {}
743        Ok(())
744    }
745
746    /// Invariant: `DownloaderActor` must reap each `handle_download` task
747    /// from its `JoinSet` as that task finishes, not only on shutdown.
748    /// If steady-state reaping is skipped, every completed download
749    /// retains its tokio task header for the lifetime of the actor and
750    /// the heap grows linearly with download volume — invisible at the
751    /// API surface, since downloads still report progress and finish.
752    ///
753    /// We submit many downloads with an empty provider list so each
754    /// `handle_download` finishes in microseconds with no I/O, then
755    /// assert the actor reaches an idle state via [`Downloader::wait_idle`].
756    /// If the `tasks.join_next()` arm is removed from
757    /// `DownloaderActor::run`'s `select!`, the JoinSet stays full of
758    /// completed-but-not-joined tasks, `tasks.is_empty()` never becomes
759    /// true, and the timeout below fires.
760    ///
761    /// The complementary `idle_waiters` notification path is not
762    /// exercised here: by the time this test calls `wait_idle`, the
763    /// actor has already drained the JoinSet during the per-stream
764    /// awaits, so `wait_idle`'s fast path answers directly.
765    #[tokio::test]
766    async fn downloader_drains_completed_tasks() -> TestResult<()> {
767        let testdir = tempfile::tempdir()?;
768        let (r, store, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
769        let swarm = Downloader::new(&store, r.endpoint());
770
771        let n = 1_000;
772        let bogus_hash = Hash::new(b"this hash is not stored anywhere");
773        let mut streams = Vec::with_capacity(n);
774        for _ in 0..n {
775            streams.push(
776                swarm
777                    .download(GetRequest::all(bogus_hash), Shuffled::new(vec![]))
778                    .stream()
779                    .await?,
780            );
781        }
782        for mut s in streams {
783            while s.next().await.is_some() {}
784        }
785
786        tokio::time::timeout(std::time::Duration::from_secs(5), swarm.wait_idle())
787            .await
788            .map_err(|_| {
789                "wait_idle did not resolve within 5s — DownloaderActor JoinSet not draining"
790            })??;
791
792        Ok(())
793    }
794}