iroh_blobs/api/
remote.rs

1//! API for downloading blobs from a single remote node.
2//!
3//! The entry point is the [`Remote`] struct.
4use genawaiter::sync::{Co, Gen};
5use iroh::endpoint::SendStream;
6use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
7use n0_future::{io, Stream, StreamExt};
8use n0_snafu::SpanTrace;
9use nested_enum_utils::common_fields;
10use ref_cast::RefCast;
11use snafu::{Backtrace, IntoError, Snafu};
12
13use super::blobs::{Bitfield, ExportBaoOptions};
14use crate::{
15    api::{blobs::WriteProgress, ApiClient},
16    get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats},
17    protocol::{
18        GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
19        MAX_MESSAGE_SIZE,
20    },
21    provider::events::{ClientResult, ProgressError},
22    util::sink::{Sink, TokioMpscSenderSink},
23};
24
25/// API to compute request and to download from remote nodes.
26///
27/// Usually you want to first find out what, if any, data you have locally.
28/// This can be done using [`Remote::local`], which inspects the local store
29/// and returns a [`LocalInfo`].
30///
31/// From this you can compute various values such as the number of locally present
32/// bytes. You can also compute a request to get the missing data using [`LocalInfo::missing`].
33///
34/// Once you have a request, you can execute it using [`Remote::execute_get`].
35/// Executing a request will store to the local store, but otherwise does not take
36/// the available data into account.
37///
38/// If you are not interested in the details and just want your data, you can use
39/// [`Remote::fetch`]. This will internally do the dance described above.
40#[derive(Debug, Clone, RefCast)]
41#[repr(transparent)]
42pub struct Remote {
43    client: ApiClient,
44}
45
46#[derive(Debug)]
47pub enum GetProgressItem {
48    /// Progress on the payload bytes read.
49    Progress(u64),
50    /// The request was completed.
51    Done(Stats),
52    /// The request was closed, but not completed.
53    Error(GetError),
54}
55
56impl From<GetResult<Stats>> for GetProgressItem {
57    fn from(res: GetResult<Stats>) -> Self {
58        match res {
59            Ok(stats) => GetProgressItem::Done(stats),
60            Err(e) => GetProgressItem::Error(e),
61        }
62    }
63}
64
65impl TryFrom<GetProgressItem> for GetResult<Stats> {
66    type Error = &'static str;
67
68    fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
69        match item {
70            GetProgressItem::Done(stats) => Ok(Ok(stats)),
71            GetProgressItem::Error(e) => Ok(Err(e)),
72            GetProgressItem::Progress(_) => Err("not a final item"),
73        }
74    }
75}
76
77pub struct GetProgress {
78    rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
79    fut: n0_future::boxed::BoxFuture<()>,
80}
81
82impl IntoFuture for GetProgress {
83    type Output = GetResult<Stats>;
84    type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
85
86    fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
87        Box::pin(self.complete())
88    }
89}
90
91impl GetProgress {
92    pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
93        into_stream(self.rx, self.fut)
94    }
95
96    pub async fn complete(self) -> GetResult<Stats> {
97        just_result(self.stream()).await.unwrap_or_else(|| {
98            Err(LocalFailureSnafu
99                .into_error(anyhow::anyhow!("stream closed without result").into()))
100        })
101    }
102}
103
104#[derive(Debug)]
105pub enum PushProgressItem {
106    /// Progress on the payload bytes read.
107    Progress(u64),
108    /// The request was completed.
109    Done(Stats),
110    /// The request was closed, but not completed.
111    Error(anyhow::Error),
112}
113
114impl From<anyhow::Result<Stats>> for PushProgressItem {
115    fn from(res: anyhow::Result<Stats>) -> Self {
116        match res {
117            Ok(stats) => Self::Done(stats),
118            Err(e) => Self::Error(e),
119        }
120    }
121}
122
123impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
124    type Error = &'static str;
125
126    fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
127        match item {
128            PushProgressItem::Done(stats) => Ok(Ok(stats)),
129            PushProgressItem::Error(e) => Ok(Err(e)),
130            PushProgressItem::Progress(_) => Err("not a final item"),
131        }
132    }
133}
134
135pub struct PushProgress {
136    rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
137    fut: n0_future::boxed::BoxFuture<()>,
138}
139
140impl IntoFuture for PushProgress {
141    type Output = anyhow::Result<Stats>;
142    type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
143
144    fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
145        Box::pin(self.complete())
146    }
147}
148
149impl PushProgress {
150    pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
151        into_stream(self.rx, self.fut)
152    }
153
154    pub async fn complete(self) -> anyhow::Result<Stats> {
155        just_result(self.stream())
156            .await
157            .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
158    }
159}
160
161async fn just_result<S, R>(stream: S) -> Option<R>
162where
163    S: Stream<Item: std::fmt::Debug>,
164    R: TryFrom<S::Item>,
165{
166    tokio::pin!(stream);
167    while let Some(item) = stream.next().await {
168        if let Ok(res) = R::try_from(item) {
169            return Some(res);
170        }
171    }
172    None
173}
174
175fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
176where
177    F: Future,
178{
179    Gen::new(move |co| async move {
180        tokio::pin!(fut);
181        loop {
182            tokio::select! {
183                biased;
184                item = rx.recv() => {
185                    if let Some(item) = item {
186                        co.yield_(item).await;
187                    } else {
188                        break;
189                    }
190                }
191                _ = &mut fut => {
192                    break;
193                }
194            }
195        }
196        while let Some(item) = rx.recv().await {
197            co.yield_(item).await;
198        }
199    })
200}
201
202/// Local info for a blob or hash sequence.
203///
204/// This can be used to get the amount of missing data, and to construct a
205/// request to get the missing data.
206#[derive(Debug)]
207pub struct LocalInfo {
208    /// The hash for which this is the local info
209    request: Arc<GetRequest>,
210    /// The bitfield for the root hash
211    bitfield: Bitfield,
212    /// Optional - the hash sequence info if this was a request for a hash sequence
213    children: Option<NonRawLocalInfo>,
214}
215
216impl LocalInfo {
217    /// The number of bytes we have locally
218    pub fn local_bytes(&self) -> u64 {
219        let Some(root_requested) = self.requested_root_ranges() else {
220            // empty request requests 0 bytes
221            return 0;
222        };
223        let mut local = self.bitfield.clone();
224        local.ranges.intersection_with(root_requested);
225        let mut res = local.total_bytes();
226        if let Some(children) = &self.children {
227            let Some(max_local_index) = children.hash_seq.keys().next_back() else {
228                // no children
229                return res;
230            };
231            for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
232                if offset == 0 {
233                    // skip the root hash
234                    continue;
235                }
236                let child = offset - 1;
237                if child > *max_local_index {
238                    // we are done
239                    break;
240                }
241                let Some(hash) = children.hash_seq.get(&child) else {
242                    continue;
243                };
244                let bitfield = &children.bitfields[hash];
245                let mut local = bitfield.clone();
246                local.ranges.intersection_with(ranges);
247                res += local.total_bytes();
248            }
249        }
250        res
251    }
252
253    /// Number of children in this hash sequence
254    pub fn children(&self) -> Option<u64> {
255        if self.children.is_some() {
256            self.bitfield.validated_size().map(|x| x / 32)
257        } else {
258            Some(0)
259        }
260    }
261
262    /// The requested root ranges.
263    ///
264    /// This will return None if the request is empty, and an empty CHunkRanges
265    /// if no ranges were requested for the root hash.
266    fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
267        self.request.ranges.iter().next()
268    }
269
270    /// True if the data is complete.
271    ///
272    /// For a blob, this is true if the blob is complete.
273    /// For a hash sequence, this is true if the hash sequence is complete and
274    /// all its children are complete.
275    pub fn is_complete(&self) -> bool {
276        let Some(root_requested) = self.requested_root_ranges() else {
277            // empty request is complete
278            return true;
279        };
280        if !self.bitfield.ranges.is_superset(root_requested) {
281            return false;
282        }
283        if let Some(children) = self.children.as_ref() {
284            let mut iter = self.request.ranges.iter_non_empty_infinite();
285            let max_child = self.bitfield.validated_size().map(|x| x / 32);
286            loop {
287                let Some((offset, range)) = iter.next() else {
288                    break;
289                };
290                if offset == 0 {
291                    // skip the root hash
292                    continue;
293                }
294                let child = offset - 1;
295                if let Some(hash) = children.hash_seq.get(&child) {
296                    let bitfield = &children.bitfields[hash];
297                    if !bitfield.ranges.is_superset(range) {
298                        // we don't have the requested ranges
299                        return false;
300                    }
301                } else {
302                    if let Some(max_child) = max_child {
303                        if child >= max_child {
304                            // reading after the end of the request
305                            return true;
306                        }
307                    }
308                    return false;
309                }
310            }
311        }
312        true
313    }
314
315    /// A request to get the missing data to complete this request
316    pub fn missing(&self) -> GetRequest {
317        let Some(root_requested) = self.requested_root_ranges() else {
318            // empty request is complete
319            return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
320        };
321        let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
322
323        let Some(children) = self.children.as_ref() else {
324            return builder.build(self.request.hash);
325        };
326        let mut iter = self.request.ranges.iter_non_empty_infinite();
327        let max_local = children
328            .hash_seq
329            .keys()
330            .next_back()
331            .map(|x| *x + 1)
332            .unwrap_or_default();
333        let max_offset = self.bitfield.validated_size().map(|x| x / 32);
334        loop {
335            let Some((offset, requested)) = iter.next() else {
336                break;
337            };
338            if offset == 0 {
339                // skip the root hash
340                continue;
341            }
342            let child = offset - 1;
343            let missing = match children.hash_seq.get(&child) {
344                Some(hash) => requested.difference(&children.bitfields[hash].ranges),
345                None => requested.clone(),
346            };
347            builder = builder.child(child, missing);
348            if offset >= max_local {
349                // we can't do anything clever anymore
350                break;
351            }
352        }
353        loop {
354            let Some((offset, requested)) = iter.next() else {
355                return builder.build(self.request.hash);
356            };
357            if offset == 0 {
358                // skip the root hash
359                continue;
360            }
361            let child = offset - 1;
362            if let Some(max_offset) = &max_offset {
363                if child >= *max_offset {
364                    return builder.build(self.request.hash);
365                }
366                builder = builder.child(child, requested.clone());
367            } else {
368                builder = builder.child(child, requested.clone());
369                if iter.is_at_end() {
370                    if iter.next().is_none() {
371                        return builder.build(self.request.hash);
372                    } else {
373                        return builder.build_open(self.request.hash);
374                    }
375                }
376            }
377        }
378    }
379}
380
381#[derive(Debug)]
382struct NonRawLocalInfo {
383    /// the available and relevant part of the hash sequence
384    hash_seq: BTreeMap<u64, Hash>,
385    /// For each hash relevant to the request, the local bitfield and the ranges
386    /// that were requested.
387    bitfields: BTreeMap<Hash, Bitfield>,
388}
389
390// fn iter_without_gaps<'a, T: Copy + 'a>(
391//     iter: impl IntoIterator<Item = &'a (u64, T)> + 'a,
392// ) -> impl Iterator<Item = (u64, Option<T>)> + 'a {
393//     let mut prev = 0;
394//     iter.into_iter().flat_map(move |(i, hash)| {
395//         let start = prev + 1;
396//         let curr = *i;
397//         prev = *i;
398//         (start..curr)
399//             .map(|i| (i, None))
400//             .chain(std::iter::once((curr, Some(*hash))))
401//     })
402// }
403
404impl Remote {
405    pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
406        Self::ref_cast(sender)
407    }
408
409    fn store(&self) -> &Store {
410        Store::ref_from_sender(&self.client)
411    }
412
413    pub async fn local_for_request(
414        &self,
415        request: impl Into<Arc<GetRequest>>,
416    ) -> anyhow::Result<LocalInfo> {
417        let request = request.into();
418        let root = request.hash;
419        let bitfield = self.store().observe(root).await?;
420        let children = if !request.ranges.is_blob() {
421            let opts = ExportBaoOptions {
422                hash: root,
423                ranges: bitfield.ranges.clone(),
424            };
425            let bao = self.store().export_bao_with_opts(opts, 32);
426            let mut by_index = BTreeMap::new();
427            let mut stream = bao.hashes_with_index();
428            while let Some(item) = stream.next().await {
429                if let Ok((index, hash)) = item {
430                    by_index.insert(index, hash);
431                }
432            }
433            let mut bitfields = BTreeMap::new();
434            let mut hash_seq = BTreeMap::new();
435            let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
436            for (index, _) in request.ranges.iter_non_empty_infinite() {
437                if index == 0 {
438                    // skip the root hash
439                    continue;
440                }
441                let child = index - 1;
442                if child > max {
443                    // we are done
444                    break;
445                }
446                let Some(hash) = by_index.get(&child) else {
447                    // we don't have the hash, so we can't store the bitfield
448                    continue;
449                };
450                let bitfield = self.store().observe(*hash).await?;
451                bitfields.insert(*hash, bitfield);
452                hash_seq.insert(child, *hash);
453            }
454            Some(NonRawLocalInfo {
455                hash_seq,
456                bitfields,
457            })
458        } else {
459            None
460        };
461        Ok(LocalInfo {
462            request: request.clone(),
463            bitfield,
464            children,
465        })
466    }
467
468    /// Get the local info for a given blob or hash sequence, at the present time.
469    pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
470        let request = GetRequest::from(content.into());
471        self.local_for_request(request).await
472    }
473
474    pub fn fetch(
475        &self,
476        conn: impl GetConnection + Send + 'static,
477        content: impl Into<HashAndFormat>,
478    ) -> GetProgress {
479        let content = content.into();
480        let (tx, rx) = tokio::sync::mpsc::channel(64);
481        let tx2 = tx.clone();
482        let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
483        let this = self.clone();
484        let fut = async move {
485            let res = this.fetch_sink(conn, content, sink).await.into();
486            tx2.send(res).await.ok();
487        };
488        GetProgress {
489            rx,
490            fut: Box::pin(fut),
491        }
492    }
493
494    /// Get a blob or hash sequence from the given connection, taking the locally available
495    /// ranges into account.
496    ///
497    /// You can provide a progress channel to get updates on the download progress. This progress
498    /// is the aggregated number of downloaded payload bytes in the request.
499    ///
500    /// This will return the stats of the download.
501    pub(crate) async fn fetch_sink(
502        &self,
503        mut conn: impl GetConnection,
504        content: impl Into<HashAndFormat>,
505        progress: impl Sink<u64, Error = irpc::channel::SendError>,
506    ) -> GetResult<Stats> {
507        let content = content.into();
508        let local = self
509            .local(content)
510            .await
511            .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
512        if local.is_complete() {
513            return Ok(Default::default());
514        }
515        let request = local.missing();
516        let conn = conn
517            .connection()
518            .await
519            .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
520        let stats = self.execute_get_sink(&conn, request, progress).await?;
521        Ok(stats)
522    }
523
524    pub fn observe(
525        &self,
526        conn: Connection,
527        request: ObserveRequest,
528    ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
529        Gen::new(|co| async move {
530            if let Err(cause) = Self::observe_impl(conn, request, &co).await {
531                co.yield_(Err(cause)).await
532            }
533        })
534    }
535
536    async fn observe_impl(
537        conn: Connection,
538        request: ObserveRequest,
539        co: &Co<io::Result<Bitfield>>,
540    ) -> io::Result<()> {
541        let hash = request.hash;
542        debug!(%hash, "observing");
543        let (mut send, mut recv) = conn.open_bi().await?;
544        // write the request. Unlike for reading, we can just serialize it sync using postcard.
545        write_observe_request(request, &mut send).await?;
546        send.finish()?;
547        loop {
548            let msg = recv
549                .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
550                .await?;
551            co.yield_(Ok(Bitfield::from(&msg))).await;
552        }
553    }
554
555    pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
556        let (tx, rx) = tokio::sync::mpsc::channel(64);
557        let tx2 = tx.clone();
558        let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
559        let this = self.clone();
560        let fut = async move {
561            let res = this.execute_push_sink(conn, request, sink).await.into();
562            tx2.send(res).await.ok();
563        };
564        PushProgress {
565            rx,
566            fut: Box::pin(fut),
567        }
568    }
569
570    /// Push the given blob or hash sequence to a remote node.
571    ///
572    /// Note that many nodes will reject push requests. Also, this is an experimental feature for now.
573    pub(crate) async fn execute_push_sink(
574        &self,
575        conn: Connection,
576        request: PushRequest,
577        progress: impl Sink<u64, Error = irpc::channel::SendError>,
578    ) -> anyhow::Result<Stats> {
579        let hash = request.hash;
580        debug!(%hash, "pushing");
581        let (mut send, mut recv) = conn.open_bi().await?;
582        let mut context = StreamContext {
583            payload_bytes_sent: 0,
584            sender: progress,
585        };
586        // we are not going to need this!
587        recv.stop(0u32.into())?;
588        // write the request. Unlike for reading, we can just serialize it sync using postcard.
589        let request = write_push_request(request, &mut send).await?;
590        let mut request_ranges = request.ranges.iter_infinite();
591        let root = request.hash;
592        let root_ranges = request_ranges.next().expect("infinite iterator");
593        if !root_ranges.is_empty() {
594            self.store()
595                .export_bao(root, root_ranges.clone())
596                .write_quinn_with_progress(&mut send, &mut context, &root, 0)
597                .await?;
598        }
599        if request.ranges.is_blob() {
600            // we are done
601            send.finish()?;
602            return Ok(Default::default());
603        }
604        let hash_seq = self.store().get_bytes(root).await?;
605        let hash_seq = HashSeq::try_from(hash_seq)?;
606        for (child, (child_hash, child_ranges)) in
607            hash_seq.into_iter().zip(request_ranges).enumerate()
608        {
609            if !child_ranges.is_empty() {
610                self.store()
611                    .export_bao(child_hash, child_ranges.clone())
612                    .write_quinn_with_progress(
613                        &mut send,
614                        &mut context,
615                        &child_hash,
616                        (child + 1) as u64,
617                    )
618                    .await?;
619            }
620        }
621        send.finish()?;
622        Ok(Default::default())
623    }
624
625    pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress {
626        self.execute_get_with_opts(conn, request)
627    }
628
629    pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
630        let (tx, rx) = tokio::sync::mpsc::channel(64);
631        let tx2 = tx.clone();
632        let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
633        let this = self.clone();
634        let fut = async move {
635            let res = this.execute_get_sink(&conn, request, sink).await.into();
636            tx2.send(res).await.ok();
637        };
638        GetProgress {
639            rx,
640            fut: Box::pin(fut),
641        }
642    }
643
644    /// Execute a get request *without* taking the locally available ranges into account.
645    ///
646    /// You can provide a progress channel to get updates on the download progress. This progress
647    /// is the aggregated number of downloaded payload bytes in the request.
648    ///
649    /// This will download the data again even if the data is locally present.
650    ///
651    /// This will return the stats of the download.
652    pub(crate) async fn execute_get_sink(
653        &self,
654        conn: &Connection,
655        request: GetRequest,
656        mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
657    ) -> GetResult<Stats> {
658        let store = self.store();
659        let root = request.hash;
660        // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive
661        // for the duration of the operation.
662        let start = crate::get::fsm::start(conn.clone(), request, Default::default());
663        let connected = start.next().await?;
664        trace!("Getting header");
665        // read the header
666        let next_child = match connected.next().await? {
667            ConnectedNext::StartRoot(at_start_root) => {
668                let header = at_start_root.next();
669                let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
670                match end.next() {
671                    EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
672                    EndBlobNext::Closing(at_closing) => Err(at_closing),
673                }
674            }
675            ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
676            ConnectedNext::Closing(at_closing) => Err(at_closing),
677        };
678        // read the rest, if any
679        let at_closing = match next_child {
680            Ok(at_start_child) => {
681                let mut next_child = Ok(at_start_child);
682                let hash_seq = HashSeq::try_from(
683                    store
684                        .get_bytes(root)
685                        .await
686                        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
687                )
688                .map_err(|source| BadRequestSnafu.into_error(source.into()))?;
689                // let mut hash_seq = LazyHashSeq::new(store.blobs().clone(), root);
690                loop {
691                    let at_start_child = match next_child {
692                        Ok(at_start_child) => at_start_child,
693                        Err(at_closing) => break at_closing,
694                    };
695                    let offset = at_start_child.offset() - 1;
696                    let Some(hash) = hash_seq.get(offset as usize) else {
697                        break at_start_child.finish();
698                    };
699                    trace!("getting child {offset} {}", hash.fmt_short());
700                    let header = at_start_child.next(hash);
701                    let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
702                    next_child = match end.next() {
703                        EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
704                        EndBlobNext::Closing(at_closing) => Err(at_closing),
705                    }
706                }
707            }
708            Err(at_closing) => at_closing,
709        };
710        // read the rest, if any
711        let stats = at_closing.next().await?;
712        trace!(?stats, "get hash seq done");
713        Ok(stats)
714    }
715
716    pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
717        let (tx, rx) = tokio::sync::mpsc::channel(64);
718        let tx2 = tx.clone();
719        let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
720        let this = self.clone();
721        let fut = async move {
722            let res = this.execute_get_many_sink(conn, request, sink).await.into();
723            tx2.send(res).await.ok();
724        };
725        GetProgress {
726            rx,
727            fut: Box::pin(fut),
728        }
729    }
730
731    /// Execute a get request *without* taking the locally available ranges into account.
732    ///
733    /// You can provide a progress channel to get updates on the download progress. This progress
734    /// is the aggregated number of downloaded payload bytes in the request.
735    ///
736    /// This will download the data again even if the data is locally present.
737    ///
738    /// This will return the stats of the download.
739    pub async fn execute_get_many_sink(
740        &self,
741        conn: Connection,
742        request: GetManyRequest,
743        mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
744    ) -> GetResult<Stats> {
745        let store = self.store();
746        let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
747        let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
748        // read all children.
749        let at_closing = match next_child {
750            Ok(at_start_child) => {
751                let mut next_child = Ok(at_start_child);
752                loop {
753                    let at_start_child = match next_child {
754                        Ok(at_start_child) => at_start_child,
755                        Err(at_closing) => break at_closing,
756                    };
757                    let offset = at_start_child.offset();
758                    println!("offset {offset}");
759                    let Some(hash) = hash_seq.get(offset as usize) else {
760                        break at_start_child.finish();
761                    };
762                    trace!("getting child {offset} {}", hash.fmt_short());
763                    let header = at_start_child.next(hash);
764                    let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
765                    next_child = match end.next() {
766                        EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
767                        EndBlobNext::Closing(at_closing) => Err(at_closing),
768                    }
769                }
770            }
771            Err(at_closing) => at_closing,
772        };
773        // read the rest, if any
774        let stats = at_closing.next().await?;
775        trace!(?stats, "get hash seq done");
776        Ok(stats)
777    }
778}
779
780/// Failures for a get operation
781#[common_fields({
782    backtrace: Option<Backtrace>,
783    #[snafu(implicit)]
784    span_trace: SpanTrace,
785})]
786#[allow(missing_docs)]
787#[non_exhaustive]
788#[derive(Debug, Snafu)]
789pub enum ExecuteError {
790    /// Network or IO operation failed.
791    #[snafu(display("Unable to open bidi stream"))]
792    Connection {
793        source: iroh::endpoint::ConnectionError,
794    },
795    #[snafu(display("Unable to read from the remote"))]
796    Read { source: iroh::endpoint::ReadError },
797    #[snafu(display("Error sending the request"))]
798    Send {
799        source: crate::get::fsm::ConnectedNextError,
800    },
801    #[snafu(display("Unable to read size"))]
802    Size {
803        source: crate::get::fsm::AtBlobHeaderNextError,
804    },
805    #[snafu(display("Error while decoding the data"))]
806    Decode {
807        source: crate::get::fsm::DecodeError,
808    },
809    #[snafu(display("Internal error while reading the hash sequence"))]
810    ExportBao { source: api::ExportBaoError },
811    #[snafu(display("Hash sequence has an invalid length"))]
812    InvalidHashSeq { source: anyhow::Error },
813    #[snafu(display("Internal error importing the data"))]
814    ImportBao { source: crate::api::RequestError },
815    #[snafu(display("Error sending download progress - receiver closed"))]
816    SendDownloadProgress { source: irpc::channel::SendError },
817    #[snafu(display("Internal error importing the data"))]
818    MpscSend {
819        source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
820    },
821}
822
823use std::{
824    collections::BTreeMap,
825    future::{Future, IntoFuture},
826    num::NonZeroU64,
827    sync::Arc,
828};
829
830use bao_tree::{
831    io::{BaoContentItem, Leaf},
832    ChunkNum, ChunkRanges,
833};
834use iroh::endpoint::Connection;
835use tracing::{debug, trace};
836
837use crate::{
838    api::{self, blobs::Blobs, Store},
839    get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext},
840    hashseq::{HashSeq, HashSeqIter},
841    protocol::{ChunkRangesSeq, GetRequest},
842    store::IROH_BLOCK_SIZE,
843    Hash, HashAndFormat,
844};
845
846/// Trait to lazily get a connection
847pub trait GetConnection {
848    fn connection(&mut self)
849        -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_;
850}
851
852/// If we already have a connection, the impl is trivial
853impl GetConnection for Connection {
854    fn connection(
855        &mut self,
856    ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
857        let conn = self.clone();
858        async { Ok(conn) }
859    }
860}
861
862/// If we already have a connection, the impl is trivial
863impl GetConnection for &Connection {
864    fn connection(
865        &mut self,
866    ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
867        let conn = self.clone();
868        async { Ok(conn) }
869    }
870}
871
872fn get_buffer_size(size: NonZeroU64) -> usize {
873    (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
874}
875
876async fn get_blob_ranges_impl(
877    header: AtBlobHeader,
878    hash: Hash,
879    store: &Store,
880    mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
881) -> GetResult<AtEndBlob> {
882    let (mut content, size) = header.next().await?;
883    let Some(size) = NonZeroU64::new(size) else {
884        return if hash == Hash::EMPTY {
885            let end = content.drain().await?;
886            Ok(end)
887        } else {
888            Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
889        };
890    };
891    let buffer_size = get_buffer_size(size);
892    trace!(%size, %buffer_size, "get blob");
893    let handle = store
894        .import_bao(hash, size, buffer_size)
895        .await
896        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
897    let write = async move {
898        GetResult::Ok(loop {
899            match content.next().await {
900                BlobContentNext::More((next, res)) => {
901                    let item = res?;
902                    progress
903                        .send(next.stats().payload_bytes_read)
904                        .await
905                        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
906                    handle.tx.send(item).await?;
907                    content = next;
908                }
909                BlobContentNext::Done(end) => {
910                    drop(handle.tx);
911                    break end;
912                }
913            }
914        })
915    };
916    let complete = async move {
917        handle.rx.await.map_err(|e| {
918            LocalFailureSnafu
919                .into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
920        })
921    };
922    let (_, end) = tokio::try_join!(complete, write)?;
923    Ok(end)
924}
925
926#[derive(Debug)]
927pub(crate) struct LazyHashSeq {
928    blobs: Blobs,
929    hash: Hash,
930    current_chunk: Option<HashSeqChunk>,
931}
932
933#[derive(Debug)]
934pub(crate) struct HashSeqChunk {
935    /// the offset of the first hash in this chunk, in bytes
936    offset: u64,
937    /// the hashes in this chunk
938    chunk: HashSeq,
939}
940
941impl TryFrom<Leaf> for HashSeqChunk {
942    type Error = anyhow::Error;
943
944    fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
945        let offset = leaf.offset;
946        let chunk = HashSeq::try_from(leaf.data)?;
947        Ok(Self { offset, chunk })
948    }
949}
950
951impl IntoIterator for HashSeqChunk {
952    type Item = Hash;
953    type IntoIter = HashSeqIter;
954
955    fn into_iter(self) -> Self::IntoIter {
956        self.chunk.into_iter()
957    }
958}
959
960impl HashSeqChunk {
961    pub fn base(&self) -> u64 {
962        self.offset / 32
963    }
964
965    #[allow(dead_code)]
966    fn get(&self, offset: u64) -> Option<Hash> {
967        let start = self.offset;
968        let end = start + self.chunk.len() as u64;
969        if offset >= start && offset < end {
970            let o = (offset - start) as usize;
971            self.chunk.get(o)
972        } else {
973            None
974        }
975    }
976}
977
978impl LazyHashSeq {
979    #[allow(dead_code)]
980    pub fn new(blobs: Blobs, hash: Hash) -> Self {
981        Self {
982            blobs,
983            hash,
984            current_chunk: None,
985        }
986    }
987
988    #[allow(dead_code)]
989    pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
990        if offset == 0 {
991            Ok(Some(self.hash))
992        } else {
993            self.get(offset - 1).await
994        }
995    }
996
997    #[allow(dead_code)]
998    pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
999        // check if we have the hash in the current chunk
1000        if let Some(chunk) = &self.current_chunk {
1001            if let Some(hash) = chunk.get(child_offset) {
1002                return Ok(Some(hash));
1003            }
1004        }
1005        // load the chunk covering the offset
1006        let leaf = self
1007            .blobs
1008            .export_chunk(self.hash, child_offset * 32)
1009            .await?;
1010        // return the hash if it is in the chunk, otherwise we are behind the end
1011        let hs = HashSeqChunk::try_from(leaf)?;
1012        Ok(hs.get(child_offset).inspect(|_hash| {
1013            self.current_chunk = Some(hs);
1014        }))
1015    }
1016}
1017
1018async fn write_push_request(
1019    request: PushRequest,
1020    stream: &mut SendStream,
1021) -> anyhow::Result<PushRequest> {
1022    let mut request_bytes = Vec::new();
1023    request_bytes.push(RequestType::Push as u8);
1024    request_bytes.write_length_prefixed(&request).unwrap();
1025    stream.write_all(&request_bytes).await?;
1026    Ok(request)
1027}
1028
1029async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> {
1030    let request = Request::Observe(request);
1031    let request_bytes = postcard::to_allocvec(&request)
1032        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1033    stream.write_all(&request_bytes).await?;
1034    Ok(())
1035}
1036
1037struct StreamContext<S> {
1038    payload_bytes_sent: u64,
1039    sender: S,
1040}
1041
1042impl<S> WriteProgress for StreamContext<S>
1043where
1044    S: Sink<u64, Error = irpc::channel::SendError>,
1045{
1046    async fn notify_payload_write(
1047        &mut self,
1048        _index: u64,
1049        _offset: u64,
1050        len: usize,
1051    ) -> ClientResult {
1052        self.payload_bytes_sent += len as u64;
1053        self.sender
1054            .send(self.payload_bytes_sent)
1055            .await
1056            .map_err(|e| ProgressError::Internal { source: e.into() })?;
1057        Ok(())
1058    }
1059
1060    fn log_other_write(&mut self, _len: usize) {}
1061
1062    async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1063}
1064
1065#[cfg(test)]
1066#[cfg(feature = "fs-store")]
1067mod tests {
1068    use bao_tree::{ChunkNum, ChunkRanges};
1069    use testresult::TestResult;
1070
1071    use crate::{
1072        api::blobs::Blobs,
1073        protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
1074        store::{
1075            fs::{
1076                tests::{create_n0_bao, test_data, INTERESTING_SIZES},
1077                FsStore,
1078            },
1079            mem::MemStore,
1080        },
1081        tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1082    };
1083
1084    #[tokio::test]
1085    async fn test_local_info_raw() -> TestResult<()> {
1086        let td = tempfile::tempdir()?;
1087        let store = FsStore::load(td.path().join("blobs.db")).await?;
1088        let blobs = store.blobs();
1089        let tt = blobs.add_slice(b"test").temp_tag().await?;
1090        let hash = *tt.hash();
1091        let info = store.remote().local(hash).await?;
1092        assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1093        assert_eq!(info.local_bytes(), 4);
1094        assert!(info.is_complete());
1095        assert_eq!(
1096            info.missing(),
1097            GetRequest::new(hash, ChunkRangesSeq::empty())
1098        );
1099        Ok(())
1100    }
1101
1102    #[tokio::test]
1103    async fn test_local_info_hash_seq_large() -> TestResult<()> {
1104        let sizes = (0..1024 + 5).collect::<Vec<_>>();
1105        let relevant_sizes = sizes[32 * 16..32 * 32]
1106            .iter()
1107            .map(|x| *x as u64)
1108            .sum::<u64>();
1109        let td = tempfile::tempdir()?;
1110        let hash_seq_ranges = ChunkRanges::chunks(16..32);
1111        let store = FsStore::load(td.path().join("blobs.db")).await?;
1112        {
1113            // only add the hash seq itself, and only the first chunk of the children
1114            let present = |i| {
1115                if i == 0 {
1116                    hash_seq_ranges.clone()
1117                } else {
1118                    ChunkRanges::from(..ChunkNum(1))
1119                }
1120            };
1121            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1122            let info = store.remote().local(content).await?;
1123            assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1124            assert!(!info.is_complete());
1125            assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1126        }
1127
1128        Ok(())
1129    }
1130
1131    async fn test_observe_partial(blobs: &Blobs) -> TestResult<()> {
1132        let sizes = INTERESTING_SIZES;
1133        for size in sizes {
1134            let data = test_data(size);
1135            let ranges = ChunkRanges::chunk(0);
1136            let (hash, bao) = create_n0_bao(&data, &ranges)?;
1137            blobs.import_bao_bytes(hash, ranges.clone(), bao).await?;
1138            let bitfield = blobs.observe(hash).await?;
1139            if size > 1024 {
1140                assert_eq!(bitfield.ranges, ranges);
1141            } else {
1142                assert_eq!(bitfield.ranges, ChunkRanges::all());
1143            }
1144        }
1145        Ok(())
1146    }
1147
1148    #[tokio::test]
1149    async fn test_observe_partial_mem() -> TestResult<()> {
1150        let store = MemStore::new();
1151        test_observe_partial(store.blobs()).await?;
1152        Ok(())
1153    }
1154
1155    #[tokio::test]
1156    async fn test_observe_partial_fs() -> TestResult<()> {
1157        let td = tempfile::tempdir()?;
1158        let store = FsStore::load(td.path()).await?;
1159        test_observe_partial(store.blobs()).await?;
1160        Ok(())
1161    }
1162
1163    #[tokio::test]
1164    async fn test_local_info_hash_seq() -> TestResult<()> {
1165        let sizes = INTERESTING_SIZES;
1166        let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1167        let hash_seq_size = (sizes.len() as u64) * 32;
1168        let td = tempfile::tempdir()?;
1169        let store = FsStore::load(td.path().join("blobs.db")).await?;
1170        {
1171            // only add the hash seq itself, none of the children
1172            let present = |i| {
1173                if i == 0 {
1174                    ChunkRanges::all()
1175                } else {
1176                    ChunkRanges::empty()
1177                }
1178            };
1179            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1180            let info = store.remote().local(content).await?;
1181            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1182            assert_eq!(info.local_bytes(), hash_seq_size);
1183            assert!(!info.is_complete());
1184            assert_eq!(
1185                info.missing(),
1186                GetRequest::new(
1187                    content.hash,
1188                    ChunkRangesSeq::from_ranges([
1189                        ChunkRanges::empty(), // we have the hash seq itself
1190                        ChunkRanges::empty(), // we always have the empty blob
1191                        ChunkRanges::all(),   // we miss all the remaining blobs (sizes.len() - 1)
1192                        ChunkRanges::all(),
1193                        ChunkRanges::all(),
1194                        ChunkRanges::all(),
1195                        ChunkRanges::all(),
1196                        ChunkRanges::all(),
1197                        ChunkRanges::all(),
1198                    ])
1199                )
1200            );
1201            store.tags().delete_all().await?;
1202        }
1203        {
1204            // only add the hash seq itself, and only the first chunk of the children
1205            let present = |i| {
1206                if i == 0 {
1207                    ChunkRanges::all()
1208                } else {
1209                    ChunkRanges::from(..ChunkNum(1))
1210                }
1211            };
1212            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1213            let info = store.remote().local(content).await?;
1214            let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1215            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1216            assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1217            assert!(!info.is_complete());
1218            assert_eq!(
1219                info.missing(),
1220                GetRequest::new(
1221                    content.hash,
1222                    ChunkRangesSeq::from_ranges([
1223                        ChunkRanges::empty(), // we have the hash seq itself
1224                        ChunkRanges::empty(), // we always have the empty blob
1225                        ChunkRanges::empty(), // size=1
1226                        ChunkRanges::empty(), // size=1024
1227                        ChunkRanges::chunks(1..),
1228                        ChunkRanges::chunks(1..),
1229                        ChunkRanges::chunks(1..),
1230                        ChunkRanges::chunks(1..),
1231                        ChunkRanges::chunks(1..),
1232                    ])
1233                )
1234            );
1235        }
1236        {
1237            let content = add_test_hash_seq(&store, sizes).await?;
1238            let info = store.remote().local(content).await?;
1239            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1240            assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1241            assert!(info.is_complete());
1242            assert_eq!(
1243                info.missing(),
1244                GetRequest::new(content.hash, ChunkRangesSeq::empty())
1245            );
1246        }
1247        Ok(())
1248    }
1249
1250    #[tokio::test]
1251    async fn test_local_info_complex_request() -> TestResult<()> {
1252        let sizes = INTERESTING_SIZES;
1253        let hash_seq_size = (sizes.len() as u64) * 32;
1254        let td = tempfile::tempdir()?;
1255        let store = FsStore::load(td.path().join("blobs.db")).await?;
1256        // only add the hash seq itself, and only the first chunk of the children
1257        let present = |i| {
1258            if i == 0 {
1259                ChunkRanges::all()
1260            } else {
1261                ChunkRanges::chunks(..2)
1262            }
1263        };
1264        let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1265        {
1266            let request: GetRequest = GetRequest::builder()
1267                .root(ChunkRanges::all())
1268                .build(content.hash);
1269            let info = store.remote().local_for_request(request).await?;
1270            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1271            assert_eq!(info.local_bytes(), hash_seq_size);
1272            assert!(info.is_complete());
1273        }
1274        {
1275            let request: GetRequest = GetRequest::builder()
1276                .root(ChunkRanges::all())
1277                .next(ChunkRanges::all())
1278                .build(content.hash);
1279            let info = store.remote().local_for_request(request).await?;
1280            let expected_child_sizes = sizes
1281                .into_iter()
1282                .take(1)
1283                .map(|x| 1024.min(x as u64))
1284                .sum::<u64>();
1285            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1286            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1287            assert!(info.is_complete());
1288        }
1289        {
1290            let request: GetRequest = GetRequest::builder()
1291                .root(ChunkRanges::all())
1292                .next(ChunkRanges::all())
1293                .next(ChunkRanges::all())
1294                .build(content.hash);
1295            let info = store.remote().local_for_request(request).await?;
1296            let expected_child_sizes = sizes
1297                .into_iter()
1298                .take(2)
1299                .map(|x| 1024.min(x as u64))
1300                .sum::<u64>();
1301            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1302            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1303            assert!(info.is_complete());
1304        }
1305        {
1306            let request: GetRequest = GetRequest::builder()
1307                .root(ChunkRanges::all())
1308                .next(ChunkRanges::chunk(0))
1309                .build_open(content.hash);
1310            let info = store.remote().local_for_request(request).await?;
1311            let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1312            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1313            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1314            assert!(info.is_complete());
1315        }
1316        Ok(())
1317    }
1318}