iroh_blobs/
get.rs

1//! The client side API
2//!
3//! To get data, create a connection using [iroh-net] or use any quinn
4//! connection that was obtained in another way.
5//!
6//! Create a request describing the data you want to get.
7//!
8//! Then create a state machine using [fsm::start] and
9//! drive it to completion by calling next on each state.
10//!
11//! For some states you have to provide additional arguments when calling next,
12//! or you can choose to finish early.
13//!
14//! [iroh-net]: https://docs.rs/iroh-net
15use std::{
16    error::Error,
17    fmt::{self, Debug},
18    time::{Duration, Instant},
19};
20
21use anyhow::Result;
22use bao_tree::{io::fsm::BaoContentItem, ChunkNum};
23use iroh::endpoint::{self, ClosedStream, RecvStream, SendStream, WriteError};
24use serde::{Deserialize, Serialize};
25use tracing::{debug, error};
26
27use crate::{
28    protocol::RangeSpecSeq,
29    util::io::{TrackingReader, TrackingWriter},
30    Hash, IROH_BLOCK_SIZE,
31};
32
33pub mod db;
34pub mod error;
35pub mod progress;
36pub mod request;
37
38/// Stats about the transfer.
39#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Stats {
41    /// The number of bytes written
42    pub bytes_written: u64,
43    /// The number of bytes read
44    pub bytes_read: u64,
45    /// The time it took to transfer the data
46    pub elapsed: Duration,
47}
48
49impl Stats {
50    /// Transfer rate in megabits per second
51    pub fn mbits(&self) -> f64 {
52        let data_len_bit = self.bytes_read * 8;
53        data_len_bit as f64 / (1000. * 1000.) / self.elapsed.as_secs_f64()
54    }
55}
56
57/// Finite state machine for get responses.
58///
59/// This is the low level API for getting data from a peer.
60#[doc = include_str!("../docs/img/get_machine.drawio.svg")]
61pub mod fsm {
62    use std::{io, result};
63
64    use bao_tree::{
65        io::fsm::{OutboardMut, ResponseDecoder, ResponseDecoderNext},
66        BaoTree, ChunkRanges, TreeNode,
67    };
68    use derive_more::From;
69    use iroh::endpoint::Connection;
70    use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader};
71    use tokio::io::AsyncWriteExt;
72
73    use super::*;
74    use crate::{
75        protocol::{GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE},
76        store::BaoBatchWriter,
77    };
78
79    type WrappedRecvStream = TrackingReader<TokioStreamReader<RecvStream>>;
80
81    self_cell::self_cell! {
82        struct RangesIterInner {
83            owner: RangeSpecSeq,
84            #[covariant]
85            dependent: NonEmptyRequestRangeSpecIter,
86        }
87    }
88
89    /// The entry point of the get response machine
90    pub fn start(connection: Connection, request: GetRequest) -> AtInitial {
91        AtInitial::new(connection, request)
92    }
93
94    /// Owned iterator for the ranges in a request
95    ///
96    /// We need an owned iterator for a fsm style API, otherwise we would have
97    /// to drag a lifetime around every single state.
98    struct RangesIter(RangesIterInner);
99
100    impl fmt::Debug for RangesIter {
101        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102            f.debug_struct("RangesIter").finish()
103        }
104    }
105
106    impl RangesIter {
107        pub fn new(owner: RangeSpecSeq) -> Self {
108            Self(RangesIterInner::new(owner, |owner| owner.iter_non_empty()))
109        }
110
111        pub fn offset(&self) -> u64 {
112            self.0.with_dependent(|_owner, iter| iter.offset())
113        }
114    }
115
116    impl Iterator for RangesIter {
117        type Item = (u64, ChunkRanges);
118
119        fn next(&mut self) -> Option<Self::Item> {
120            self.0.with_dependent_mut(|_owner, iter| {
121                iter.next()
122                    .map(|(offset, ranges)| (offset, ranges.to_chunk_ranges()))
123            })
124        }
125    }
126
127    /// Initial state of the get response machine
128    #[derive(Debug)]
129    pub struct AtInitial {
130        connection: Connection,
131        request: GetRequest,
132    }
133
134    impl AtInitial {
135        /// Create a new get response
136        ///
137        /// `connection` is an existing connection
138        /// `request` is the request to be sent
139        pub fn new(connection: Connection, request: GetRequest) -> Self {
140            Self {
141                connection,
142                request,
143            }
144        }
145
146        /// Initiate a new bidi stream to use for the get response
147        pub async fn next(self) -> Result<AtConnected, endpoint::ConnectionError> {
148            let start = Instant::now();
149            let (writer, reader) = self.connection.open_bi().await?;
150            let reader = TrackingReader::new(TokioStreamReader::new(reader));
151            let writer = TrackingWriter::new(writer);
152            Ok(AtConnected {
153                start,
154                reader,
155                writer,
156                request: self.request,
157            })
158        }
159    }
160
161    /// State of the get response machine after the handshake has been sent
162    #[derive(Debug)]
163    pub struct AtConnected {
164        start: Instant,
165        reader: WrappedRecvStream,
166        writer: TrackingWriter<SendStream>,
167        request: GetRequest,
168    }
169
170    /// Possible next states after the handshake has been sent
171    #[derive(Debug, From)]
172    pub enum ConnectedNext {
173        /// First response is either a collection or a single blob
174        StartRoot(AtStartRoot),
175        /// First response is a child
176        StartChild(AtStartChild),
177        /// Request is empty
178        Closing(AtClosing),
179    }
180
181    /// Error that you can get from [`AtConnected::next`]
182    #[derive(Debug, thiserror::Error)]
183    pub enum ConnectedNextError {
184        /// Error when serializing the request
185        #[error("postcard ser: {0}")]
186        PostcardSer(postcard::Error),
187        /// The serialized request is too long to be sent
188        #[error("request too big")]
189        RequestTooBig,
190        /// Error when writing the request to the [`SendStream`].
191        #[error("write: {0}")]
192        Write(#[from] WriteError),
193        /// Quic connection is closed.
194        #[error("closed")]
195        Closed(#[from] ClosedStream),
196        /// A generic io error
197        #[error("io {0}")]
198        Io(io::Error),
199    }
200
201    impl ConnectedNextError {
202        fn from_io(cause: io::Error) -> Self {
203            if let Some(inner) = cause.get_ref() {
204                if let Some(e) = inner.downcast_ref::<endpoint::WriteError>() {
205                    Self::Write(e.clone())
206                } else {
207                    Self::Io(cause)
208                }
209            } else {
210                Self::Io(cause)
211            }
212        }
213    }
214
215    impl From<ConnectedNextError> for io::Error {
216        fn from(cause: ConnectedNextError) -> Self {
217            match cause {
218                ConnectedNextError::Write(cause) => cause.into(),
219                ConnectedNextError::Io(cause) => cause,
220                ConnectedNextError::PostcardSer(cause) => {
221                    io::Error::new(io::ErrorKind::Other, cause)
222                }
223                _ => io::Error::new(io::ErrorKind::Other, cause),
224            }
225        }
226    }
227
228    impl AtConnected {
229        /// Send the request and move to the next state
230        ///
231        /// The next state will be either `StartRoot` or `StartChild` depending on whether
232        /// the request requests part of the collection or not.
233        ///
234        /// If the request is empty, this can also move directly to `Finished`.
235        pub async fn next(self) -> Result<ConnectedNext, ConnectedNextError> {
236            let Self {
237                start,
238                reader,
239                mut writer,
240                mut request,
241            } = self;
242            // 1. Send Request
243            {
244                debug!("sending request");
245                let wrapped = Request::Get(request);
246                let request_bytes =
247                    postcard::to_stdvec(&wrapped).map_err(ConnectedNextError::PostcardSer)?;
248                let Request::Get(x) = wrapped;
249                request = x;
250
251                if request_bytes.len() > MAX_MESSAGE_SIZE {
252                    return Err(ConnectedNextError::RequestTooBig);
253                }
254
255                // write the request itself
256                writer
257                    .write_all(&request_bytes)
258                    .await
259                    .map_err(ConnectedNextError::from_io)?;
260            }
261
262            // 2. Finish writing before expecting a response
263            let (mut writer, bytes_written) = writer.into_parts();
264            writer.finish()?;
265
266            let hash = request.hash;
267            let ranges_iter = RangesIter::new(request.ranges);
268            // this is in a box so we don't have to memcpy it on every state transition
269            let mut misc = Box::new(Misc {
270                start,
271                bytes_written,
272                ranges_iter,
273            });
274            Ok(match misc.ranges_iter.next() {
275                Some((offset, ranges)) => {
276                    if offset == 0 {
277                        AtStartRoot {
278                            reader,
279                            ranges,
280                            misc,
281                            hash,
282                        }
283                        .into()
284                    } else {
285                        AtStartChild {
286                            reader,
287                            ranges,
288                            misc,
289                            child_offset: offset - 1,
290                        }
291                        .into()
292                    }
293                }
294                None => AtClosing::new(misc, reader, true).into(),
295            })
296        }
297    }
298
299    /// State of the get response when we start reading a collection
300    #[derive(Debug)]
301    pub struct AtStartRoot {
302        ranges: ChunkRanges,
303        reader: TrackingReader<TokioStreamReader<RecvStream>>,
304        misc: Box<Misc>,
305        hash: Hash,
306    }
307
308    /// State of the get response when we start reading a child
309    #[derive(Debug)]
310    pub struct AtStartChild {
311        ranges: ChunkRanges,
312        reader: TrackingReader<TokioStreamReader<RecvStream>>,
313        misc: Box<Misc>,
314        child_offset: u64,
315    }
316
317    impl AtStartChild {
318        /// The offset of the child we are currently reading
319        ///
320        /// This must be used to determine the hash needed to call next.
321        /// If this is larger than the number of children in the collection,
322        /// you can call finish to stop reading the response.
323        pub fn child_offset(&self) -> u64 {
324            self.child_offset
325        }
326
327        /// The ranges we have requested for the child
328        pub fn ranges(&self) -> &ChunkRanges {
329            &self.ranges
330        }
331
332        /// Go into the next state, reading the header
333        ///
334        /// This requires passing in the hash of the child for validation
335        pub fn next(self, hash: Hash) -> AtBlobHeader {
336            AtBlobHeader {
337                reader: self.reader,
338                ranges: self.ranges,
339                misc: self.misc,
340                hash,
341            }
342        }
343
344        /// Finish the get response without reading further
345        ///
346        /// This is used if you know that there are no more children from having
347        /// read the collection, or when you want to stop reading the response
348        /// early.
349        pub fn finish(self) -> AtClosing {
350            AtClosing::new(self.misc, self.reader, false)
351        }
352    }
353
354    impl AtStartRoot {
355        /// The ranges we have requested for the child
356        pub fn ranges(&self) -> &ChunkRanges {
357            &self.ranges
358        }
359
360        /// Hash of the root blob
361        pub fn hash(&self) -> Hash {
362            self.hash
363        }
364
365        /// Go into the next state, reading the header
366        ///
367        /// For the collection we already know the hash, since it was part of the request
368        pub fn next(self) -> AtBlobHeader {
369            AtBlobHeader {
370                reader: self.reader,
371                ranges: self.ranges,
372                hash: self.hash,
373                misc: self.misc,
374            }
375        }
376
377        /// Finish the get response without reading further
378        pub fn finish(self) -> AtClosing {
379            AtClosing::new(self.misc, self.reader, false)
380        }
381    }
382
383    /// State before reading a size header
384    #[derive(Debug)]
385    pub struct AtBlobHeader {
386        ranges: ChunkRanges,
387        reader: TrackingReader<TokioStreamReader<RecvStream>>,
388        misc: Box<Misc>,
389        hash: Hash,
390    }
391
392    /// Error that you can get from [`AtBlobHeader::next`]
393    #[derive(Debug, thiserror::Error)]
394    pub enum AtBlobHeaderNextError {
395        /// Eof when reading the size header
396        ///
397        /// This indicates that the provider does not have the requested data.
398        #[error("not found")]
399        NotFound,
400        /// Quinn read error when reading the size header
401        #[error("read: {0}")]
402        Read(endpoint::ReadError),
403        /// Generic io error
404        #[error("io: {0}")]
405        Io(io::Error),
406    }
407
408    impl From<AtBlobHeaderNextError> for io::Error {
409        fn from(cause: AtBlobHeaderNextError) -> Self {
410            match cause {
411                AtBlobHeaderNextError::NotFound => {
412                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
413                }
414                AtBlobHeaderNextError::Read(cause) => cause.into(),
415                AtBlobHeaderNextError::Io(cause) => cause,
416            }
417        }
418    }
419
420    impl AtBlobHeader {
421        /// Read the size header, returning it and going into the `Content` state.
422        pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> {
423            let size = self.reader.read::<8>().await.map_err(|cause| {
424                if cause.kind() == io::ErrorKind::UnexpectedEof {
425                    AtBlobHeaderNextError::NotFound
426                } else if let Some(e) = cause
427                    .get_ref()
428                    .and_then(|x| x.downcast_ref::<endpoint::ReadError>())
429                {
430                    AtBlobHeaderNextError::Read(e.clone())
431                } else {
432                    AtBlobHeaderNextError::Io(cause)
433                }
434            })?;
435            let size = u64::from_le_bytes(size);
436            let stream = ResponseDecoder::new(
437                self.hash.into(),
438                self.ranges,
439                BaoTree::new(size, IROH_BLOCK_SIZE),
440                self.reader,
441            );
442            Ok((
443                AtBlobContent {
444                    stream,
445                    misc: self.misc,
446                },
447                size,
448            ))
449        }
450
451        /// Drain the response and throw away the result
452        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
453            let (content, _size) = self.next().await?;
454            content.drain().await
455        }
456
457        /// Concatenate the entire response into a vec
458        ///
459        /// For a request that does not request the complete blob, this will just
460        /// concatenate the ranges that were requested.
461        pub async fn concatenate_into_vec(
462            self,
463        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
464            let (content, _size) = self.next().await?;
465            content.concatenate_into_vec().await
466        }
467
468        /// Write the entire blob to a slice writer.
469        pub async fn write_all<D: AsyncSliceWriter>(
470            self,
471            data: D,
472        ) -> result::Result<AtEndBlob, DecodeError> {
473            let (content, _size) = self.next().await?;
474            let res = content.write_all(data).await?;
475            Ok(res)
476        }
477
478        /// Write the entire blob to a slice writer and to an optional outboard.
479        ///
480        /// The outboard is only written to if the blob is larger than a single
481        /// chunk group.
482        pub async fn write_all_with_outboard<D, O>(
483            self,
484            outboard: Option<O>,
485            data: D,
486        ) -> result::Result<AtEndBlob, DecodeError>
487        where
488            D: AsyncSliceWriter,
489            O: OutboardMut,
490        {
491            let (content, _size) = self.next().await?;
492            let res = content.write_all_with_outboard(outboard, data).await?;
493            Ok(res)
494        }
495
496        /// Write the entire stream for this blob to a batch writer.
497        pub async fn write_all_batch<B>(self, batch: B) -> result::Result<AtEndBlob, DecodeError>
498        where
499            B: BaoBatchWriter,
500        {
501            let (content, _size) = self.next().await?;
502            let res = content.write_all_batch(batch).await?;
503            Ok(res)
504        }
505
506        /// The hash of the blob we are reading.
507        pub fn hash(&self) -> Hash {
508            self.hash
509        }
510
511        /// The ranges we have requested for the current hash.
512        pub fn ranges(&self) -> &ChunkRanges {
513            &self.ranges
514        }
515
516        /// The current offset of the blob we are reading.
517        pub fn offset(&self) -> u64 {
518            self.misc.ranges_iter.offset()
519        }
520    }
521
522    /// State while we are reading content
523    #[derive(Debug)]
524    pub struct AtBlobContent {
525        stream: ResponseDecoder<WrappedRecvStream>,
526        misc: Box<Misc>,
527    }
528
529    /// Decode error that you can get once you have sent the request and are
530    /// decoding the response, e.g. from [`AtBlobContent::next`].
531    ///
532    /// This is similar to [`bao_tree::io::DecodeError`], but takes into account
533    /// that we are reading from a [`RecvStream`], so read errors will be
534    /// propagated as [`DecodeError::Read`], containing a [`ReadError`].
535    /// This carries more concrete information about the error than an [`io::Error`].
536    ///
537    /// When the provider finds that it does not have a chunk that we requested,
538    /// or that the chunk is invalid, it will stop sending data without producing
539    /// an error. This is indicated by either the [`DecodeError::ParentNotFound`] or
540    /// [`DecodeError::LeafNotFound`] variant, which can be used to detect that data
541    /// is missing but the connection as well that the provider is otherwise healthy.
542    ///
543    /// The [`DecodeError::ParentHashMismatch`] and [`DecodeError::LeafHashMismatch`]
544    /// variants indicate that the provider has sent us invalid data. A well-behaved
545    /// provider should never do this, so this is an indication that the provider is
546    /// not behaving correctly.
547    ///
548    /// The [`DecodeError::Io`] variant is just a fallback for any other io error that
549    /// is not actually a [`ReadError`].
550    ///
551    /// [`ReadError`]: endpoint::ReadError
552    #[derive(Debug, thiserror::Error)]
553    pub enum DecodeError {
554        /// A chunk was not found or invalid, so the provider stopped sending data
555        #[error("not found")]
556        NotFound,
557        /// A parent was not found or invalid, so the provider stopped sending data
558        #[error("parent not found {0:?}")]
559        ParentNotFound(TreeNode),
560        /// A parent was not found or invalid, so the provider stopped sending data
561        #[error("chunk not found {0}")]
562        LeafNotFound(ChunkNum),
563        /// The hash of a parent did not match the expected hash
564        #[error("parent hash mismatch: {0:?}")]
565        ParentHashMismatch(TreeNode),
566        /// The hash of a leaf did not match the expected hash
567        #[error("leaf hash mismatch: {0}")]
568        LeafHashMismatch(ChunkNum),
569        /// Error when reading from the stream
570        #[error("read: {0}")]
571        Read(endpoint::ReadError),
572        /// A generic io error
573        #[error("io: {0}")]
574        Io(#[from] io::Error),
575    }
576
577    impl From<AtBlobHeaderNextError> for DecodeError {
578        fn from(cause: AtBlobHeaderNextError) -> Self {
579            match cause {
580                AtBlobHeaderNextError::NotFound => Self::NotFound,
581                AtBlobHeaderNextError::Read(cause) => Self::Read(cause),
582                AtBlobHeaderNextError::Io(cause) => Self::Io(cause),
583            }
584        }
585    }
586
587    impl From<DecodeError> for io::Error {
588        fn from(cause: DecodeError) -> Self {
589            match cause {
590                DecodeError::ParentNotFound(_) => {
591                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
592                }
593                DecodeError::LeafNotFound(_) => io::Error::new(io::ErrorKind::UnexpectedEof, cause),
594                DecodeError::Read(cause) => cause.into(),
595                DecodeError::Io(cause) => cause,
596                _ => io::Error::new(io::ErrorKind::Other, cause),
597            }
598        }
599    }
600
601    impl From<bao_tree::io::DecodeError> for DecodeError {
602        fn from(value: bao_tree::io::DecodeError) -> Self {
603            match value {
604                bao_tree::io::DecodeError::ParentNotFound(x) => Self::ParentNotFound(x),
605                bao_tree::io::DecodeError::LeafNotFound(x) => Self::LeafNotFound(x),
606                bao_tree::io::DecodeError::ParentHashMismatch(node) => {
607                    Self::ParentHashMismatch(node)
608                }
609                bao_tree::io::DecodeError::LeafHashMismatch(chunk) => Self::LeafHashMismatch(chunk),
610                bao_tree::io::DecodeError::Io(cause) => {
611                    if let Some(inner) = cause.get_ref() {
612                        if let Some(e) = inner.downcast_ref::<endpoint::ReadError>() {
613                            Self::Read(e.clone())
614                        } else {
615                            Self::Io(cause)
616                        }
617                    } else {
618                        Self::Io(cause)
619                    }
620                }
621            }
622        }
623    }
624
625    /// The next state after reading a content item
626    #[derive(Debug, From)]
627    pub enum BlobContentNext {
628        /// We expect more content
629        More((AtBlobContent, result::Result<BaoContentItem, DecodeError>)),
630        /// We are done with this blob
631        Done(AtEndBlob),
632    }
633
634    impl AtBlobContent {
635        /// Read the next item, either content, an error, or the end of the blob
636        pub async fn next(self) -> BlobContentNext {
637            match self.stream.next().await {
638                ResponseDecoderNext::More((stream, res)) => {
639                    let next = Self { stream, ..self };
640                    let res = res.map_err(DecodeError::from);
641                    BlobContentNext::More((next, res))
642                }
643                ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob {
644                    stream,
645                    misc: self.misc,
646                }),
647            }
648        }
649
650        /// The geometry of the tree we are currently reading.
651        pub fn tree(&self) -> bao_tree::BaoTree {
652            self.stream.tree()
653        }
654
655        /// The hash of the blob we are reading.
656        pub fn hash(&self) -> Hash {
657            (*self.stream.hash()).into()
658        }
659
660        /// The current offset of the blob we are reading.
661        pub fn offset(&self) -> u64 {
662            self.misc.ranges_iter.offset()
663        }
664
665        /// Drain the response and throw away the result
666        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
667            let mut content = self;
668            loop {
669                match content.next().await {
670                    BlobContentNext::More((content1, res)) => {
671                        let _ = res?;
672                        content = content1;
673                    }
674                    BlobContentNext::Done(end) => {
675                        break Ok(end);
676                    }
677                }
678            }
679        }
680
681        /// Concatenate the entire response into a vec
682        pub async fn concatenate_into_vec(
683            self,
684        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
685            let mut res = Vec::with_capacity(1024);
686            let mut curr = self;
687            let done = loop {
688                match curr.next().await {
689                    BlobContentNext::More((next, data)) => {
690                        if let BaoContentItem::Leaf(leaf) = data? {
691                            res.extend_from_slice(&leaf.data);
692                        }
693                        curr = next;
694                    }
695                    BlobContentNext::Done(done) => {
696                        // we are done with the root blob
697                        break done;
698                    }
699                }
700            };
701            Ok((done, res))
702        }
703
704        /// Write the entire stream for this blob to a batch writer.
705        pub async fn write_all_batch<B>(self, writer: B) -> result::Result<AtEndBlob, DecodeError>
706        where
707            B: BaoBatchWriter,
708        {
709            let mut writer = writer;
710            let mut buf = Vec::new();
711            let mut content = self;
712            let size = content.tree().size();
713            loop {
714                match content.next().await {
715                    BlobContentNext::More((next, item)) => {
716                        let item = item?;
717                        match &item {
718                            BaoContentItem::Parent(_) => {
719                                buf.push(item);
720                            }
721                            BaoContentItem::Leaf(_) => {
722                                buf.push(item);
723                                let batch = std::mem::take(&mut buf);
724                                writer.write_batch(size, batch).await?;
725                            }
726                        }
727                        content = next;
728                    }
729                    BlobContentNext::Done(end) => {
730                        assert!(buf.is_empty());
731                        return Ok(end);
732                    }
733                }
734            }
735        }
736
737        /// Write the entire blob to a slice writer and to an optional outboard.
738        ///
739        /// The outboard is only written to if the blob is larger than a single
740        /// chunk group.
741        pub async fn write_all_with_outboard<D, O>(
742            self,
743            mut outboard: Option<O>,
744            mut data: D,
745        ) -> result::Result<AtEndBlob, DecodeError>
746        where
747            D: AsyncSliceWriter,
748            O: OutboardMut,
749        {
750            let mut content = self;
751            loop {
752                match content.next().await {
753                    BlobContentNext::More((content1, item)) => {
754                        content = content1;
755                        match item? {
756                            BaoContentItem::Parent(parent) => {
757                                if let Some(outboard) = outboard.as_mut() {
758                                    outboard.save(parent.node, &parent.pair).await?;
759                                }
760                            }
761                            BaoContentItem::Leaf(leaf) => {
762                                data.write_bytes_at(leaf.offset, leaf.data).await?;
763                            }
764                        }
765                    }
766                    BlobContentNext::Done(end) => {
767                        return Ok(end);
768                    }
769                }
770            }
771        }
772
773        /// Write the entire blob to a slice writer.
774        pub async fn write_all<D>(self, mut data: D) -> result::Result<AtEndBlob, DecodeError>
775        where
776            D: AsyncSliceWriter,
777        {
778            let mut content = self;
779            loop {
780                match content.next().await {
781                    BlobContentNext::More((content1, item)) => {
782                        content = content1;
783                        match item? {
784                            BaoContentItem::Parent(_) => {}
785                            BaoContentItem::Leaf(leaf) => {
786                                data.write_bytes_at(leaf.offset, leaf.data).await?;
787                            }
788                        }
789                    }
790                    BlobContentNext::Done(end) => {
791                        return Ok(end);
792                    }
793                }
794            }
795        }
796
797        /// Immediately finish the get response without reading further
798        pub fn finish(self) -> AtClosing {
799            AtClosing::new(self.misc, self.stream.finish(), false)
800        }
801    }
802
803    /// State after we have read all the content for a blob
804    #[derive(Debug)]
805    pub struct AtEndBlob {
806        stream: WrappedRecvStream,
807        misc: Box<Misc>,
808    }
809
810    /// The next state after the end of a blob
811    #[derive(Debug, From)]
812    pub enum EndBlobNext {
813        /// Response is expected to have more children
814        MoreChildren(AtStartChild),
815        /// No more children expected
816        Closing(AtClosing),
817    }
818
819    impl AtEndBlob {
820        /// Read the next child, or finish
821        pub fn next(mut self) -> EndBlobNext {
822            if let Some((offset, ranges)) = self.misc.ranges_iter.next() {
823                AtStartChild {
824                    reader: self.stream,
825                    child_offset: offset - 1,
826                    ranges,
827                    misc: self.misc,
828                }
829                .into()
830            } else {
831                AtClosing::new(self.misc, self.stream, true).into()
832            }
833        }
834    }
835
836    /// State when finishing the get response
837    #[derive(Debug)]
838    pub struct AtClosing {
839        misc: Box<Misc>,
840        reader: WrappedRecvStream,
841        check_extra_data: bool,
842    }
843
844    impl AtClosing {
845        fn new(misc: Box<Misc>, reader: WrappedRecvStream, check_extra_data: bool) -> Self {
846            Self {
847                misc,
848                reader,
849                check_extra_data,
850            }
851        }
852
853        /// Finish the get response, returning statistics
854        pub async fn next(self) -> result::Result<Stats, endpoint::ReadError> {
855            // Shut down the stream
856            let (reader, bytes_read) = self.reader.into_parts();
857            let mut reader = reader.into_inner();
858            if self.check_extra_data {
859                if let Some(chunk) = reader.read_chunk(8, false).await? {
860                    reader.stop(0u8.into()).ok();
861                    error!("Received unexpected data from the provider: {chunk:?}");
862                }
863            } else {
864                reader.stop(0u8.into()).ok();
865            }
866            Ok(Stats {
867                elapsed: self.misc.start.elapsed(),
868                bytes_written: self.misc.bytes_written,
869                bytes_read,
870            })
871        }
872    }
873
874    /// Stuff we need to hold on to while going through the machine states
875    #[derive(Debug)]
876    struct Misc {
877        /// start time for statistics
878        start: Instant,
879        /// bytes written for statistics
880        bytes_written: u64,
881        /// iterator over the ranges of the collection and the children
882        ranges_iter: RangesIter,
883    }
884}
885
886/// Error when processing a response
887#[derive(thiserror::Error, Debug)]
888pub enum GetResponseError {
889    /// Error when opening a stream
890    #[error("connection: {0}")]
891    Connection(#[from] endpoint::ConnectionError),
892    /// Error when writing the handshake or request to the stream
893    #[error("write: {0}")]
894    Write(#[from] endpoint::WriteError),
895    /// Error when reading from the stream
896    #[error("read: {0}")]
897    Read(#[from] endpoint::ReadError),
898    /// Error when decoding, e.g. hash mismatch
899    #[error("decode: {0}")]
900    Decode(bao_tree::io::DecodeError),
901    /// A generic error
902    #[error("generic: {0}")]
903    Generic(anyhow::Error),
904}
905
906impl From<postcard::Error> for GetResponseError {
907    fn from(cause: postcard::Error) -> Self {
908        Self::Generic(cause.into())
909    }
910}
911
912impl From<bao_tree::io::DecodeError> for GetResponseError {
913    fn from(cause: bao_tree::io::DecodeError) -> Self {
914        match cause {
915            bao_tree::io::DecodeError::Io(cause) => {
916                // try to downcast to specific quinn errors
917                if let Some(source) = cause.source() {
918                    if let Some(error) = source.downcast_ref::<endpoint::ConnectionError>() {
919                        return Self::Connection(error.clone());
920                    }
921                    if let Some(error) = source.downcast_ref::<endpoint::ReadError>() {
922                        return Self::Read(error.clone());
923                    }
924                    if let Some(error) = source.downcast_ref::<endpoint::WriteError>() {
925                        return Self::Write(error.clone());
926                    }
927                }
928                Self::Generic(cause.into())
929            }
930            _ => Self::Decode(cause),
931        }
932    }
933}
934
935impl From<anyhow::Error> for GetResponseError {
936    fn from(cause: anyhow::Error) -> Self {
937        Self::Generic(cause)
938    }
939}
940
941impl From<GetResponseError> for std::io::Error {
942    fn from(cause: GetResponseError) -> Self {
943        Self::new(std::io::ErrorKind::Other, cause)
944    }
945}