iroh_blobs/api/
downloader.rs

1//! API for downloads from multiple nodes.
2use std::{
3    collections::{HashMap, HashSet},
4    fmt::Debug,
5    future::{Future, IntoFuture},
6    sync::Arc,
7};
8
9use anyhow::bail;
10use genawaiter::sync::Gen;
11use iroh::{Endpoint, NodeId};
12use irpc::{channel::mpsc, rpc_requests};
13use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt};
14use rand::seq::SliceRandom;
15use serde::{de::Error, Deserialize, Serialize};
16use tokio::task::JoinSet;
17use tracing::instrument::Instrument;
18
19use super::Store;
20use crate::{
21    protocol::{GetManyRequest, GetRequest},
22    util::{
23        connection_pool::ConnectionPool,
24        sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
25    },
26    BlobFormat, Hash, HashAndFormat,
27};
28
29#[derive(Debug, Clone)]
30pub struct Downloader {
31    client: irpc::Client<SwarmProtocol>,
32}
33
34#[rpc_requests(message = SwarmMsg, alias = "Msg")]
35#[derive(Debug, Serialize, Deserialize)]
36enum SwarmProtocol {
37    #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
38    Download(DownloadRequest),
39}
40
41struct DownloaderActor {
42    store: Store,
43    pool: ConnectionPool,
44    tasks: JoinSet<()>,
45    running: HashSet<tokio::task::Id>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49pub enum DownloadProgressItem {
50    #[serde(skip)]
51    Error(anyhow::Error),
52    TryProvider {
53        id: NodeId,
54        request: Arc<GetRequest>,
55    },
56    ProviderFailed {
57        id: NodeId,
58        request: Arc<GetRequest>,
59    },
60    PartComplete {
61        request: Arc<GetRequest>,
62    },
63    Progress(u64),
64    DownloadError,
65}
66
67impl DownloaderActor {
68    fn new(store: Store, endpoint: Endpoint) -> Self {
69        Self {
70            store,
71            pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
72            tasks: JoinSet::new(),
73            running: HashSet::new(),
74        }
75    }
76
77    async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
78        while let Some(msg) = rx.recv().await {
79            match msg {
80                SwarmMsg::Download(request) => {
81                    self.spawn(handle_download(
82                        self.store.clone(),
83                        self.pool.clone(),
84                        request,
85                    ));
86                }
87            }
88        }
89    }
90
91    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
92        let span = tracing::Span::current();
93        let id = self.tasks.spawn(fut.instrument(span)).id();
94        self.running.insert(id);
95    }
96}
97
98async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
99    let DownloadMsg { inner, mut tx, .. } = msg;
100    if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
101        tx.send(DownloadProgressItem::Error(cause)).await.ok();
102    }
103}
104
105async fn handle_download_impl(
106    store: Store,
107    pool: ConnectionPool,
108    request: DownloadRequest,
109    tx: &mut mpsc::Sender<DownloadProgressItem>,
110) -> anyhow::Result<()> {
111    match request.strategy {
112        SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
113        SplitStrategy::None => match request.request {
114            FiniteRequest::Get(get) => {
115                let sink = IrpcSenderRefSink(tx);
116                execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
117            }
118            FiniteRequest::GetMany(_) => {
119                handle_download_split_impl(store, pool, request, tx).await?
120            }
121        },
122    }
123    Ok(())
124}
125
126async fn handle_download_split_impl(
127    store: Store,
128    pool: ConnectionPool,
129    request: DownloadRequest,
130    tx: &mut mpsc::Sender<DownloadProgressItem>,
131) -> anyhow::Result<()> {
132    let providers = request.providers;
133    let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
134    let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
135    let mut futs = stream::iter(requests.into_iter().enumerate())
136        .map(|(id, request)| {
137            let pool = pool.clone();
138            let providers = providers.clone();
139            let store = store.clone();
140            let progress_tx = progress_tx.clone();
141            async move {
142                let hash = request.hash;
143                let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
144                progress_tx.send(rx).await.ok();
145                let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
146                let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
147                (hash, res)
148            }
149        })
150        .buffered_unordered(32);
151    let mut progress_stream = {
152        let mut offsets = HashMap::new();
153        let mut total = 0;
154        into_stream(progress_rx)
155            .flat_map(into_stream)
156            .map(move |(id, item)| match item {
157                DownloadProgressItem::Progress(offset) => {
158                    total += offset;
159                    if let Some(prev) = offsets.insert(id, offset) {
160                        total -= prev;
161                    }
162                    DownloadProgressItem::Progress(total)
163                }
164                x => x,
165            })
166    };
167    loop {
168        tokio::select! {
169            Some(item) = progress_stream.next() => {
170                tx.send(item).await?;
171            },
172            res = futs.next() => {
173                match res {
174                    Some((_hash, Ok(()))) => {
175                    }
176                    Some((_hash, Err(_e))) => {
177                        tx.send(DownloadProgressItem::DownloadError).await?;
178                    }
179                    None => break,
180                }
181            }
182            _ = tx.closed() => {
183                // The sender has been closed, we should stop processing.
184                break;
185            }
186        }
187    }
188    Ok(())
189}
190
191fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
192    Gen::new(|co| async move {
193        while let Some(item) = recv.recv().await {
194            co.yield_(item).await;
195        }
196    })
197}
198
199#[derive(Debug, Serialize, Deserialize, derive_more::From)]
200pub enum FiniteRequest {
201    Get(GetRequest),
202    GetMany(GetManyRequest),
203}
204
205pub trait SupportedRequest {
206    fn into_request(self) -> FiniteRequest;
207}
208
209impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
210    fn into_request(self) -> FiniteRequest {
211        let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
212        FiniteRequest::GetMany(hashes)
213    }
214}
215
216impl SupportedRequest for GetRequest {
217    fn into_request(self) -> FiniteRequest {
218        self.into()
219    }
220}
221
222impl SupportedRequest for GetManyRequest {
223    fn into_request(self) -> FiniteRequest {
224        self.into()
225    }
226}
227
228impl SupportedRequest for Hash {
229    fn into_request(self) -> FiniteRequest {
230        GetRequest::blob(self).into()
231    }
232}
233
234impl SupportedRequest for HashAndFormat {
235    fn into_request(self) -> FiniteRequest {
236        (match self.format {
237            BlobFormat::Raw => GetRequest::blob(self.hash),
238            BlobFormat::HashSeq => GetRequest::all(self.hash),
239        })
240        .into()
241    }
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct AddProviderRequest {
246    pub hash: Hash,
247    pub providers: Vec<NodeId>,
248}
249
250#[derive(Debug)]
251pub struct DownloadRequest {
252    pub request: FiniteRequest,
253    pub providers: Arc<dyn ContentDiscovery>,
254    pub strategy: SplitStrategy,
255}
256
257impl DownloadRequest {
258    pub fn new(
259        request: impl SupportedRequest,
260        providers: impl ContentDiscovery,
261        strategy: SplitStrategy,
262    ) -> Self {
263        Self {
264            request: request.into_request(),
265            providers: Arc::new(providers),
266            strategy,
267        }
268    }
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub enum SplitStrategy {
273    None,
274    Split,
275}
276
277impl Serialize for DownloadRequest {
278    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
279    where
280        S: serde::Serializer,
281    {
282        Err(serde::ser::Error::custom(
283            "cannot serialize DownloadRequest",
284        ))
285    }
286}
287
288// Implement Deserialize to always fail
289impl<'de> Deserialize<'de> for DownloadRequest {
290    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
291    where
292        D: serde::Deserializer<'de>,
293    {
294        Err(D::Error::custom("cannot deserialize DownloadRequest"))
295    }
296}
297
298pub type DownloadOptions = DownloadRequest;
299
300pub struct DownloadProgress {
301    fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
302}
303
304impl DownloadProgress {
305    fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
306        Self { fut }
307    }
308
309    pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
310        let rx = self.fut.await?;
311        Ok(Box::pin(rx.into_stream().map(|item| match item {
312            Ok(item) => item,
313            Err(e) => DownloadProgressItem::Error(e.into()),
314        })))
315    }
316
317    async fn complete(self) -> anyhow::Result<()> {
318        let rx = self.fut.await?;
319        let stream = rx.into_stream();
320        tokio::pin!(stream);
321        while let Some(item) = stream.next().await {
322            match item? {
323                DownloadProgressItem::Error(e) => Err(e)?,
324                DownloadProgressItem::DownloadError => anyhow::bail!("Download error"),
325                _ => {}
326            }
327        }
328        Ok(())
329    }
330}
331
332impl IntoFuture for DownloadProgress {
333    type Output = anyhow::Result<()>;
334    type IntoFuture = future::Boxed<Self::Output>;
335
336    fn into_future(self) -> Self::IntoFuture {
337        Box::pin(self.complete())
338    }
339}
340
341impl Downloader {
342    pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
343        let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
344        let actor = DownloaderActor::new(store.clone(), endpoint.clone());
345        tokio::spawn(actor.run(rx));
346        Self { client: tx.into() }
347    }
348
349    pub fn download(
350        &self,
351        request: impl SupportedRequest,
352        providers: impl ContentDiscovery,
353    ) -> DownloadProgress {
354        let request = request.into_request();
355        let providers = Arc::new(providers);
356        self.download_with_opts(DownloadOptions {
357            request,
358            providers,
359            strategy: SplitStrategy::None,
360        })
361    }
362
363    pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
364        let fut = self.client.server_streaming(options, 32);
365        DownloadProgress::new(Box::pin(fut))
366    }
367}
368
369/// Split a request into multiple requests that can be run in parallel.
370async fn split_request<'a>(
371    request: &'a FiniteRequest,
372    providers: &Arc<dyn ContentDiscovery>,
373    pool: &ConnectionPool,
374    store: &Store,
375    progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
376) -> anyhow::Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
377    Ok(match request {
378        FiniteRequest::Get(req) => {
379            let Some(_first) = req.ranges.iter_infinite().next() else {
380                return Ok(Box::new(std::iter::empty()));
381            };
382            let first = GetRequest::blob(req.hash);
383            execute_get(pool, Arc::new(first), providers, store, progress).await?;
384            let size = store.observe(req.hash).await?.size();
385            anyhow::ensure!(size % 32 == 0, "Size is not a multiple of 32");
386            let n = size / 32;
387            Box::new(
388                req.ranges
389                    .iter_infinite()
390                    .take(n as usize + 1)
391                    .enumerate()
392                    .filter_map(|(i, ranges)| {
393                        if i != 0 && !ranges.is_empty() {
394                            Some(
395                                GetRequest::builder()
396                                    .offset(i as u64, ranges.clone())
397                                    .build(req.hash),
398                            )
399                        } else {
400                            None
401                        }
402                    }),
403            )
404        }
405        FiniteRequest::GetMany(req) => Box::new(
406            req.hashes
407                .iter()
408                .enumerate()
409                .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
410        ),
411    })
412}
413
414/// Execute a get request sequentially for multiple providers.
415///
416/// It will try each provider in order
417/// until it finds one that can fulfill the request. When trying a new provider,
418/// it takes the progress from the previous providers into account, so e.g.
419/// if the first provider had the first 10% of the data, it will only ask the next
420/// provider for the remaining 90%.
421///
422/// This is fully sequential, so there will only be one request in flight at a time.
423///
424/// If the request is not complete after trying all providers, it will return an error.
425/// If the provider stream never ends, it will try indefinitely.
426async fn execute_get(
427    pool: &ConnectionPool,
428    request: Arc<GetRequest>,
429    providers: &Arc<dyn ContentDiscovery>,
430    store: &Store,
431    mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
432) -> anyhow::Result<()> {
433    let remote = store.remote();
434    let mut providers = providers.find_providers(request.content());
435    while let Some(provider) = providers.next().await {
436        progress
437            .send(DownloadProgressItem::TryProvider {
438                id: provider,
439                request: request.clone(),
440            })
441            .await?;
442        let conn = pool.get_or_connect(provider);
443        let local = remote.local_for_request(request.clone()).await?;
444        if local.is_complete() {
445            return Ok(());
446        }
447        let local_bytes = local.local_bytes();
448        let Ok(conn) = conn.await else {
449            progress
450                .send(DownloadProgressItem::ProviderFailed {
451                    id: provider,
452                    request: request.clone(),
453                })
454                .await?;
455            continue;
456        };
457        match remote
458            .execute_get_sink(
459                &conn,
460                local.missing(),
461                (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
462            )
463            .await
464        {
465            Ok(_stats) => {
466                progress
467                    .send(DownloadProgressItem::PartComplete {
468                        request: request.clone(),
469                    })
470                    .await?;
471                return Ok(());
472            }
473            Err(_cause) => {
474                progress
475                    .send(DownloadProgressItem::ProviderFailed {
476                        id: provider,
477                        request: request.clone(),
478                    })
479                    .await?;
480                continue;
481            }
482        }
483    }
484    bail!("Unable to download {}", request.hash);
485}
486
487/// Trait for pluggable content discovery strategies.
488pub trait ContentDiscovery: Debug + Send + Sync + 'static {
489    fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<NodeId>;
490}
491
492impl<C, I> ContentDiscovery for C
493where
494    C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
495    C::IntoIter: Send + Sync + 'static,
496    I: Into<NodeId> + Send + Sync + 'static,
497{
498    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
499        let providers = self.clone();
500        n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
501    }
502}
503
504#[derive(derive_more::Debug)]
505pub struct Shuffled {
506    nodes: Vec<NodeId>,
507}
508
509impl Shuffled {
510    pub fn new(nodes: Vec<NodeId>) -> Self {
511        Self { nodes }
512    }
513}
514
515impl ContentDiscovery for Shuffled {
516    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<NodeId> {
517        let mut nodes = self.nodes.clone();
518        nodes.shuffle(&mut rand::thread_rng());
519        n0_future::stream::iter(nodes).boxed()
520    }
521}
522
523#[cfg(test)]
524#[cfg(feature = "fs-store")]
525mod tests {
526    use std::ops::Deref;
527
528    use bao_tree::ChunkRanges;
529    use iroh::Watcher;
530    use n0_future::StreamExt;
531    use testresult::TestResult;
532
533    use crate::{
534        api::{
535            blobs::AddBytesOptions,
536            downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
537        },
538        hashseq::HashSeq,
539        protocol::{GetManyRequest, GetRequest},
540        tests::node_test_setup_fs,
541    };
542
543    #[tokio::test]
544    #[ignore = "todo"]
545    async fn downloader_get_many_smoke() -> TestResult<()> {
546        let testdir = tempfile::tempdir()?;
547        let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
548        let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
549        let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
550        let tt1 = store1.add_slice("hello world").await?;
551        let tt2 = store2.add_slice("hello world 2").await?;
552        let node1_addr = r1.endpoint().node_addr().initialized().await;
553        let node1_id = node1_addr.node_id;
554        let node2_addr = r2.endpoint().node_addr().initialized().await;
555        let node2_id = node2_addr.node_id;
556        let swarm = Downloader::new(&store3, r3.endpoint());
557        r3.endpoint().add_node_addr(node1_addr.clone())?;
558        r3.endpoint().add_node_addr(node2_addr.clone())?;
559        let request = GetManyRequest::builder()
560            .hash(tt1.hash, ChunkRanges::all())
561            .hash(tt2.hash, ChunkRanges::all())
562            .build();
563        let mut progress = swarm
564            .download(request, Shuffled::new(vec![node1_id, node2_id]))
565            .stream()
566            .await?;
567        while let Some(item) = progress.next().await {
568            println!("Got item: {item:?}");
569        }
570        assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
571        assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
572        Ok(())
573    }
574
575    #[tokio::test]
576    async fn downloader_get_smoke() -> TestResult<()> {
577        // tracing_subscriber::fmt::try_init().ok();
578        let testdir = tempfile::tempdir()?;
579        let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
580        let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
581        let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
582        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
583        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
584        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
585        let root = store1
586            .add_bytes_with_opts(AddBytesOptions {
587                data: hs.clone().into(),
588                format: crate::BlobFormat::HashSeq,
589            })
590            .await?;
591        let node1_addr = r1.endpoint().node_addr().initialized().await;
592        let node1_id = node1_addr.node_id;
593        let node2_addr = r2.endpoint().node_addr().initialized().await;
594        let node2_id = node2_addr.node_id;
595        let swarm = Downloader::new(&store3, r3.endpoint());
596        r3.endpoint().add_node_addr(node1_addr.clone())?;
597        r3.endpoint().add_node_addr(node2_addr.clone())?;
598        let request = GetRequest::builder()
599            .root(ChunkRanges::all())
600            .next(ChunkRanges::all())
601            .next(ChunkRanges::all())
602            .build(root.hash);
603        if true {
604            let mut progress = swarm
605                .download_with_opts(DownloadOptions::new(
606                    request,
607                    [node1_id, node2_id],
608                    SplitStrategy::Split,
609                ))
610                .stream()
611                .await?;
612            while let Some(item) = progress.next().await {
613                println!("Got item: {item:?}");
614            }
615        }
616        if false {
617            let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
618            let remote = store3.remote();
619            let _rh = remote
620                .execute_get(
621                    conn.clone(),
622                    GetRequest::builder()
623                        .root(ChunkRanges::all())
624                        .build(root.hash),
625                )
626                .await?;
627            let h1 = remote.execute_get(
628                conn.clone(),
629                GetRequest::builder()
630                    .child(0, ChunkRanges::all())
631                    .build(root.hash),
632            );
633            let h2 = remote.execute_get(
634                conn.clone(),
635                GetRequest::builder()
636                    .child(1, ChunkRanges::all())
637                    .build(root.hash),
638            );
639            h1.await?;
640            h2.await?;
641        }
642        Ok(())
643    }
644
645    #[tokio::test]
646    async fn downloader_get_all() -> TestResult<()> {
647        let testdir = tempfile::tempdir()?;
648        let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?;
649        let (r2, store2, _) = node_test_setup_fs(testdir.path().join("b")).await?;
650        let (r3, store3, _) = node_test_setup_fs(testdir.path().join("c")).await?;
651        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
652        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
653        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
654        let root = store1
655            .add_bytes_with_opts(AddBytesOptions {
656                data: hs.clone().into(),
657                format: crate::BlobFormat::HashSeq,
658            })
659            .await?;
660        let node1_addr = r1.endpoint().node_addr().initialized().await;
661        let node1_id = node1_addr.node_id;
662        let node2_addr = r2.endpoint().node_addr().initialized().await;
663        let node2_id = node2_addr.node_id;
664        let swarm = Downloader::new(&store3, r3.endpoint());
665        r3.endpoint().add_node_addr(node1_addr.clone())?;
666        r3.endpoint().add_node_addr(node2_addr.clone())?;
667        let request = GetRequest::all(root.hash);
668        let mut progress = swarm
669            .download_with_opts(DownloadOptions::new(
670                request,
671                [node1_id, node2_id],
672                SplitStrategy::Split,
673            ))
674            .stream()
675            .await?;
676        while let Some(item) = progress.next().await {
677            println!("Got item: {item:?}");
678        }
679        Ok(())
680    }
681}