iroh_blobs/api/
remote.rs

1//! API for downloading blobs from a single remote node.
2//!
3//! The entry point is the [`Remote`] struct.
4use genawaiter::sync::{Co, Gen};
5use iroh::endpoint::SendStream;
6use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
7use n0_future::{io, Stream, StreamExt};
8use n0_snafu::SpanTrace;
9use nested_enum_utils::common_fields;
10use ref_cast::RefCast;
11use snafu::{Backtrace, IntoError, Snafu};
12
13use super::blobs::{Bitfield, ExportBaoOptions};
14use crate::{
15    api::{blobs::WriteProgress, ApiClient},
16    get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats},
17    protocol::{
18        GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
19        MAX_MESSAGE_SIZE,
20    },
21    util::sink::{Sink, TokioMpscSenderSink},
22};
23
24/// API to compute request and to download from remote nodes.
25///
26/// Usually you want to first find out what, if any, data you have locally.
27/// This can be done using [`Remote::local`], which inspects the local store
28/// and returns a [`LocalInfo`].
29///
30/// From this you can compute various values such as the number of locally present
31/// bytes. You can also compute a request to get the missing data using [`LocalInfo::missing`].
32///
33/// Once you have a request, you can execute it using [`Remote::execute_get`].
34/// Executing a request will store to the local store, but otherwise does not take
35/// the available data into account.
36///
37/// If you are not interested in the details and just want your data, you can use
38/// [`Remote::fetch`]. This will internally do the dance described above.
39#[derive(Debug, Clone, RefCast)]
40#[repr(transparent)]
41pub struct Remote {
42    client: ApiClient,
43}
44
45#[derive(Debug)]
46pub enum GetProgressItem {
47    /// Progress on the payload bytes read.
48    Progress(u64),
49    /// The request was completed.
50    Done(Stats),
51    /// The request was closed, but not completed.
52    Error(GetError),
53}
54
55impl From<GetResult<Stats>> for GetProgressItem {
56    fn from(res: GetResult<Stats>) -> Self {
57        match res {
58            Ok(stats) => GetProgressItem::Done(stats),
59            Err(e) => GetProgressItem::Error(e),
60        }
61    }
62}
63
64impl TryFrom<GetProgressItem> for GetResult<Stats> {
65    type Error = &'static str;
66
67    fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
68        match item {
69            GetProgressItem::Done(stats) => Ok(Ok(stats)),
70            GetProgressItem::Error(e) => Ok(Err(e)),
71            GetProgressItem::Progress(_) => Err("not a final item"),
72        }
73    }
74}
75
76pub struct GetProgress {
77    rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
78    fut: n0_future::boxed::BoxFuture<()>,
79}
80
81impl IntoFuture for GetProgress {
82    type Output = GetResult<Stats>;
83    type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
84
85    fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
86        Box::pin(self.complete())
87    }
88}
89
90impl GetProgress {
91    pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
92        into_stream(self.rx, self.fut)
93    }
94
95    pub async fn complete(self) -> GetResult<Stats> {
96        just_result(self.stream()).await.unwrap_or_else(|| {
97            Err(LocalFailureSnafu
98                .into_error(anyhow::anyhow!("stream closed without result").into()))
99        })
100    }
101}
102
103#[derive(Debug)]
104pub enum PushProgressItem {
105    /// Progress on the payload bytes read.
106    Progress(u64),
107    /// The request was completed.
108    Done(Stats),
109    /// The request was closed, but not completed.
110    Error(anyhow::Error),
111}
112
113impl From<anyhow::Result<Stats>> for PushProgressItem {
114    fn from(res: anyhow::Result<Stats>) -> Self {
115        match res {
116            Ok(stats) => Self::Done(stats),
117            Err(e) => Self::Error(e),
118        }
119    }
120}
121
122impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
123    type Error = &'static str;
124
125    fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
126        match item {
127            PushProgressItem::Done(stats) => Ok(Ok(stats)),
128            PushProgressItem::Error(e) => Ok(Err(e)),
129            PushProgressItem::Progress(_) => Err("not a final item"),
130        }
131    }
132}
133
134pub struct PushProgress {
135    rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
136    fut: n0_future::boxed::BoxFuture<()>,
137}
138
139impl IntoFuture for PushProgress {
140    type Output = anyhow::Result<Stats>;
141    type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
142
143    fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
144        Box::pin(self.complete())
145    }
146}
147
148impl PushProgress {
149    pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
150        into_stream(self.rx, self.fut)
151    }
152
153    pub async fn complete(self) -> anyhow::Result<Stats> {
154        just_result(self.stream())
155            .await
156            .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
157    }
158}
159
160async fn just_result<S, R>(stream: S) -> Option<R>
161where
162    S: Stream<Item: std::fmt::Debug>,
163    R: TryFrom<S::Item>,
164{
165    tokio::pin!(stream);
166    while let Some(item) = stream.next().await {
167        if let Ok(res) = R::try_from(item) {
168            return Some(res);
169        }
170    }
171    None
172}
173
174fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
175where
176    F: Future,
177{
178    Gen::new(move |co| async move {
179        tokio::pin!(fut);
180        loop {
181            tokio::select! {
182                biased;
183                item = rx.recv() => {
184                    if let Some(item) = item {
185                        co.yield_(item).await;
186                    } else {
187                        break;
188                    }
189                }
190                _ = &mut fut => {
191                    break;
192                }
193            }
194        }
195        while let Some(item) = rx.recv().await {
196            co.yield_(item).await;
197        }
198    })
199}
200
201/// Local info for a blob or hash sequence.
202///
203/// This can be used to get the amount of missing data, and to construct a
204/// request to get the missing data.
205#[derive(Debug)]
206pub struct LocalInfo {
207    /// The hash for which this is the local info
208    request: Arc<GetRequest>,
209    /// The bitfield for the root hash
210    bitfield: Bitfield,
211    /// Optional - the hash sequence info if this was a request for a hash sequence
212    children: Option<NonRawLocalInfo>,
213}
214
215impl LocalInfo {
216    /// The number of bytes we have locally
217    pub fn local_bytes(&self) -> u64 {
218        let Some(root_requested) = self.requested_root_ranges() else {
219            // empty request requests 0 bytes
220            return 0;
221        };
222        let mut local = self.bitfield.clone();
223        local.ranges.intersection_with(root_requested);
224        let mut res = local.total_bytes();
225        if let Some(children) = &self.children {
226            let Some(max_local_index) = children.hash_seq.keys().next_back() else {
227                // no children
228                return res;
229            };
230            for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
231                if offset == 0 {
232                    // skip the root hash
233                    continue;
234                }
235                let child = offset - 1;
236                if child > *max_local_index {
237                    // we are done
238                    break;
239                }
240                let Some(hash) = children.hash_seq.get(&child) else {
241                    continue;
242                };
243                let bitfield = &children.bitfields[hash];
244                let mut local = bitfield.clone();
245                local.ranges.intersection_with(ranges);
246                res += local.total_bytes();
247            }
248        }
249        res
250    }
251
252    /// Number of children in this hash sequence
253    pub fn children(&self) -> Option<u64> {
254        if self.children.is_some() {
255            self.bitfield.validated_size().map(|x| x / 32)
256        } else {
257            Some(0)
258        }
259    }
260
261    /// The requested root ranges.
262    ///
263    /// This will return None if the request is empty, and an empty CHunkRanges
264    /// if no ranges were requested for the root hash.
265    fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
266        self.request.ranges.iter().next()
267    }
268
269    /// True if the data is complete.
270    ///
271    /// For a blob, this is true if the blob is complete.
272    /// For a hash sequence, this is true if the hash sequence is complete and
273    /// all its children are complete.
274    pub fn is_complete(&self) -> bool {
275        let Some(root_requested) = self.requested_root_ranges() else {
276            // empty request is complete
277            return true;
278        };
279        if !self.bitfield.ranges.is_superset(root_requested) {
280            return false;
281        }
282        if let Some(children) = self.children.as_ref() {
283            let mut iter = self.request.ranges.iter_non_empty_infinite();
284            let max_child = self.bitfield.validated_size().map(|x| x / 32);
285            loop {
286                let Some((offset, range)) = iter.next() else {
287                    break;
288                };
289                if offset == 0 {
290                    // skip the root hash
291                    continue;
292                }
293                let child = offset - 1;
294                if let Some(hash) = children.hash_seq.get(&child) {
295                    let bitfield = &children.bitfields[hash];
296                    if !bitfield.ranges.is_superset(range) {
297                        // we don't have the requested ranges
298                        return false;
299                    }
300                } else {
301                    if let Some(max_child) = max_child {
302                        if child >= max_child {
303                            // reading after the end of the request
304                            return true;
305                        }
306                    }
307                    return false;
308                }
309            }
310        }
311        true
312    }
313
314    /// A request to get the missing data to complete this request
315    pub fn missing(&self) -> GetRequest {
316        let Some(root_requested) = self.requested_root_ranges() else {
317            // empty request is complete
318            return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
319        };
320        let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
321
322        let Some(children) = self.children.as_ref() else {
323            return builder.build(self.request.hash);
324        };
325        let mut iter = self.request.ranges.iter_non_empty_infinite();
326        let max_local = children
327            .hash_seq
328            .keys()
329            .next_back()
330            .map(|x| *x + 1)
331            .unwrap_or_default();
332        let max_offset = self.bitfield.validated_size().map(|x| x / 32);
333        loop {
334            let Some((offset, requested)) = iter.next() else {
335                break;
336            };
337            if offset == 0 {
338                // skip the root hash
339                continue;
340            }
341            let child = offset - 1;
342            let missing = match children.hash_seq.get(&child) {
343                Some(hash) => requested.difference(&children.bitfields[hash].ranges),
344                None => requested.clone(),
345            };
346            builder = builder.child(child, missing);
347            if offset >= max_local {
348                // we can't do anything clever anymore
349                break;
350            }
351        }
352        loop {
353            let Some((offset, requested)) = iter.next() else {
354                return builder.build(self.request.hash);
355            };
356            if offset == 0 {
357                // skip the root hash
358                continue;
359            }
360            let child = offset - 1;
361            if let Some(max_offset) = &max_offset {
362                if child >= *max_offset {
363                    return builder.build(self.request.hash);
364                }
365                builder = builder.child(child, requested.clone());
366            } else {
367                builder = builder.child(child, requested.clone());
368                if iter.is_at_end() {
369                    if iter.next().is_none() {
370                        return builder.build(self.request.hash);
371                    } else {
372                        return builder.build_open(self.request.hash);
373                    }
374                }
375            }
376        }
377    }
378}
379
380#[derive(Debug)]
381struct NonRawLocalInfo {
382    /// the available and relevant part of the hash sequence
383    hash_seq: BTreeMap<u64, Hash>,
384    /// For each hash relevant to the request, the local bitfield and the ranges
385    /// that were requested.
386    bitfields: BTreeMap<Hash, Bitfield>,
387}
388
389// fn iter_without_gaps<'a, T: Copy + 'a>(
390//     iter: impl IntoIterator<Item = &'a (u64, T)> + 'a,
391// ) -> impl Iterator<Item = (u64, Option<T>)> + 'a {
392//     let mut prev = 0;
393//     iter.into_iter().flat_map(move |(i, hash)| {
394//         let start = prev + 1;
395//         let curr = *i;
396//         prev = *i;
397//         (start..curr)
398//             .map(|i| (i, None))
399//             .chain(std::iter::once((curr, Some(*hash))))
400//     })
401// }
402
403impl Remote {
404    pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
405        Self::ref_cast(sender)
406    }
407
408    fn store(&self) -> &Store {
409        Store::ref_from_sender(&self.client)
410    }
411
412    pub async fn local_for_request(
413        &self,
414        request: impl Into<Arc<GetRequest>>,
415    ) -> anyhow::Result<LocalInfo> {
416        let request = request.into();
417        let root = request.hash;
418        let bitfield = self.store().observe(root).await?;
419        let children = if !request.ranges.is_blob() {
420            let opts = ExportBaoOptions {
421                hash: root,
422                ranges: bitfield.ranges.clone(),
423            };
424            let bao = self.store().export_bao_with_opts(opts, 32);
425            let mut by_index = BTreeMap::new();
426            let mut stream = bao.hashes_with_index();
427            while let Some(item) = stream.next().await {
428                if let Ok((index, hash)) = item {
429                    by_index.insert(index, hash);
430                }
431            }
432            let mut bitfields = BTreeMap::new();
433            let mut hash_seq = BTreeMap::new();
434            let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
435            for (index, _) in request.ranges.iter_non_empty_infinite() {
436                if index == 0 {
437                    // skip the root hash
438                    continue;
439                }
440                let child = index - 1;
441                if child > max {
442                    // we are done
443                    break;
444                }
445                let Some(hash) = by_index.get(&child) else {
446                    // we don't have the hash, so we can't store the bitfield
447                    continue;
448                };
449                let bitfield = self.store().observe(*hash).await?;
450                bitfields.insert(*hash, bitfield);
451                hash_seq.insert(child, *hash);
452            }
453            Some(NonRawLocalInfo {
454                hash_seq,
455                bitfields,
456            })
457        } else {
458            None
459        };
460        Ok(LocalInfo {
461            request: request.clone(),
462            bitfield,
463            children,
464        })
465    }
466
467    /// Get the local info for a given blob or hash sequence, at the present time.
468    pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
469        let request = GetRequest::from(content.into());
470        self.local_for_request(request).await
471    }
472
473    pub fn fetch(
474        &self,
475        conn: impl GetConnection + Send + 'static,
476        content: impl Into<HashAndFormat>,
477    ) -> GetProgress {
478        let content = content.into();
479        let (tx, rx) = tokio::sync::mpsc::channel(64);
480        let tx2 = tx.clone();
481        let sink = TokioMpscSenderSink(tx)
482            .with_map(GetProgressItem::Progress)
483            .with_map_err(io::Error::other);
484        let this = self.clone();
485        let fut = async move {
486            let res = this.fetch_sink(conn, content, sink).await.into();
487            tx2.send(res).await.ok();
488        };
489        GetProgress {
490            rx,
491            fut: Box::pin(fut),
492        }
493    }
494
495    /// Get a blob or hash sequence from the given connection, taking the locally available
496    /// ranges into account.
497    ///
498    /// You can provide a progress channel to get updates on the download progress. This progress
499    /// is the aggregated number of downloaded payload bytes in the request.
500    ///
501    /// This will return the stats of the download.
502    pub async fn fetch_sink(
503        &self,
504        mut conn: impl GetConnection,
505        content: impl Into<HashAndFormat>,
506        progress: impl Sink<u64, Error = io::Error>,
507    ) -> GetResult<Stats> {
508        let content = content.into();
509        let local = self
510            .local(content)
511            .await
512            .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
513        if local.is_complete() {
514            return Ok(Default::default());
515        }
516        let request = local.missing();
517        let conn = conn
518            .connection()
519            .await
520            .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
521        let stats = self.execute_get_sink(conn, request, progress).await?;
522        Ok(stats)
523    }
524
525    pub fn observe(
526        &self,
527        conn: Connection,
528        request: ObserveRequest,
529    ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
530        Gen::new(|co| async move {
531            if let Err(cause) = Self::observe_impl(conn, request, &co).await {
532                co.yield_(Err(cause)).await
533            }
534        })
535    }
536
537    async fn observe_impl(
538        conn: Connection,
539        request: ObserveRequest,
540        co: &Co<io::Result<Bitfield>>,
541    ) -> io::Result<()> {
542        let hash = request.hash;
543        debug!(%hash, "observing");
544        let (mut send, mut recv) = conn.open_bi().await?;
545        // write the request. Unlike for reading, we can just serialize it sync using postcard.
546        write_observe_request(request, &mut send).await?;
547        send.finish()?;
548        loop {
549            let msg = recv
550                .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
551                .await?;
552            co.yield_(Ok(Bitfield::from(&msg))).await;
553        }
554    }
555
556    pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
557        let (tx, rx) = tokio::sync::mpsc::channel(64);
558        let tx2 = tx.clone();
559        let sink = TokioMpscSenderSink(tx)
560            .with_map(PushProgressItem::Progress)
561            .with_map_err(io::Error::other);
562        let this = self.clone();
563        let fut = async move {
564            let res = this.execute_push_sink(conn, request, sink).await.into();
565            tx2.send(res).await.ok();
566        };
567        PushProgress {
568            rx,
569            fut: Box::pin(fut),
570        }
571    }
572
573    /// Push the given blob or hash sequence to a remote node.
574    ///
575    /// Note that many nodes will reject push requests. Also, this is an experimental feature for now.
576    pub async fn execute_push_sink(
577        &self,
578        conn: Connection,
579        request: PushRequest,
580        progress: impl Sink<u64, Error = io::Error>,
581    ) -> anyhow::Result<Stats> {
582        let hash = request.hash;
583        debug!(%hash, "pushing");
584        let (mut send, mut recv) = conn.open_bi().await?;
585        let mut context = StreamContext {
586            payload_bytes_sent: 0,
587            sender: progress,
588        };
589        // we are not going to need this!
590        recv.stop(0u32.into())?;
591        // write the request. Unlike for reading, we can just serialize it sync using postcard.
592        let request = write_push_request(request, &mut send).await?;
593        let mut request_ranges = request.ranges.iter_infinite();
594        let root = request.hash;
595        let root_ranges = request_ranges.next().expect("infinite iterator");
596        if !root_ranges.is_empty() {
597            self.store()
598                .export_bao(root, root_ranges.clone())
599                .write_quinn_with_progress(&mut send, &mut context, &root, 0)
600                .await?;
601        }
602        if request.ranges.is_blob() {
603            // we are done
604            send.finish()?;
605            return Ok(Default::default());
606        }
607        let hash_seq = self.store().get_bytes(root).await?;
608        let hash_seq = HashSeq::try_from(hash_seq)?;
609        for (child, (child_hash, child_ranges)) in
610            hash_seq.into_iter().zip(request_ranges).enumerate()
611        {
612            if !child_ranges.is_empty() {
613                self.store()
614                    .export_bao(child_hash, child_ranges.clone())
615                    .write_quinn_with_progress(
616                        &mut send,
617                        &mut context,
618                        &child_hash,
619                        (child + 1) as u64,
620                    )
621                    .await?;
622            }
623        }
624        send.finish()?;
625        Ok(Default::default())
626    }
627
628    pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress {
629        self.execute_get_with_opts(conn, request)
630    }
631
632    pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
633        let (tx, rx) = tokio::sync::mpsc::channel(64);
634        let tx2 = tx.clone();
635        let sink = TokioMpscSenderSink(tx)
636            .with_map(GetProgressItem::Progress)
637            .with_map_err(io::Error::other);
638        let this = self.clone();
639        let fut = async move {
640            let res = this.execute_get_sink(conn, request, sink).await.into();
641            tx2.send(res).await.ok();
642        };
643        GetProgress {
644            rx,
645            fut: Box::pin(fut),
646        }
647    }
648
649    /// Execute a get request *without* taking the locally available ranges into account.
650    ///
651    /// You can provide a progress channel to get updates on the download progress. This progress
652    /// is the aggregated number of downloaded payload bytes in the request.
653    ///
654    /// This will download the data again even if the data is locally present.
655    ///
656    /// This will return the stats of the download.
657    pub async fn execute_get_sink(
658        &self,
659        conn: Connection,
660        request: GetRequest,
661        mut progress: impl Sink<u64, Error = io::Error>,
662    ) -> GetResult<Stats> {
663        let store = self.store();
664        let root = request.hash;
665        let start = crate::get::fsm::start(conn, request, Default::default());
666        let connected = start.next().await?;
667        trace!("Getting header");
668        // read the header
669        let next_child = match connected.next().await? {
670            ConnectedNext::StartRoot(at_start_root) => {
671                let header = at_start_root.next();
672                let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
673                match end.next() {
674                    EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
675                    EndBlobNext::Closing(at_closing) => Err(at_closing),
676                }
677            }
678            ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
679            ConnectedNext::Closing(at_closing) => Err(at_closing),
680        };
681        // read the rest, if any
682        let at_closing = match next_child {
683            Ok(at_start_child) => {
684                let mut next_child = Ok(at_start_child);
685                let hash_seq = HashSeq::try_from(
686                    store
687                        .get_bytes(root)
688                        .await
689                        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
690                )
691                .map_err(|source| BadRequestSnafu.into_error(source.into()))?;
692                // let mut hash_seq = LazyHashSeq::new(store.blobs().clone(), root);
693                loop {
694                    let at_start_child = match next_child {
695                        Ok(at_start_child) => at_start_child,
696                        Err(at_closing) => break at_closing,
697                    };
698                    let offset = at_start_child.offset() - 1;
699                    let Some(hash) = hash_seq.get(offset as usize) else {
700                        break at_start_child.finish();
701                    };
702                    trace!("getting child {offset} {}", hash.fmt_short());
703                    let header = at_start_child.next(hash);
704                    let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
705                    next_child = match end.next() {
706                        EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
707                        EndBlobNext::Closing(at_closing) => Err(at_closing),
708                    }
709                }
710            }
711            Err(at_closing) => at_closing,
712        };
713        // read the rest, if any
714        let stats = at_closing.next().await?;
715        trace!(?stats, "get hash seq done");
716        Ok(stats)
717    }
718
719    pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
720        let (tx, rx) = tokio::sync::mpsc::channel(64);
721        let tx2 = tx.clone();
722        let sink = TokioMpscSenderSink(tx)
723            .with_map(GetProgressItem::Progress)
724            .with_map_err(io::Error::other);
725        let this = self.clone();
726        let fut = async move {
727            let res = this.execute_get_many_sink(conn, request, sink).await.into();
728            tx2.send(res).await.ok();
729        };
730        GetProgress {
731            rx,
732            fut: Box::pin(fut),
733        }
734    }
735
736    /// Execute a get request *without* taking the locally available ranges into account.
737    ///
738    /// You can provide a progress channel to get updates on the download progress. This progress
739    /// is the aggregated number of downloaded payload bytes in the request.
740    ///
741    /// This will download the data again even if the data is locally present.
742    ///
743    /// This will return the stats of the download.
744    pub async fn execute_get_many_sink(
745        &self,
746        conn: Connection,
747        request: GetManyRequest,
748        mut progress: impl Sink<u64, Error = io::Error>,
749    ) -> GetResult<Stats> {
750        let store = self.store();
751        let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
752        let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
753        // read all children.
754        let at_closing = match next_child {
755            Ok(at_start_child) => {
756                let mut next_child = Ok(at_start_child);
757                loop {
758                    let at_start_child = match next_child {
759                        Ok(at_start_child) => at_start_child,
760                        Err(at_closing) => break at_closing,
761                    };
762                    let offset = at_start_child.offset();
763                    println!("offset {offset}");
764                    let Some(hash) = hash_seq.get(offset as usize) else {
765                        break at_start_child.finish();
766                    };
767                    trace!("getting child {offset} {}", hash.fmt_short());
768                    let header = at_start_child.next(hash);
769                    let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
770                    next_child = match end.next() {
771                        EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
772                        EndBlobNext::Closing(at_closing) => Err(at_closing),
773                    }
774                }
775            }
776            Err(at_closing) => at_closing,
777        };
778        // read the rest, if any
779        let stats = at_closing.next().await?;
780        trace!(?stats, "get hash seq done");
781        Ok(stats)
782    }
783}
784
785/// Failures for a get operation
786#[common_fields({
787    backtrace: Option<Backtrace>,
788    #[snafu(implicit)]
789    span_trace: SpanTrace,
790})]
791#[allow(missing_docs)]
792#[non_exhaustive]
793#[derive(Debug, Snafu)]
794pub enum ExecuteError {
795    /// Network or IO operation failed.
796    #[snafu(display("Unable to open bidi stream"))]
797    Connection {
798        source: iroh::endpoint::ConnectionError,
799    },
800    #[snafu(display("Unable to read from the remote"))]
801    Read { source: iroh::endpoint::ReadError },
802    #[snafu(display("Error sending the request"))]
803    Send {
804        source: crate::get::fsm::ConnectedNextError,
805    },
806    #[snafu(display("Unable to read size"))]
807    Size {
808        source: crate::get::fsm::AtBlobHeaderNextError,
809    },
810    #[snafu(display("Error while decoding the data"))]
811    Decode {
812        source: crate::get::fsm::DecodeError,
813    },
814    #[snafu(display("Internal error while reading the hash sequence"))]
815    ExportBao { source: api::ExportBaoError },
816    #[snafu(display("Hash sequence has an invalid length"))]
817    InvalidHashSeq { source: anyhow::Error },
818    #[snafu(display("Internal error importing the data"))]
819    ImportBao { source: crate::api::RequestError },
820    #[snafu(display("Error sending download progress - receiver closed"))]
821    SendDownloadProgress { source: irpc::channel::SendError },
822    #[snafu(display("Internal error importing the data"))]
823    MpscSend {
824        source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
825    },
826}
827
828use std::{
829    collections::BTreeMap,
830    future::{Future, IntoFuture},
831    num::NonZeroU64,
832    sync::Arc,
833};
834
835use bao_tree::{
836    io::{BaoContentItem, Leaf},
837    ChunkNum, ChunkRanges,
838};
839use iroh::endpoint::Connection;
840use tracing::{debug, trace};
841
842use crate::{
843    api::{self, blobs::Blobs, Store},
844    get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext},
845    hashseq::{HashSeq, HashSeqIter},
846    protocol::{ChunkRangesSeq, GetRequest},
847    store::IROH_BLOCK_SIZE,
848    Hash, HashAndFormat,
849};
850
851/// Trait to lazily get a connection
852pub trait GetConnection {
853    fn connection(&mut self)
854        -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_;
855}
856
857/// If we already have a connection, the impl is trivial
858impl GetConnection for Connection {
859    fn connection(
860        &mut self,
861    ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
862        let conn = self.clone();
863        async { Ok(conn) }
864    }
865}
866
867/// If we already have a connection, the impl is trivial
868impl GetConnection for &Connection {
869    fn connection(
870        &mut self,
871    ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
872        let conn = self.clone();
873        async { Ok(conn) }
874    }
875}
876
877fn get_buffer_size(size: NonZeroU64) -> usize {
878    (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
879}
880
881async fn get_blob_ranges_impl(
882    header: AtBlobHeader,
883    hash: Hash,
884    store: &Store,
885    mut progress: impl Sink<u64, Error = io::Error>,
886) -> GetResult<AtEndBlob> {
887    let (mut content, size) = header.next().await?;
888    let Some(size) = NonZeroU64::new(size) else {
889        return if hash == Hash::EMPTY {
890            let end = content.drain().await?;
891            Ok(end)
892        } else {
893            Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
894        };
895    };
896    let buffer_size = get_buffer_size(size);
897    trace!(%size, %buffer_size, "get blob");
898    let handle = store
899        .import_bao(hash, size, buffer_size)
900        .await
901        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
902    let write = async move {
903        GetResult::Ok(loop {
904            match content.next().await {
905                BlobContentNext::More((next, res)) => {
906                    let item = res?;
907                    progress
908                        .send(next.stats().payload_bytes_read)
909                        .await
910                        .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
911                    handle.tx.send(item).await?;
912                    content = next;
913                }
914                BlobContentNext::Done(end) => {
915                    drop(handle.tx);
916                    break end;
917                }
918            }
919        })
920    };
921    let complete = async move {
922        handle.rx.await.map_err(|e| {
923            LocalFailureSnafu
924                .into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
925        })
926    };
927    let (_, end) = tokio::try_join!(complete, write)?;
928    Ok(end)
929}
930
931#[derive(Debug)]
932pub(crate) struct LazyHashSeq {
933    blobs: Blobs,
934    hash: Hash,
935    current_chunk: Option<HashSeqChunk>,
936}
937
938#[derive(Debug)]
939pub(crate) struct HashSeqChunk {
940    /// the offset of the first hash in this chunk, in bytes
941    offset: u64,
942    /// the hashes in this chunk
943    chunk: HashSeq,
944}
945
946impl TryFrom<Leaf> for HashSeqChunk {
947    type Error = anyhow::Error;
948
949    fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
950        let offset = leaf.offset;
951        let chunk = HashSeq::try_from(leaf.data)?;
952        Ok(Self { offset, chunk })
953    }
954}
955
956impl IntoIterator for HashSeqChunk {
957    type Item = Hash;
958    type IntoIter = HashSeqIter;
959
960    fn into_iter(self) -> Self::IntoIter {
961        self.chunk.into_iter()
962    }
963}
964
965impl HashSeqChunk {
966    pub fn base(&self) -> u64 {
967        self.offset / 32
968    }
969
970    #[allow(dead_code)]
971    fn get(&self, offset: u64) -> Option<Hash> {
972        let start = self.offset;
973        let end = start + self.chunk.len() as u64;
974        if offset >= start && offset < end {
975            let o = (offset - start) as usize;
976            self.chunk.get(o)
977        } else {
978            None
979        }
980    }
981}
982
983impl LazyHashSeq {
984    #[allow(dead_code)]
985    pub fn new(blobs: Blobs, hash: Hash) -> Self {
986        Self {
987            blobs,
988            hash,
989            current_chunk: None,
990        }
991    }
992
993    #[allow(dead_code)]
994    pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
995        if offset == 0 {
996            Ok(Some(self.hash))
997        } else {
998            self.get(offset - 1).await
999        }
1000    }
1001
1002    #[allow(dead_code)]
1003    pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
1004        // check if we have the hash in the current chunk
1005        if let Some(chunk) = &self.current_chunk {
1006            if let Some(hash) = chunk.get(child_offset) {
1007                return Ok(Some(hash));
1008            }
1009        }
1010        // load the chunk covering the offset
1011        let leaf = self
1012            .blobs
1013            .export_chunk(self.hash, child_offset * 32)
1014            .await?;
1015        // return the hash if it is in the chunk, otherwise we are behind the end
1016        let hs = HashSeqChunk::try_from(leaf)?;
1017        Ok(hs.get(child_offset).inspect(|_hash| {
1018            self.current_chunk = Some(hs);
1019        }))
1020    }
1021}
1022
1023async fn write_push_request(
1024    request: PushRequest,
1025    stream: &mut SendStream,
1026) -> anyhow::Result<PushRequest> {
1027    let mut request_bytes = Vec::new();
1028    request_bytes.push(RequestType::Push as u8);
1029    request_bytes.write_length_prefixed(&request).unwrap();
1030    stream.write_all(&request_bytes).await?;
1031    Ok(request)
1032}
1033
1034async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> {
1035    let request = Request::Observe(request);
1036    let request_bytes = postcard::to_allocvec(&request)
1037        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1038    stream.write_all(&request_bytes).await?;
1039    Ok(())
1040}
1041
1042struct StreamContext<S> {
1043    payload_bytes_sent: u64,
1044    sender: S,
1045}
1046
1047impl<S> WriteProgress for StreamContext<S>
1048where
1049    S: Sink<u64, Error = io::Error>,
1050{
1051    async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) {
1052        self.payload_bytes_sent += len as u64;
1053        self.sender.send(self.payload_bytes_sent).await.ok();
1054    }
1055
1056    fn log_other_write(&mut self, _len: usize) {}
1057
1058    async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use bao_tree::{ChunkNum, ChunkRanges};
1064    use testresult::TestResult;
1065
1066    use crate::{
1067        protocol::{ChunkRangesSeq, GetRequest},
1068        store::fs::{tests::INTERESTING_SIZES, FsStore},
1069        tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1070        util::ChunkRangesExt,
1071    };
1072
1073    #[tokio::test]
1074    async fn test_local_info_raw() -> TestResult<()> {
1075        let td = tempfile::tempdir()?;
1076        let store = FsStore::load(td.path().join("blobs.db")).await?;
1077        let blobs = store.blobs();
1078        let tt = blobs.add_slice(b"test").temp_tag().await?;
1079        let hash = *tt.hash();
1080        let info = store.remote().local(hash).await?;
1081        assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1082        assert_eq!(info.local_bytes(), 4);
1083        assert!(info.is_complete());
1084        assert_eq!(
1085            info.missing(),
1086            GetRequest::new(hash, ChunkRangesSeq::empty())
1087        );
1088        Ok(())
1089    }
1090
1091    #[tokio::test]
1092    async fn test_local_info_hash_seq_large() -> TestResult<()> {
1093        let sizes = (0..1024 + 5).collect::<Vec<_>>();
1094        let relevant_sizes = sizes[32 * 16..32 * 32]
1095            .iter()
1096            .map(|x| *x as u64)
1097            .sum::<u64>();
1098        let td = tempfile::tempdir()?;
1099        let hash_seq_ranges = ChunkRanges::chunks(16..32);
1100        let store = FsStore::load(td.path().join("blobs.db")).await?;
1101        {
1102            // only add the hash seq itself, and only the first chunk of the children
1103            let present = |i| {
1104                if i == 0 {
1105                    hash_seq_ranges.clone()
1106                } else {
1107                    ChunkRanges::from(..ChunkNum(1))
1108                }
1109            };
1110            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1111            let info = store.remote().local(content).await?;
1112            assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1113            assert!(!info.is_complete());
1114            assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1115        }
1116
1117        Ok(())
1118    }
1119
1120    #[tokio::test]
1121    async fn test_local_info_hash_seq() -> TestResult<()> {
1122        let sizes = INTERESTING_SIZES;
1123        let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1124        let hash_seq_size = (sizes.len() as u64) * 32;
1125        let td = tempfile::tempdir()?;
1126        let store = FsStore::load(td.path().join("blobs.db")).await?;
1127        {
1128            // only add the hash seq itself, none of the children
1129            let present = |i| {
1130                if i == 0 {
1131                    ChunkRanges::all()
1132                } else {
1133                    ChunkRanges::empty()
1134                }
1135            };
1136            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1137            let info = store.remote().local(content).await?;
1138            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1139            assert_eq!(info.local_bytes(), hash_seq_size);
1140            assert!(!info.is_complete());
1141            assert_eq!(
1142                info.missing(),
1143                GetRequest::new(
1144                    content.hash,
1145                    ChunkRangesSeq::from_ranges([
1146                        ChunkRanges::empty(), // we have the hash seq itself
1147                        ChunkRanges::empty(), // we always have the empty blob
1148                        ChunkRanges::all(),   // we miss all the remaining blobs (sizes.len() - 1)
1149                        ChunkRanges::all(),
1150                        ChunkRanges::all(),
1151                        ChunkRanges::all(),
1152                        ChunkRanges::all(),
1153                        ChunkRanges::all(),
1154                        ChunkRanges::all(),
1155                    ])
1156                )
1157            );
1158            store.tags().delete_all().await?;
1159        }
1160        {
1161            // only add the hash seq itself, and only the first chunk of the children
1162            let present = |i| {
1163                if i == 0 {
1164                    ChunkRanges::all()
1165                } else {
1166                    ChunkRanges::from(..ChunkNum(1))
1167                }
1168            };
1169            let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1170            let info = store.remote().local(content).await?;
1171            let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1172            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1173            assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1174            assert!(!info.is_complete());
1175            assert_eq!(
1176                info.missing(),
1177                GetRequest::new(
1178                    content.hash,
1179                    ChunkRangesSeq::from_ranges([
1180                        ChunkRanges::empty(), // we have the hash seq itself
1181                        ChunkRanges::empty(), // we always have the empty blob
1182                        ChunkRanges::empty(), // size=1
1183                        ChunkRanges::empty(), // size=1024
1184                        ChunkRanges::chunks(1..),
1185                        ChunkRanges::chunks(1..),
1186                        ChunkRanges::chunks(1..),
1187                        ChunkRanges::chunks(1..),
1188                        ChunkRanges::chunks(1..),
1189                    ])
1190                )
1191            );
1192        }
1193        {
1194            let content = add_test_hash_seq(&store, sizes).await?;
1195            let info = store.remote().local(content).await?;
1196            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1197            assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1198            assert!(info.is_complete());
1199            assert_eq!(
1200                info.missing(),
1201                GetRequest::new(content.hash, ChunkRangesSeq::empty())
1202            );
1203        }
1204        Ok(())
1205    }
1206
1207    #[tokio::test]
1208    async fn test_local_info_complex_request() -> TestResult<()> {
1209        let sizes = INTERESTING_SIZES;
1210        let hash_seq_size = (sizes.len() as u64) * 32;
1211        let td = tempfile::tempdir()?;
1212        let store = FsStore::load(td.path().join("blobs.db")).await?;
1213        // only add the hash seq itself, and only the first chunk of the children
1214        let present = |i| {
1215            if i == 0 {
1216                ChunkRanges::all()
1217            } else {
1218                ChunkRanges::chunks(..2)
1219            }
1220        };
1221        let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1222        {
1223            let request: GetRequest = GetRequest::builder()
1224                .root(ChunkRanges::all())
1225                .build(content.hash);
1226            let info = store.remote().local_for_request(request).await?;
1227            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1228            assert_eq!(info.local_bytes(), hash_seq_size);
1229            assert!(info.is_complete());
1230        }
1231        {
1232            let request: GetRequest = GetRequest::builder()
1233                .root(ChunkRanges::all())
1234                .next(ChunkRanges::all())
1235                .build(content.hash);
1236            let info = store.remote().local_for_request(request).await?;
1237            let expected_child_sizes = sizes
1238                .into_iter()
1239                .take(1)
1240                .map(|x| 1024.min(x as u64))
1241                .sum::<u64>();
1242            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1243            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1244            assert!(info.is_complete());
1245        }
1246        {
1247            let request: GetRequest = GetRequest::builder()
1248                .root(ChunkRanges::all())
1249                .next(ChunkRanges::all())
1250                .next(ChunkRanges::all())
1251                .build(content.hash);
1252            let info = store.remote().local_for_request(request).await?;
1253            let expected_child_sizes = sizes
1254                .into_iter()
1255                .take(2)
1256                .map(|x| 1024.min(x as u64))
1257                .sum::<u64>();
1258            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1259            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1260            assert!(info.is_complete());
1261        }
1262        {
1263            let request: GetRequest = GetRequest::builder()
1264                .root(ChunkRanges::all())
1265                .next(ChunkRanges::chunk(0))
1266                .build_open(content.hash);
1267            let info = store.remote().local_for_request(request).await?;
1268            let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1269            assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1270            assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1271            assert!(info.is_complete());
1272        }
1273        Ok(())
1274    }
1275}