1use std::{
3 collections::{HashMap, HashSet},
4 fmt::Debug,
5 future::{Future, IntoFuture},
6 sync::Arc,
7};
8
9use anyhow::bail;
10use genawaiter::sync::Gen;
11use iroh::{Endpoint, NodeId};
12use irpc::{channel::mpsc, rpc_requests};
13use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt};
14use rand::seq::SliceRandom;
15use serde::{de::Error, Deserialize, Serialize};
16use tokio::task::JoinSet;
17use tracing::instrument::Instrument;
18
19use super::Store;
20use crate::{
21 protocol::{GetManyRequest, GetRequest},
22 util::{
23 connection_pool::ConnectionPool,
24 sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
25 },
26 BlobFormat, Hash, HashAndFormat,
27};
28
29#[derive(Debug, Clone)]
30pub struct Downloader {
31 client: irpc::Client<SwarmProtocol>,
32}
33
34#[rpc_requests(message = SwarmMsg, alias = "Msg")]
35#[derive(Debug, Serialize, Deserialize)]
36enum SwarmProtocol {
37 #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
38 Download(DownloadRequest),
39}
40
41struct DownloaderActor {
42 store: Store,
43 pool: ConnectionPool,
44 tasks: JoinSet<()>,
45 running: HashSet<tokio::task::Id>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49pub enum DownloadProgressItem {
50 #[serde(skip)]
51 Error(anyhow::Error),
52 TryProvider {
53 id: NodeId,
54 request: Arc<GetRequest>,
55 },
56 ProviderFailed {
57 id: NodeId,
58 request: Arc<GetRequest>,
59 },
60 PartComplete {
61 request: Arc<GetRequest>,
62 },
63 Progress(u64),
64 DownloadError,
65}
66
67impl DownloaderActor {
68 fn new(store: Store, endpoint: Endpoint) -> Self {
69 Self {
70 store,
71 pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
72 tasks: JoinSet::new(),
73 running: HashSet::new(),
74 }
75 }
76
77 async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
78 while let Some(msg) = rx.recv().await {
79 match msg {
80 SwarmMsg::Download(request) => {
81 self.spawn(handle_download(
82 self.store.clone(),
83 self.pool.clone(),
84 request,
85 ));
86 }
87 }
88 }
89 }
90
91 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
92 let span = tracing::Span::current();
93 let id = self.tasks.spawn(fut.instrument(span)).id();
94 self.running.insert(id);
95 }
96}
97
98async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
99 let DownloadMsg { inner, mut tx, .. } = msg;
100 if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
101 tx.send(DownloadProgressItem::Error(cause)).await.ok();
102 }
103}
104
105async fn handle_download_impl(
106 store: Store,
107 pool: ConnectionPool,
108 request: DownloadRequest,
109 tx: &mut mpsc::Sender<DownloadProgressItem>,
110) -> anyhow::Result<()> {
111 match request.strategy {
112 SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
113 SplitStrategy::None => match request.request {
114 FiniteRequest::Get(get) => {
115 let sink = IrpcSenderRefSink(tx);
116 execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
117 }
118 FiniteRequest::GetMany(_) => {
119 handle_download_split_impl(store, pool, request, tx).await?
120 }
121 },
122 }
123 Ok(())
124}
125
126async fn handle_download_split_impl(
127 store: Store,
128 pool: ConnectionPool,
129 request: DownloadRequest,
130 tx: &mut mpsc::Sender<DownloadProgressItem>,
131) -> anyhow::Result<()> {
132 let providers = request.providers;
133 let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
134 let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
135 let mut futs = stream::iter(requests.into_iter().enumerate())
136 .map(|(id, request)| {
137 let pool = pool.clone();
138 let providers = providers.clone();
139 let store = store.clone();
140 let progress_tx = progress_tx.clone();
141 async move {
142 let hash = request.hash;
143 let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
144 progress_tx.send(rx).await.ok();
145 let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
146 let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
147 (hash, res)
148 }
149 })
150 .buffered_unordered(32);
151 let mut progress_stream = {
152 let mut offsets = HashMap::new();
153 let mut total = 0;
154 into_stream(progress_rx)
155 .flat_map(into_stream)
156 .map(move |(id, item)| match item {
157 DownloadProgressItem::Progress(offset) => {
158 total += offset;
159 if let Some(prev) = offsets.insert(id, offset) {
160 total -= prev;
161 }
162 DownloadProgressItem::Progress(total)
163 }
164 x => x,
165 })
166 };
167 loop {
168 tokio::select! {
169 Some(item) = progress_stream.next() => {
170 tx.send(item).await?;
171 },
172 res = futs.next() => {
173 match res {
174 Some((_hash, Ok(()))) => {
175 }
176 Some((_hash, Err(_e))) => {
177 tx.send(DownloadProgressItem::DownloadError).await?;
178 }
179 None => break,
180 }
181 }
182 _ = tx.closed() => {
183 break;
185 }
186 }
187 }
188 Ok(())
189}
190
191fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
192 Gen::new(|co| async move {
193 while let Some(item) = recv.recv().await {
194 co.yield_(item).await;
195 }
196 })
197}
198
199#[derive(Debug, Serialize, Deserialize, derive_more::From)]
200pub enum FiniteRequest {
201 Get(GetRequest),
202 GetMany(GetManyRequest),
203}
204
205pub trait SupportedRequest {
206 fn into_request(self) -> FiniteRequest;
207}
208
209impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
210 fn into_request(self) -> FiniteRequest {
211 let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
212 FiniteRequest::GetMany(hashes)
213 }
214}
215
216impl SupportedRequest for GetRequest {
217 fn into_request(self) -> FiniteRequest {
218 self.into()
219 }
220}
221
222impl SupportedRequest for GetManyRequest {
223 fn into_request(self) -> FiniteRequest {
224 self.into()
225 }
226}
227
228impl SupportedRequest for Hash {
229 fn into_request(self) -> FiniteRequest {
230 GetRequest::blob(self).into()
231 }
232}
233
234impl SupportedRequest for HashAndFormat {
235 fn into_request(self) -> FiniteRequest {
236 (match self.format {
237 BlobFormat::Raw => GetRequest::blob(self.hash),
238 BlobFormat::HashSeq => GetRequest::all(self.hash),
239 })
240 .into()
241 }
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct AddProviderRequest {
246 pub hash: Hash,
247 pub providers: Vec<NodeId>,
248}
249
250#[derive(Debug)]
251pub struct DownloadRequest {
252 pub request: FiniteRequest,
253 pub providers: Arc<dyn ContentDiscovery>,
254 pub strategy: SplitStrategy,
255}
256
257impl DownloadRequest {
258 pub fn new(
259 request: impl SupportedRequest,
260 providers: impl ContentDiscovery,
261 strategy: SplitStrategy,
262 ) -> Self {
263 Self {
264 request: request.into_request(),
265 providers: Arc::new(providers),
266 strategy,
267 }
268 }
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub enum SplitStrategy {
273 None,
274 Split,
275}
276
277impl Serialize for DownloadRequest {
278 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
279 where
280 S: serde::Serializer,
281 {
282 Err(serde::ser::Error::custom(
283 "cannot serialize DownloadRequest",
284 ))
285 }
286}
287
288impl<'de> Deserialize<'de> for DownloadRequest {
290 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
291 where
292 D: serde::Deserializer<'de>,
293 {
294 Err(D::Error::custom("cannot deserialize DownloadRequest"))
295 }
296}
297
298pub type DownloadOptions = DownloadRequest;
299
300pub struct DownloadProgress {
301 fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
302}
303
304impl DownloadProgress {
305 fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
306 Self { fut }
307 }
308
309 pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
310 let rx = self.fut.await?;
311 Ok(Box::pin(rx.into_stream().map(|item| match item {
312 Ok(item) => item,
313 Err(e) => DownloadProgressItem::Error(e.into()),
314 })))
315 }
316
317 async fn complete(self) -> anyhow::Result<()> {
318 let rx = self.fut.await?;
319 let stream = rx.into_stream();
320 tokio::pin!(stream);
321 while let Some(item) = stream.next().await {
322 match item? {
323 DownloadProgressItem::Error(e) => Err(e)?,
324 DownloadProgressItem::DownloadError => anyhow::bail!("Download error"),
325 _ => {}
326 }
327 }
328 Ok(())
329 }
330}
331
332impl IntoFuture for DownloadProgress {
333 type Output = anyhow::Result<()>;
334 type IntoFuture = future::Boxed<Self::Output>;
335
336 fn into_future(self) -> Self::IntoFuture {
337 Box::pin(self.complete())
338 }
339}
340
341impl Downloader {
342 pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
343 let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
344 let actor = DownloaderActor::new(store.clone(), endpoint.clone());
345 tokio::spawn(actor.run(rx));
346 Self { client: tx.into() }
347 }
348
349 pub fn download(
350 &self,
351 request: impl SupportedRequest,
352 providers: impl ContentDiscovery,
353 ) -> DownloadProgress {
354 let request = request.into_request();
355 let providers = Arc::new(providers);
356 self.download_with_opts(DownloadOptions {
357 request,
358 providers,
359 strategy: SplitStrategy::None,
360 })
361 }
362
363 pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
364 let fut = self.client.server_streaming(options, 32);
365 DownloadProgress::new(Box::pin(fut))
366 }
367}
368
369async fn split_request<'a>(
371 request: &'a FiniteRequest,
372 providers: &Arc<dyn ContentDiscovery>,
373 pool: &ConnectionPool,
374 store: &Store,
375 progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
376) -> anyhow::Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
377 Ok(match request {
378 FiniteRequest::Get(req) => {
379 let Some(_first) = req.ranges.iter_infinite().next() else {
380 return Ok(Box::new(std::iter::empty()));
381 };
382 let first = GetRequest::blob(req.hash);
383 execute_get(pool, Arc::new(first), providers, store, progress).await?;
384 let size = store.observe(req.hash).await?.size();
385 anyhow::ensure!(size % 32 == 0, "Size is not a multiple of 32");
386 let n = size / 32;
387 Box::new(
388 req.ranges
389 .iter_infinite()
390 .take(n as usize + 1)
391 .enumerate()
392 .filter_map(|(i, ranges)| {
393 if i != 0 && !ranges.is_empty() {
394 Some(
395 GetRequest::builder()
396 .offset(i as u64, ranges.clone())
397 .build(req.hash),
398 )
399 } else {
400 None
401 }
402 }),
403 )
404 }
405 FiniteRequest::GetMany(req) => Box::new(
406 req.hashes
407 .iter()
408 .enumerate()
409 .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
410 ),
411 })
412}
413
414async fn execute_get(
427 pool: &ConnectionPool,
428 request: Arc<GetRequest>,
429 providers: &Arc<dyn ContentDiscovery>,
430 store: &Store,
431 mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
432) -> anyhow::Result<()> {
433 let remote = store.remote();
434 let mut providers = providers.find_providers(request.content());
435 while let Some(provider) = providers.next().await {
436 progress
437 .send(DownloadProgressItem::TryProvider {
438 id: provider,
439 request: request.clone(),
440 })
441 .await?;
442 let conn = pool.get_or_connect(provider);
443 let local = remote.local_for_request(request.clone()).await?;
444 if local.is_complete() {
445 return Ok(());
446 }
447 let local_bytes = local.local_bytes();
448 let Ok(conn) = conn.await else {
449 progress
450 .send(DownloadProgressItem::ProviderFailed {
451 id: provider,
452 request: request.clone(),
453 })
454 .await?;
455 continue;
456 };
457 match remote
458 .execute_get_sink(
459 &conn,
460 local.missing(),
461 (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
462 )
463 .await
464 {
465 Ok(_stats) => {
466 progress
467 .send(DownloadProgressItem::PartComplete {
468 request: request.clone(),
469 })
470 .await?;
471 return Ok(());
472 }
473 Err(_cause) => {
474 progress
475 .send(DownloadProgressItem::ProviderFailed {
476 id: provider,
477 request: request.clone(),
478 })
479 .await?;
480 continue;
481 }
482 }
483 }
484 bail!("Unable to download {}", request.hash);
485}
486
487pub trait ContentDiscovery: Debug + Send + Sync + 'static {
489 fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<NodeId>;
490}
491
492impl<C, I> ContentDiscovery for C
493where
494 C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
495 C::IntoIter: Send + Sync + 'static,
496 I: Into<NodeId> + Send + Sync + 'static,
497{
498 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
499 let providers = self.clone();
500 n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
501 }
502}
503
504#[derive(derive_more::Debug)]
505pub struct Shuffled {
506 nodes: Vec<NodeId>,
507}
508
509impl Shuffled {
510 pub fn new(nodes: Vec<NodeId>) -> Self {
511 Self { nodes }
512 }
513}
514
515impl ContentDiscovery for Shuffled {
516 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
517 let mut nodes = self.nodes.clone();
518 nodes.shuffle(&mut rand::thread_rng());
519 n0_future::stream::iter(nodes).boxed()
520 }
521}
522
523#[cfg(test)]
524#[cfg(feature = "fs-store")]
525mod tests {
526 use std::ops::Deref;
527
528 use bao_tree::ChunkRanges;
529 use iroh::Watcher;
530 use n0_future::StreamExt;
531 use testresult::TestResult;
532
533 use crate::{
534 api::{
535 blobs::AddBytesOptions,
536 downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
537 },
538 hashseq::HashSeq,
539 protocol::{GetManyRequest, GetRequest},
540 tests::node_test_setup_fs,
541 };
542
543 #[tokio::test]
544 #[ignore = "todo"]
545 async fn downloader_get_many_smoke() -> TestResult<()> {
546 let testdir = tempfile::tempdir()?;
547 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
548 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
549 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
550 let tt1 = store1.add_slice("hello world").await?;
551 let tt2 = store2.add_slice("hello world 2").await?;
552 let node1_addr = r1.endpoint().node_addr().initialized().await;
553 let node1_id = node1_addr.node_id;
554 let node2_addr = r2.endpoint().node_addr().initialized().await;
555 let node2_id = node2_addr.node_id;
556 let swarm = Downloader::new(&store3, r3.endpoint());
557 r3.endpoint().add_node_addr(node1_addr.clone())?;
558 r3.endpoint().add_node_addr(node2_addr.clone())?;
559 let request = GetManyRequest::builder()
560 .hash(tt1.hash, ChunkRanges::all())
561 .hash(tt2.hash, ChunkRanges::all())
562 .build();
563 let mut progress = swarm
564 .download(request, Shuffled::new(vec![node1_id, node2_id]))
565 .stream()
566 .await?;
567 while let Some(item) = progress.next().await {
568 println!("Got item: {item:?}");
569 }
570 assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
571 assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
572 Ok(())
573 }
574
575 #[tokio::test]
576 async fn downloader_get_smoke() -> TestResult<()> {
577 let testdir = tempfile::tempdir()?;
579 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
580 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
581 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
582 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
583 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
584 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
585 let root = store1
586 .add_bytes_with_opts(AddBytesOptions {
587 data: hs.clone().into(),
588 format: crate::BlobFormat::HashSeq,
589 })
590 .await?;
591 let node1_addr = r1.endpoint().node_addr().initialized().await;
592 let node1_id = node1_addr.node_id;
593 let node2_addr = r2.endpoint().node_addr().initialized().await;
594 let node2_id = node2_addr.node_id;
595 let swarm = Downloader::new(&store3, r3.endpoint());
596 r3.endpoint().add_node_addr(node1_addr.clone())?;
597 r3.endpoint().add_node_addr(node2_addr.clone())?;
598 let request = GetRequest::builder()
599 .root(ChunkRanges::all())
600 .next(ChunkRanges::all())
601 .next(ChunkRanges::all())
602 .build(root.hash);
603 if true {
604 let mut progress = swarm
605 .download_with_opts(DownloadOptions::new(
606 request,
607 [node1_id, node2_id],
608 SplitStrategy::Split,
609 ))
610 .stream()
611 .await?;
612 while let Some(item) = progress.next().await {
613 println!("Got item: {item:?}");
614 }
615 }
616 if false {
617 let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
618 let remote = store3.remote();
619 let _rh = remote
620 .execute_get(
621 conn.clone(),
622 GetRequest::builder()
623 .root(ChunkRanges::all())
624 .build(root.hash),
625 )
626 .await?;
627 let h1 = remote.execute_get(
628 conn.clone(),
629 GetRequest::builder()
630 .child(0, ChunkRanges::all())
631 .build(root.hash),
632 );
633 let h2 = remote.execute_get(
634 conn.clone(),
635 GetRequest::builder()
636 .child(1, ChunkRanges::all())
637 .build(root.hash),
638 );
639 h1.await?;
640 h2.await?;
641 }
642 Ok(())
643 }
644
645 #[tokio::test]
646 async fn downloader_get_all() -> TestResult<()> {
647 let testdir = tempfile::tempdir()?;
648 let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
649 let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
650 let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
651 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
652 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
653 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
654 let root = store1
655 .add_bytes_with_opts(AddBytesOptions {
656 data: hs.clone().into(),
657 format: crate::BlobFormat::HashSeq,
658 })
659 .await?;
660 let node1_addr = r1.endpoint().node_addr().initialized().await;
661 let node1_id = node1_addr.node_id;
662 let node2_addr = r2.endpoint().node_addr().initialized().await;
663 let node2_id = node2_addr.node_id;
664 let swarm = Downloader::new(&store3, r3.endpoint());
665 r3.endpoint().add_node_addr(node1_addr.clone())?;
666 r3.endpoint().add_node_addr(node2_addr.clone())?;
667 let request = GetRequest::all(root.hash);
668 let mut progress = swarm
669 .download_with_opts(DownloadOptions::new(
670 request,
671 [node1_id, node2_id],
672 SplitStrategy::Split,
673 ))
674 .stream()
675 .await?;
676 while let Some(item) = progress.next().await {
677 println!("Got item: {item:?}");
678 }
679 Ok(())
680 }
681}