1use std::{
3 collections::{HashMap, HashSet},
4 fmt::Debug,
5 future::{Future, IntoFuture},
6 io,
7 ops::Deref,
8 sync::Arc,
9 time::{Duration, SystemTime},
10};
11
12use anyhow::bail;
13use genawaiter::sync::Gen;
14use iroh::{endpoint::Connection, Endpoint, NodeId};
15use irpc::{channel::mpsc, rpc_requests};
16use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt};
17use rand::seq::SliceRandom;
18use serde::{de::Error, Deserialize, Serialize};
19use tokio::{sync::Mutex, task::JoinSet};
20use tokio_util::time::FutureExt;
21use tracing::{info, instrument::Instrument, warn};
22
23use super::{remote::GetConnection, Store};
24use crate::{
25 protocol::{GetManyRequest, GetRequest},
26 util::sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
27 BlobFormat, Hash, HashAndFormat,
28};
29
30#[derive(Debug, Clone)]
31pub struct Downloader {
32 client: irpc::Client<SwarmMsg, SwarmProtocol, DownloaderService>,
33}
34
35#[derive(Debug, Clone)]
36pub struct DownloaderService;
37
38impl irpc::Service for DownloaderService {}
39
40#[rpc_requests(DownloaderService, message = SwarmMsg, alias = "Msg")]
41#[derive(Debug, Serialize, Deserialize)]
42enum SwarmProtocol {
43 #[rpc(tx = mpsc::Sender<DownloadProgessItem>)]
44 Download(DownloadRequest),
45}
46
47struct DownloaderActor {
48 store: Store,
49 pool: ConnectionPool,
50 tasks: JoinSet<()>,
51 running: HashSet<tokio::task::Id>,
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55pub enum DownloadProgessItem {
56 #[serde(skip)]
57 Error(anyhow::Error),
58 TryProvider {
59 id: NodeId,
60 request: Arc<GetRequest>,
61 },
62 ProviderFailed {
63 id: NodeId,
64 request: Arc<GetRequest>,
65 },
66 PartComplete {
67 request: Arc<GetRequest>,
68 },
69 Progress(u64),
70 DownloadError,
71}
72
73impl DownloaderActor {
74 fn new(store: Store, endpoint: Endpoint) -> Self {
75 Self {
76 store,
77 pool: ConnectionPool::new(endpoint, crate::ALPN.to_vec()),
78 tasks: JoinSet::new(),
79 running: HashSet::new(),
80 }
81 }
82
83 async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
84 while let Some(msg) = rx.recv().await {
85 match msg {
86 SwarmMsg::Download(request) => {
87 self.spawn(handle_download(
88 self.store.clone(),
89 self.pool.clone(),
90 request,
91 ));
92 }
93 }
94 }
95 }
96
97 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
98 let span = tracing::Span::current();
99 let id = self.tasks.spawn(fut.instrument(span)).id();
100 self.running.insert(id);
101 }
102}
103
104async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
105 let DownloadMsg { inner, mut tx, .. } = msg;
106 if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
107 tx.send(DownloadProgessItem::Error(cause)).await.ok();
108 }
109}
110
111async fn handle_download_impl(
112 store: Store,
113 pool: ConnectionPool,
114 request: DownloadRequest,
115 tx: &mut mpsc::Sender<DownloadProgessItem>,
116) -> anyhow::Result<()> {
117 match request.strategy {
118 SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
119 SplitStrategy::None => match request.request {
120 FiniteRequest::Get(get) => {
121 let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other);
122 execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
123 }
124 FiniteRequest::GetMany(_) => {
125 handle_download_split_impl(store, pool, request, tx).await?
126 }
127 },
128 }
129 Ok(())
130}
131
132async fn handle_download_split_impl(
133 store: Store,
134 pool: ConnectionPool,
135 request: DownloadRequest,
136 tx: &mut mpsc::Sender<DownloadProgessItem>,
137) -> anyhow::Result<()> {
138 let providers = request.providers;
139 let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
140 let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
141 let mut futs = stream::iter(requests.into_iter().enumerate())
142 .map(|(id, request)| {
143 let pool = pool.clone();
144 let providers = providers.clone();
145 let store = store.clone();
146 let progress_tx = progress_tx.clone();
147 async move {
148 let hash = request.hash;
149 let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16);
150 progress_tx.send(rx).await.ok();
151 let sink = TokioMpscSenderSink(tx)
152 .with_map_err(io::Error::other)
153 .with_map(move |x| (id, x));
154 let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
155 (hash, res)
156 }
157 })
158 .buffered_unordered(32);
159 let mut progress_stream = {
160 let mut offsets = HashMap::new();
161 let mut total = 0;
162 into_stream(progress_rx)
163 .flat_map(into_stream)
164 .map(move |(id, item)| match item {
165 DownloadProgessItem::Progress(offset) => {
166 total += offset;
167 if let Some(prev) = offsets.insert(id, offset) {
168 total -= prev;
169 }
170 DownloadProgessItem::Progress(total)
171 }
172 x => x,
173 })
174 };
175 loop {
176 tokio::select! {
177 Some(item) = progress_stream.next() => {
178 tx.send(item).await?;
179 },
180 res = futs.next() => {
181 match res {
182 Some((_hash, Ok(()))) => {
183 }
184 Some((_hash, Err(_e))) => {
185 tx.send(DownloadProgessItem::DownloadError).await?;
186 }
187 None => break,
188 }
189 }
190 _ = tx.closed() => {
191 break;
193 }
194 }
195 }
196 Ok(())
197}
198
199fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
200 Gen::new(|co| async move {
201 while let Some(item) = recv.recv().await {
202 co.yield_(item).await;
203 }
204 })
205}
206
207#[derive(Debug, Serialize, Deserialize, derive_more::From)]
208pub enum FiniteRequest {
209 Get(GetRequest),
210 GetMany(GetManyRequest),
211}
212
213pub trait SupportedRequest {
214 fn into_request(self) -> FiniteRequest;
215}
216
217impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
218 fn into_request(self) -> FiniteRequest {
219 let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
220 FiniteRequest::GetMany(hashes)
221 }
222}
223
224impl SupportedRequest for GetRequest {
225 fn into_request(self) -> FiniteRequest {
226 self.into()
227 }
228}
229
230impl SupportedRequest for GetManyRequest {
231 fn into_request(self) -> FiniteRequest {
232 self.into()
233 }
234}
235
236impl SupportedRequest for Hash {
237 fn into_request(self) -> FiniteRequest {
238 GetRequest::blob(self).into()
239 }
240}
241
242impl SupportedRequest for HashAndFormat {
243 fn into_request(self) -> FiniteRequest {
244 (match self.format {
245 BlobFormat::Raw => GetRequest::blob(self.hash),
246 BlobFormat::HashSeq => GetRequest::all(self.hash),
247 })
248 .into()
249 }
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253pub struct AddProviderRequest {
254 pub hash: Hash,
255 pub providers: Vec<NodeId>,
256}
257
258#[derive(Debug)]
259pub struct DownloadRequest {
260 pub request: FiniteRequest,
261 pub providers: Arc<dyn ContentDiscovery>,
262 pub strategy: SplitStrategy,
263}
264
265impl DownloadRequest {
266 pub fn new(
267 request: impl SupportedRequest,
268 providers: impl ContentDiscovery,
269 strategy: SplitStrategy,
270 ) -> Self {
271 Self {
272 request: request.into_request(),
273 providers: Arc::new(providers),
274 strategy,
275 }
276 }
277}
278
279#[derive(Debug, Serialize, Deserialize)]
280pub enum SplitStrategy {
281 None,
282 Split,
283}
284
285impl Serialize for DownloadRequest {
286 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
287 where
288 S: serde::Serializer,
289 {
290 Err(serde::ser::Error::custom(
291 "cannot serialize DownloadRequest",
292 ))
293 }
294}
295
296impl<'de> Deserialize<'de> for DownloadRequest {
298 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
299 where
300 D: serde::Deserializer<'de>,
301 {
302 Err(D::Error::custom("cannot deserialize DownloadRequest"))
303 }
304}
305
306pub type DownloadOptions = DownloadRequest;
307
308pub struct DownloadProgress {
309 fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgessItem>>>,
310}
311
312impl DownloadProgress {
313 fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgessItem>>>) -> Self {
314 Self { fut }
315 }
316
317 pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgessItem> + Unpin> {
318 let rx = self.fut.await?;
319 Ok(Box::pin(rx.into_stream().map(|item| match item {
320 Ok(item) => item,
321 Err(e) => DownloadProgessItem::Error(e.into()),
322 })))
323 }
324
325 async fn complete(self) -> anyhow::Result<()> {
326 let rx = self.fut.await?;
327 let stream = rx.into_stream();
328 tokio::pin!(stream);
329 while let Some(item) = stream.next().await {
330 match item? {
331 DownloadProgessItem::Error(e) => Err(e)?,
332 DownloadProgessItem::DownloadError => anyhow::bail!("Download error"),
333 _ => {}
334 }
335 }
336 Ok(())
337 }
338}
339
340impl IntoFuture for DownloadProgress {
341 type Output = anyhow::Result<()>;
342 type IntoFuture = future::Boxed<Self::Output>;
343
344 fn into_future(self) -> Self::IntoFuture {
345 Box::pin(self.complete())
346 }
347}
348
349impl Downloader {
350 pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
351 let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
352 let actor = DownloaderActor::new(store.clone(), endpoint.clone());
353 tokio::spawn(actor.run(rx));
354 Self { client: tx.into() }
355 }
356
357 pub fn download(
358 &self,
359 request: impl SupportedRequest,
360 providers: impl ContentDiscovery,
361 ) -> DownloadProgress {
362 let request = request.into_request();
363 let providers = Arc::new(providers);
364 self.download_with_opts(DownloadOptions {
365 request,
366 providers,
367 strategy: SplitStrategy::Split,
368 })
369 }
370
371 pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
372 let fut = self.client.server_streaming(options, 32);
373 DownloadProgress::new(Box::pin(fut))
374 }
375}
376
377async fn split_request<'a>(
379 request: &'a FiniteRequest,
380 providers: &Arc<dyn ContentDiscovery>,
381 pool: &ConnectionPool,
382 store: &Store,
383 progress: impl Sink<DownloadProgessItem, Error = io::Error>,
384) -> anyhow::Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
385 Ok(match request {
386 FiniteRequest::Get(req) => {
387 let Some(_first) = req.ranges.iter_infinite().next() else {
388 return Ok(Box::new(std::iter::empty()));
389 };
390 let first = GetRequest::blob(req.hash);
391 execute_get(pool, Arc::new(first), providers, store, progress).await?;
392 let size = store.observe(req.hash).await?.size();
393 anyhow::ensure!(size % 32 == 0, "Size is not a multiple of 32");
394 let n = size / 32;
395 Box::new(
396 req.ranges
397 .iter_infinite()
398 .take(n as usize + 1)
399 .enumerate()
400 .filter_map(|(i, ranges)| {
401 if i != 0 && !ranges.is_empty() {
402 Some(
403 GetRequest::builder()
404 .offset(i as u64, ranges.clone())
405 .build(req.hash),
406 )
407 } else {
408 None
409 }
410 }),
411 )
412 }
413 FiniteRequest::GetMany(req) => Box::new(
414 req.hashes
415 .iter()
416 .enumerate()
417 .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
418 ),
419 })
420}
421
422#[derive(Debug)]
423struct ConnectionPoolInner {
424 alpn: Vec<u8>,
425 endpoint: Endpoint,
426 connections: Mutex<HashMap<NodeId, Arc<Mutex<SlotState>>>>,
427 retry_delay: Duration,
428 connect_timeout: Duration,
429}
430
431#[derive(Debug, Clone)]
432struct ConnectionPool(Arc<ConnectionPoolInner>);
433
434#[derive(Debug, Default)]
435enum SlotState {
436 #[default]
437 Initial,
438 Connected(Connection),
439 AttemptFailed(SystemTime),
440 #[allow(dead_code)]
441 Evil(String),
442}
443
444impl ConnectionPool {
445 fn new(endpoint: Endpoint, alpn: Vec<u8>) -> Self {
446 Self(
447 ConnectionPoolInner {
448 endpoint,
449 alpn,
450 connections: Default::default(),
451 retry_delay: Duration::from_secs(5),
452 connect_timeout: Duration::from_secs(2),
453 }
454 .into(),
455 )
456 }
457
458 pub fn alpn(&self) -> &[u8] {
459 &self.0.alpn
460 }
461
462 pub fn endpoint(&self) -> &Endpoint {
463 &self.0.endpoint
464 }
465
466 pub fn retry_delay(&self) -> Duration {
467 self.0.retry_delay
468 }
469
470 fn dial(&self, id: NodeId) -> DialNode {
471 DialNode {
472 pool: self.clone(),
473 id,
474 }
475 }
476
477 #[allow(dead_code)]
478 async fn mark_evil(&self, id: NodeId, reason: String) {
479 let slot = self
480 .0
481 .connections
482 .lock()
483 .await
484 .entry(id)
485 .or_default()
486 .clone();
487 let mut t = slot.lock().await;
488 *t = SlotState::Evil(reason)
489 }
490
491 #[allow(dead_code)]
492 async fn mark_closed(&self, id: NodeId) {
493 let slot = self
494 .0
495 .connections
496 .lock()
497 .await
498 .entry(id)
499 .or_default()
500 .clone();
501 let mut t = slot.lock().await;
502 *t = SlotState::Initial
503 }
504}
505
506async fn execute_get(
519 pool: &ConnectionPool,
520 request: Arc<GetRequest>,
521 providers: &Arc<dyn ContentDiscovery>,
522 store: &Store,
523 mut progress: impl Sink<DownloadProgessItem, Error = io::Error>,
524) -> anyhow::Result<()> {
525 let remote = store.remote();
526 let mut providers = providers.find_providers(request.content());
527 while let Some(provider) = providers.next().await {
528 progress
529 .send(DownloadProgessItem::TryProvider {
530 id: provider,
531 request: request.clone(),
532 })
533 .await?;
534 let mut conn = pool.dial(provider);
535 let local = remote.local_for_request(request.clone()).await?;
536 if local.is_complete() {
537 return Ok(());
538 }
539 let local_bytes = local.local_bytes();
540 let Ok(conn) = conn.connection().await else {
541 progress
542 .send(DownloadProgessItem::ProviderFailed {
543 id: provider,
544 request: request.clone(),
545 })
546 .await?;
547 continue;
548 };
549 match remote
550 .execute_get_sink(
551 conn,
552 local.missing(),
553 (&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)),
554 )
555 .await
556 {
557 Ok(_stats) => {
558 progress
559 .send(DownloadProgessItem::PartComplete {
560 request: request.clone(),
561 })
562 .await?;
563 return Ok(());
564 }
565 Err(_cause) => {
566 progress
567 .send(DownloadProgessItem::ProviderFailed {
568 id: provider,
569 request: request.clone(),
570 })
571 .await?;
572 continue;
573 }
574 }
575 }
576 bail!("Unable to download {}", request.hash);
577}
578
579#[derive(Debug, Clone)]
580struct DialNode {
581 pool: ConnectionPool,
582 id: NodeId,
583}
584
585impl DialNode {
586 async fn connection_impl(&self) -> anyhow::Result<Connection> {
587 info!("Getting connection for node {}", self.id);
588 let slot = self
589 .pool
590 .0
591 .connections
592 .lock()
593 .await
594 .entry(self.id)
595 .or_default()
596 .clone();
597 info!("Dialing node {}", self.id);
598 let mut guard = slot.lock().await;
599 match guard.deref() {
600 SlotState::Connected(conn) => {
601 return Ok(conn.clone());
602 }
603 SlotState::AttemptFailed(time) => {
604 let elapsed = time.elapsed().unwrap_or_default();
605 if elapsed <= self.pool.retry_delay() {
606 bail!(
607 "Connection attempt failed {} seconds ago",
608 elapsed.as_secs_f64()
609 );
610 }
611 }
612 SlotState::Evil(reason) => {
613 bail!("Node is banned due to evil behavior: {reason}");
614 }
615 SlotState::Initial => {}
616 }
617 let res = self
618 .pool
619 .endpoint()
620 .connect(self.id, self.pool.alpn())
621 .timeout(self.pool.0.connect_timeout)
622 .await;
623 match res {
624 Ok(Ok(conn)) => {
625 info!("Connected to node {}", self.id);
626 *guard = SlotState::Connected(conn.clone());
627 Ok(conn)
628 }
629 Ok(Err(e)) => {
630 warn!("Failed to connect to node {}: {}", self.id, e);
631 *guard = SlotState::AttemptFailed(SystemTime::now());
632 Err(e.into())
633 }
634 Err(e) => {
635 warn!("Failed to connect to node {}: {}", self.id, e);
636 *guard = SlotState::AttemptFailed(SystemTime::now());
637 bail!("Failed to connect to node: {}", e);
638 }
639 }
640 }
641}
642
643impl GetConnection for DialNode {
644 fn connection(&mut self) -> impl Future<Output = Result<Connection, anyhow::Error>> + '_ {
645 let this = self.clone();
646 async move { this.connection_impl().await }
647 }
648}
649
650pub trait ContentDiscovery: Debug + Send + Sync + 'static {
652 fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<NodeId>;
653}
654
655impl<C, I> ContentDiscovery for C
656where
657 C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
658 C::IntoIter: Send + Sync + 'static,
659 I: Into<NodeId> + Send + Sync + 'static,
660{
661 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
662 let providers = self.clone();
663 n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
664 }
665}
666
667#[derive(derive_more::Debug)]
668pub struct Shuffled {
669 nodes: Vec<NodeId>,
670}
671
672impl Shuffled {
673 pub fn new(nodes: Vec<NodeId>) -> Self {
674 Self { nodes }
675 }
676}
677
678impl ContentDiscovery for Shuffled {
679 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
680 let mut nodes = self.nodes.clone();
681 nodes.shuffle(&mut rand::thread_rng());
682 n0_future::stream::iter(nodes).boxed()
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use std::ops::Deref;
689
690 use bao_tree::ChunkRanges;
691 use iroh::Watcher;
692 use n0_future::StreamExt;
693 use testresult::TestResult;
694
695 use crate::{
696 api::{
697 blobs::AddBytesOptions,
698 downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
699 },
700 hashseq::HashSeq,
701 protocol::{GetManyRequest, GetRequest},
702 tests::node_test_setup_fs,
703 };
704
705 #[tokio::test]
706 #[ignore = "todo"]
707 async fn downloader_get_many_smoke() -> TestResult<()> {
708 let testdir = tempfile::tempdir()?;
709 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
710 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
711 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
712 let tt1 = store1.add_slice("hello world").await?;
713 let tt2 = store2.add_slice("hello world 2").await?;
714 let node1_addr = r1.endpoint().node_addr().initialized().await?;
715 let node1_id = node1_addr.node_id;
716 let node2_addr = r2.endpoint().node_addr().initialized().await?;
717 let node2_id = node2_addr.node_id;
718 let swarm = Downloader::new(&store3, r3.endpoint());
719 r3.endpoint().add_node_addr(node1_addr.clone())?;
720 r3.endpoint().add_node_addr(node2_addr.clone())?;
721 let request = GetManyRequest::builder()
722 .hash(tt1.hash, ChunkRanges::all())
723 .hash(tt2.hash, ChunkRanges::all())
724 .build();
725 let mut progress = swarm
726 .download(request, Shuffled::new(vec![node1_id, node2_id]))
727 .stream()
728 .await?;
729 while let Some(item) = progress.next().await {
730 println!("Got item: {item:?}");
731 }
732 assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
733 assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
734 Ok(())
735 }
736
737 #[tokio::test]
738 async fn downloader_get_smoke() -> TestResult<()> {
739 let testdir = tempfile::tempdir()?;
741 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
742 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
743 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
744 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
745 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
746 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
747 let root = store1
748 .add_bytes_with_opts(AddBytesOptions {
749 data: hs.clone().into(),
750 format: crate::BlobFormat::HashSeq,
751 })
752 .await?;
753 let node1_addr = r1.endpoint().node_addr().initialized().await?;
754 let node1_id = node1_addr.node_id;
755 let node2_addr = r2.endpoint().node_addr().initialized().await?;
756 let node2_id = node2_addr.node_id;
757 let swarm = Downloader::new(&store3, r3.endpoint());
758 r3.endpoint().add_node_addr(node1_addr.clone())?;
759 r3.endpoint().add_node_addr(node2_addr.clone())?;
760 let request = GetRequest::builder()
761 .root(ChunkRanges::all())
762 .next(ChunkRanges::all())
763 .next(ChunkRanges::all())
764 .build(root.hash);
765 if true {
766 let mut progress = swarm
767 .download_with_opts(DownloadOptions::new(
768 request,
769 [node1_id, node2_id],
770 SplitStrategy::Split,
771 ))
772 .stream()
773 .await?;
774 while let Some(item) = progress.next().await {
775 println!("Got item: {item:?}");
776 }
777 }
778 if false {
779 let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
780 let remote = store3.remote();
781 let _rh = remote
782 .execute_get(
783 conn.clone(),
784 GetRequest::builder()
785 .root(ChunkRanges::all())
786 .build(root.hash),
787 )
788 .await?;
789 let h1 = remote.execute_get(
790 conn.clone(),
791 GetRequest::builder()
792 .child(0, ChunkRanges::all())
793 .build(root.hash),
794 );
795 let h2 = remote.execute_get(
796 conn.clone(),
797 GetRequest::builder()
798 .child(1, ChunkRanges::all())
799 .build(root.hash),
800 );
801 h1.await?;
802 h2.await?;
803 }
804 Ok(())
805 }
806
807 #[tokio::test]
808 async fn downloader_get_all() -> TestResult<()> {
809 let testdir = tempfile::tempdir()?;
810 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
811 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
812 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
813 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
814 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
815 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
816 let root = store1
817 .add_bytes_with_opts(AddBytesOptions {
818 data: hs.clone().into(),
819 format: crate::BlobFormat::HashSeq,
820 })
821 .await?;
822 let node1_addr = r1.endpoint().node_addr().initialized().await?;
823 let node1_id = node1_addr.node_id;
824 let node2_addr = r2.endpoint().node_addr().initialized().await?;
825 let node2_id = node2_addr.node_id;
826 let swarm = Downloader::new(&store3, r3.endpoint());
827 r3.endpoint().add_node_addr(node1_addr.clone())?;
828 r3.endpoint().add_node_addr(node2_addr.clone())?;
829 let request = GetRequest::all(root.hash);
830 let mut progress = swarm
831 .download_with_opts(DownloadOptions::new(
832 request,
833 [node1_id, node2_id],
834 SplitStrategy::Split,
835 ))
836 .stream()
837 .await?;
838 while let Some(item) = progress.next().await {
839 println!("Got item: {item:?}");
840 }
841 Ok(())
842 }
843}