iroh_bytes/
get.rs

1//! The client side API
2//!
3//! To get data, create a connection using [iroh-net] or use any quinn
4//! connection that was obtained in another way.
5//!
6//! Create a request describing the data you want to get.
7//!
8//! Then create a state machine using [fsm::start] and
9//! drive it to completion by calling next on each state.
10//!
11//! For some states you have to provide additional arguments when calling next,
12//! or you can choose to finish early.
13//!
14//! [iroh-net]: https://docs.rs/iroh-net
15use std::error::Error;
16use std::fmt::{self, Debug};
17use std::time::{Duration, Instant};
18
19use crate::Hash;
20use anyhow::Result;
21use bao_tree::io::fsm::BaoContentItem;
22use bao_tree::ChunkNum;
23use quinn::RecvStream;
24use serde::{Deserialize, Serialize};
25use tracing::{debug, error};
26
27use crate::protocol::RangeSpecSeq;
28use crate::util::io::{TrackingReader, TrackingWriter};
29use crate::IROH_BLOCK_SIZE;
30
31pub mod db;
32pub mod error;
33pub mod progress;
34pub mod request;
35
36/// Stats about the transfer.
37#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub struct Stats {
39    /// The number of bytes written
40    pub bytes_written: u64,
41    /// The number of bytes read
42    pub bytes_read: u64,
43    /// The time it took to transfer the data
44    pub elapsed: Duration,
45}
46
47impl Stats {
48    /// Transfer rate in megabits per second
49    pub fn mbits(&self) -> f64 {
50        let data_len_bit = self.bytes_read * 8;
51        data_len_bit as f64 / (1000. * 1000.) / self.elapsed.as_secs_f64()
52    }
53}
54
55/// Finite state machine for get responses.
56///
57/// This is the low level API for getting data from a peer.
58#[doc = include_str!("../docs/img/get_machine.drawio.svg")]
59pub mod fsm {
60    use std::{io, result};
61
62    use crate::{
63        protocol::{GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE},
64        store::BaoBatchWriter,
65    };
66
67    use super::*;
68
69    use bao_tree::{
70        io::fsm::{OutboardMut, ResponseDecoder, ResponseDecoderNext},
71        BaoTree, ChunkRanges, TreeNode,
72    };
73    use derive_more::From;
74    use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader};
75    use tokio::io::AsyncWriteExt;
76
77    type WrappedRecvStream = TrackingReader<TokioStreamReader<RecvStream>>;
78
79    self_cell::self_cell! {
80        struct RangesIterInner {
81            owner: RangeSpecSeq,
82            #[covariant]
83            dependent: NonEmptyRequestRangeSpecIter,
84        }
85    }
86
87    /// The entry point of the get response machine
88    pub fn start(connection: quinn::Connection, request: GetRequest) -> AtInitial {
89        AtInitial::new(connection, request)
90    }
91
92    /// Owned iterator for the ranges in a request
93    ///
94    /// We need an owned iterator for a fsm style API, otherwise we would have
95    /// to drag a lifetime around every single state.
96    struct RangesIter(RangesIterInner);
97
98    impl fmt::Debug for RangesIter {
99        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100            f.debug_struct("RangesIter").finish()
101        }
102    }
103
104    impl RangesIter {
105        pub fn new(owner: RangeSpecSeq) -> Self {
106            Self(RangesIterInner::new(owner, |owner| owner.iter_non_empty()))
107        }
108
109        pub fn offset(&self) -> u64 {
110            self.0.with_dependent(|_owner, iter| iter.offset())
111        }
112    }
113
114    impl Iterator for RangesIter {
115        type Item = (u64, ChunkRanges);
116
117        fn next(&mut self) -> Option<Self::Item> {
118            self.0.with_dependent_mut(|_owner, iter| {
119                iter.next()
120                    .map(|(offset, ranges)| (offset, ranges.to_chunk_ranges()))
121            })
122        }
123    }
124
125    /// Initial state of the get response machine
126    #[derive(Debug)]
127    pub struct AtInitial {
128        connection: quinn::Connection,
129        request: GetRequest,
130    }
131
132    impl AtInitial {
133        /// Create a new get response
134        ///
135        /// `connection` is an existing connection
136        /// `request` is the request to be sent
137        pub fn new(connection: quinn::Connection, request: GetRequest) -> Self {
138            Self {
139                connection,
140                request,
141            }
142        }
143
144        /// Initiate a new bidi stream to use for the get response
145        pub async fn next(self) -> Result<AtConnected, quinn::ConnectionError> {
146            let start = Instant::now();
147            let (writer, reader) = self.connection.open_bi().await?;
148            let reader = TrackingReader::new(TokioStreamReader::new(reader));
149            let writer = TrackingWriter::new(writer);
150            Ok(AtConnected {
151                start,
152                reader,
153                writer,
154                request: self.request,
155            })
156        }
157    }
158
159    /// State of the get response machine after the handshake has been sent
160    #[derive(Debug)]
161    pub struct AtConnected {
162        start: Instant,
163        reader: WrappedRecvStream,
164        writer: TrackingWriter<quinn::SendStream>,
165        request: GetRequest,
166    }
167
168    /// Possible next states after the handshake has been sent
169    #[derive(Debug, From)]
170    pub enum ConnectedNext {
171        /// First response is either a collection or a single blob
172        StartRoot(AtStartRoot),
173        /// First response is a child
174        StartChild(AtStartChild),
175        /// Request is empty
176        Closing(AtClosing),
177    }
178
179    /// Error that you can get from [`AtConnected::next`]
180    #[derive(Debug, thiserror::Error)]
181    pub enum ConnectedNextError {
182        /// Error when serializing the request
183        #[error("postcard ser: {0}")]
184        PostcardSer(postcard::Error),
185        /// The serialized request is too long to be sent
186        #[error("request too big")]
187        RequestTooBig,
188        /// Error when writing the request to the [`quinn::SendStream`]
189        #[error("write: {0}")]
190        Write(#[from] quinn::WriteError),
191        /// A generic io error
192        #[error("io {0}")]
193        Io(io::Error),
194    }
195
196    impl ConnectedNextError {
197        fn from_io(cause: io::Error) -> Self {
198            if let Some(inner) = cause.get_ref() {
199                if let Some(e) = inner.downcast_ref::<quinn::WriteError>() {
200                    Self::Write(e.clone())
201                } else {
202                    Self::Io(cause)
203                }
204            } else {
205                Self::Io(cause)
206            }
207        }
208    }
209
210    impl From<ConnectedNextError> for io::Error {
211        fn from(cause: ConnectedNextError) -> Self {
212            match cause {
213                ConnectedNextError::Write(cause) => cause.into(),
214                ConnectedNextError::Io(cause) => cause,
215                ConnectedNextError::PostcardSer(cause) => {
216                    io::Error::new(io::ErrorKind::Other, cause)
217                }
218                _ => io::Error::new(io::ErrorKind::Other, cause),
219            }
220        }
221    }
222
223    impl AtConnected {
224        /// Send the request and move to the next state
225        ///
226        /// The next state will be either `StartRoot` or `StartChild` depending on whether
227        /// the request requests part of the collection or not.
228        ///
229        /// If the request is empty, this can also move directly to `Finished`.
230        pub async fn next(self) -> Result<ConnectedNext, ConnectedNextError> {
231            let Self {
232                start,
233                reader,
234                mut writer,
235                mut request,
236            } = self;
237            // 1. Send Request
238            {
239                debug!("sending request");
240                let wrapped = Request::Get(request);
241                let request_bytes =
242                    postcard::to_stdvec(&wrapped).map_err(ConnectedNextError::PostcardSer)?;
243                let Request::Get(x) = wrapped;
244                request = x;
245
246                if request_bytes.len() > MAX_MESSAGE_SIZE {
247                    return Err(ConnectedNextError::RequestTooBig);
248                }
249
250                // write the request itself
251                writer
252                    .write_all(&request_bytes)
253                    .await
254                    .map_err(ConnectedNextError::from_io)?;
255            }
256
257            // 2. Finish writing before expecting a response
258            let (mut writer, bytes_written) = writer.into_parts();
259            writer.finish().await?;
260
261            let hash = request.hash;
262            let ranges_iter = RangesIter::new(request.ranges);
263            // this is in a box so we don't have to memcpy it on every state transition
264            let mut misc = Box::new(Misc {
265                start,
266                bytes_written,
267                ranges_iter,
268            });
269            Ok(match misc.ranges_iter.next() {
270                Some((offset, ranges)) => {
271                    if offset == 0 {
272                        AtStartRoot {
273                            reader,
274                            ranges,
275                            misc,
276                            hash,
277                        }
278                        .into()
279                    } else {
280                        AtStartChild {
281                            reader,
282                            ranges,
283                            misc,
284                            child_offset: offset - 1,
285                        }
286                        .into()
287                    }
288                }
289                None => AtClosing::new(misc, reader, true).into(),
290            })
291        }
292    }
293
294    /// State of the get response when we start reading a collection
295    #[derive(Debug)]
296    pub struct AtStartRoot {
297        ranges: ChunkRanges,
298        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
299        misc: Box<Misc>,
300        hash: Hash,
301    }
302
303    /// State of the get response when we start reading a child
304    #[derive(Debug)]
305    pub struct AtStartChild {
306        ranges: ChunkRanges,
307        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
308        misc: Box<Misc>,
309        child_offset: u64,
310    }
311
312    impl AtStartChild {
313        /// The offset of the child we are currently reading
314        ///
315        /// This must be used to determine the hash needed to call next.
316        /// If this is larger than the number of children in the collection,
317        /// you can call finish to stop reading the response.
318        pub fn child_offset(&self) -> u64 {
319            self.child_offset
320        }
321
322        /// The ranges we have requested for the child
323        pub fn ranges(&self) -> &ChunkRanges {
324            &self.ranges
325        }
326
327        /// Go into the next state, reading the header
328        ///
329        /// This requires passing in the hash of the child for validation
330        pub fn next(self, hash: Hash) -> AtBlobHeader {
331            AtBlobHeader {
332                reader: self.reader,
333                ranges: self.ranges,
334                misc: self.misc,
335                hash,
336            }
337        }
338
339        /// Finish the get response without reading further
340        ///
341        /// This is used if you know that there are no more children from having
342        /// read the collection, or when you want to stop reading the response
343        /// early.
344        pub fn finish(self) -> AtClosing {
345            AtClosing::new(self.misc, self.reader, false)
346        }
347    }
348
349    impl AtStartRoot {
350        /// The ranges we have requested for the child
351        pub fn ranges(&self) -> &ChunkRanges {
352            &self.ranges
353        }
354
355        /// Hash of the root blob
356        pub fn hash(&self) -> Hash {
357            self.hash
358        }
359
360        /// Go into the next state, reading the header
361        ///
362        /// For the collection we already know the hash, since it was part of the request
363        pub fn next(self) -> AtBlobHeader {
364            AtBlobHeader {
365                reader: self.reader,
366                ranges: self.ranges,
367                hash: self.hash,
368                misc: self.misc,
369            }
370        }
371
372        /// Finish the get response without reading further
373        pub fn finish(self) -> AtClosing {
374            AtClosing::new(self.misc, self.reader, false)
375        }
376    }
377
378    /// State before reading a size header
379    #[derive(Debug)]
380    pub struct AtBlobHeader {
381        ranges: ChunkRanges,
382        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
383        misc: Box<Misc>,
384        hash: Hash,
385    }
386
387    /// Error that you can get from [`AtBlobHeader::next`]
388    #[derive(Debug, thiserror::Error)]
389    pub enum AtBlobHeaderNextError {
390        /// Eof when reading the size header
391        ///
392        /// This indicates that the provider does not have the requested data.
393        #[error("not found")]
394        NotFound,
395        /// Quinn read error when reading the size header
396        #[error("read: {0}")]
397        Read(quinn::ReadError),
398        /// Generic io error
399        #[error("io: {0}")]
400        Io(io::Error),
401    }
402
403    impl From<AtBlobHeaderNextError> for io::Error {
404        fn from(cause: AtBlobHeaderNextError) -> Self {
405            match cause {
406                AtBlobHeaderNextError::NotFound => {
407                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
408                }
409                AtBlobHeaderNextError::Read(cause) => cause.into(),
410                AtBlobHeaderNextError::Io(cause) => cause,
411            }
412        }
413    }
414
415    impl AtBlobHeader {
416        /// Read the size header, returning it and going into the `Content` state.
417        pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> {
418            let size = self.reader.read::<8>().await.map_err(|cause| {
419                if cause.kind() == io::ErrorKind::UnexpectedEof {
420                    AtBlobHeaderNextError::NotFound
421                } else if let Some(e) = cause
422                    .get_ref()
423                    .and_then(|x| x.downcast_ref::<quinn::ReadError>())
424                {
425                    AtBlobHeaderNextError::Read(e.clone())
426                } else {
427                    AtBlobHeaderNextError::Io(cause)
428                }
429            })?;
430            let size = u64::from_le_bytes(size);
431            let stream = ResponseDecoder::new(
432                self.hash.into(),
433                self.ranges,
434                BaoTree::new(size, IROH_BLOCK_SIZE),
435                self.reader,
436            );
437            Ok((
438                AtBlobContent {
439                    stream,
440                    misc: self.misc,
441                },
442                size,
443            ))
444        }
445
446        /// Drain the response and throw away the result
447        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
448            let (content, _size) = self.next().await?;
449            content.drain().await
450        }
451
452        /// Concatenate the entire response into a vec
453        ///
454        /// For a request that does not request the complete blob, this will just
455        /// concatenate the ranges that were requested.
456        pub async fn concatenate_into_vec(
457            self,
458        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
459            let (content, _size) = self.next().await?;
460            content.concatenate_into_vec().await
461        }
462
463        /// Write the entire blob to a slice writer.
464        pub async fn write_all<D: AsyncSliceWriter>(
465            self,
466            data: D,
467        ) -> result::Result<AtEndBlob, DecodeError> {
468            let (content, _size) = self.next().await?;
469            let res = content.write_all(data).await?;
470            Ok(res)
471        }
472
473        /// Write the entire blob to a slice writer and to an optional outboard.
474        ///
475        /// The outboard is only written to if the blob is larger than a single
476        /// chunk group.
477        pub async fn write_all_with_outboard<D, O>(
478            self,
479            outboard: Option<O>,
480            data: D,
481        ) -> result::Result<AtEndBlob, DecodeError>
482        where
483            D: AsyncSliceWriter,
484            O: OutboardMut,
485        {
486            let (content, _size) = self.next().await?;
487            let res = content.write_all_with_outboard(outboard, data).await?;
488            Ok(res)
489        }
490
491        /// Write the entire stream for this blob to a batch writer.
492        pub async fn write_all_batch<B>(self, batch: B) -> result::Result<AtEndBlob, DecodeError>
493        where
494            B: BaoBatchWriter,
495        {
496            let (content, _size) = self.next().await?;
497            let res = content.write_all_batch(batch).await?;
498            Ok(res)
499        }
500
501        /// The hash of the blob we are reading.
502        pub fn hash(&self) -> Hash {
503            self.hash
504        }
505
506        /// The ranges we have requested for the current hash.
507        pub fn ranges(&self) -> &ChunkRanges {
508            &self.ranges
509        }
510
511        /// The current offset of the blob we are reading.
512        pub fn offset(&self) -> u64 {
513            self.misc.ranges_iter.offset()
514        }
515    }
516
517    /// State while we are reading content
518    #[derive(Debug)]
519    pub struct AtBlobContent {
520        stream: ResponseDecoder<WrappedRecvStream>,
521        misc: Box<Misc>,
522    }
523
524    /// Decode error that you can get once you have sent the request and are
525    /// decoding the response, e.g. from [`AtBlobContent::next`].
526    ///
527    /// This is similar to [`bao_tree::io::DecodeError`], but takes into account
528    /// that we are reading from a [`quinn::RecvStream`], so read errors will be
529    /// propagated as [`DecodeError::Read`], containing a [`quinn::ReadError`].
530    /// This carries more concrete information about the error than an [`io::Error`].
531    ///
532    /// When the provider finds that it does not have a chunk that we requested,
533    /// or that the chunk is invalid, it will stop sending data without producing
534    /// an error. This is indicated by either the [`DecodeError::ParentNotFound`] or
535    /// [`DecodeError::LeafNotFound`] variant, which can be used to detect that data
536    /// is missing but the connection as well that the provider is otherwise healthy.
537    ///
538    /// The [`DecodeError::ParentHashMismatch`] and [`DecodeError::LeafHashMismatch`]
539    /// variants indicate that the provider has sent us invalid data. A well-behaved
540    /// provider should never do this, so this is an indication that the provider is
541    /// not behaving correctly.
542    ///
543    /// The [`DecodeError::Io`] variant is just a fallback for any other io error that
544    /// is not actually a [`quinn::ReadError`].
545    #[derive(Debug, thiserror::Error)]
546    pub enum DecodeError {
547        /// A chunk was not found or invalid, so the provider stopped sending data
548        #[error("not found")]
549        NotFound,
550        /// A parent was not found or invalid, so the provider stopped sending data
551        #[error("parent not found {0:?}")]
552        ParentNotFound(TreeNode),
553        /// A parent was not found or invalid, so the provider stopped sending data
554        #[error("chunk not found {0}")]
555        LeafNotFound(ChunkNum),
556        /// The hash of a parent did not match the expected hash
557        #[error("parent hash mismatch: {0:?}")]
558        ParentHashMismatch(TreeNode),
559        /// The hash of a leaf did not match the expected hash
560        #[error("leaf hash mismatch: {0}")]
561        LeafHashMismatch(ChunkNum),
562        /// Error when reading from the stream
563        #[error("read: {0}")]
564        Read(quinn::ReadError),
565        /// A generic io error
566        #[error("io: {0}")]
567        Io(#[from] io::Error),
568    }
569
570    impl From<AtBlobHeaderNextError> for DecodeError {
571        fn from(cause: AtBlobHeaderNextError) -> Self {
572            match cause {
573                AtBlobHeaderNextError::NotFound => Self::NotFound,
574                AtBlobHeaderNextError::Read(cause) => Self::Read(cause),
575                AtBlobHeaderNextError::Io(cause) => Self::Io(cause),
576            }
577        }
578    }
579
580    impl From<DecodeError> for io::Error {
581        fn from(cause: DecodeError) -> Self {
582            match cause {
583                DecodeError::ParentNotFound(_) => {
584                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
585                }
586                DecodeError::LeafNotFound(_) => io::Error::new(io::ErrorKind::UnexpectedEof, cause),
587                DecodeError::Read(cause) => cause.into(),
588                DecodeError::Io(cause) => cause,
589                _ => io::Error::new(io::ErrorKind::Other, cause),
590            }
591        }
592    }
593
594    impl From<bao_tree::io::DecodeError> for DecodeError {
595        fn from(value: bao_tree::io::DecodeError) -> Self {
596            match value {
597                bao_tree::io::DecodeError::ParentNotFound(x) => Self::ParentNotFound(x),
598                bao_tree::io::DecodeError::LeafNotFound(x) => Self::LeafNotFound(x),
599                bao_tree::io::DecodeError::ParentHashMismatch(node) => {
600                    Self::ParentHashMismatch(node)
601                }
602                bao_tree::io::DecodeError::LeafHashMismatch(chunk) => Self::LeafHashMismatch(chunk),
603                bao_tree::io::DecodeError::Io(cause) => {
604                    if let Some(inner) = cause.get_ref() {
605                        if let Some(e) = inner.downcast_ref::<quinn::ReadError>() {
606                            Self::Read(e.clone())
607                        } else {
608                            Self::Io(cause)
609                        }
610                    } else {
611                        Self::Io(cause)
612                    }
613                }
614            }
615        }
616    }
617
618    /// The next state after reading a content item
619    #[derive(Debug, From)]
620    pub enum BlobContentNext {
621        /// We expect more content
622        More((AtBlobContent, result::Result<BaoContentItem, DecodeError>)),
623        /// We are done with this blob
624        Done(AtEndBlob),
625    }
626
627    impl AtBlobContent {
628        /// Read the next item, either content, an error, or the end of the blob
629        pub async fn next(self) -> BlobContentNext {
630            match self.stream.next().await {
631                ResponseDecoderNext::More((stream, res)) => {
632                    let next = Self { stream, ..self };
633                    let res = res.map_err(DecodeError::from);
634                    BlobContentNext::More((next, res))
635                }
636                ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob {
637                    stream,
638                    misc: self.misc,
639                }),
640            }
641        }
642
643        /// The geometry of the tree we are currently reading.
644        pub fn tree(&self) -> bao_tree::BaoTree {
645            self.stream.tree()
646        }
647
648        /// The hash of the blob we are reading.
649        pub fn hash(&self) -> Hash {
650            (*self.stream.hash()).into()
651        }
652
653        /// The current offset of the blob we are reading.
654        pub fn offset(&self) -> u64 {
655            self.misc.ranges_iter.offset()
656        }
657
658        /// Drain the response and throw away the result
659        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
660            let mut content = self;
661            loop {
662                match content.next().await {
663                    BlobContentNext::More((content1, res)) => {
664                        let _ = res?;
665                        content = content1;
666                    }
667                    BlobContentNext::Done(end) => {
668                        break Ok(end);
669                    }
670                }
671            }
672        }
673
674        /// Concatenate the entire response into a vec
675        pub async fn concatenate_into_vec(
676            self,
677        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
678            let mut res = Vec::with_capacity(1024);
679            let mut curr = self;
680            let done = loop {
681                match curr.next().await {
682                    BlobContentNext::More((next, data)) => {
683                        if let BaoContentItem::Leaf(leaf) = data? {
684                            res.extend_from_slice(&leaf.data);
685                        }
686                        curr = next;
687                    }
688                    BlobContentNext::Done(done) => {
689                        // we are done with the root blob
690                        break done;
691                    }
692                }
693            };
694            Ok((done, res))
695        }
696
697        /// Write the entire stream for this blob to a batch writer.
698        pub async fn write_all_batch<B>(self, writer: B) -> result::Result<AtEndBlob, DecodeError>
699        where
700            B: BaoBatchWriter,
701        {
702            let mut writer = writer;
703            let mut buf = Vec::new();
704            let mut content = self;
705            let size = content.tree().size();
706            loop {
707                match content.next().await {
708                    BlobContentNext::More((next, item)) => {
709                        let item = item?;
710                        match &item {
711                            BaoContentItem::Parent(_) => {
712                                buf.push(item);
713                            }
714                            BaoContentItem::Leaf(_) => {
715                                buf.push(item);
716                                let batch = std::mem::take(&mut buf);
717                                writer.write_batch(size, batch).await?;
718                            }
719                        }
720                        content = next;
721                    }
722                    BlobContentNext::Done(end) => {
723                        assert!(buf.is_empty());
724                        return Ok(end);
725                    }
726                }
727            }
728        }
729
730        /// Write the entire blob to a slice writer and to an optional outboard.
731        ///
732        /// The outboard is only written to if the blob is larger than a single
733        /// chunk group.
734        pub async fn write_all_with_outboard<D, O>(
735            self,
736            mut outboard: Option<O>,
737            mut data: D,
738        ) -> result::Result<AtEndBlob, DecodeError>
739        where
740            D: AsyncSliceWriter,
741            O: OutboardMut,
742        {
743            let mut content = self;
744            loop {
745                match content.next().await {
746                    BlobContentNext::More((content1, item)) => {
747                        content = content1;
748                        match item? {
749                            BaoContentItem::Parent(parent) => {
750                                if let Some(outboard) = outboard.as_mut() {
751                                    outboard.save(parent.node, &parent.pair).await?;
752                                }
753                            }
754                            BaoContentItem::Leaf(leaf) => {
755                                data.write_bytes_at(leaf.offset, leaf.data).await?;
756                            }
757                        }
758                    }
759                    BlobContentNext::Done(end) => {
760                        return Ok(end);
761                    }
762                }
763            }
764        }
765
766        /// Write the entire blob to a slice writer.
767        pub async fn write_all<D>(self, mut data: D) -> result::Result<AtEndBlob, DecodeError>
768        where
769            D: AsyncSliceWriter,
770        {
771            let mut content = self;
772            loop {
773                match content.next().await {
774                    BlobContentNext::More((content1, item)) => {
775                        content = content1;
776                        match item? {
777                            BaoContentItem::Parent(_) => {}
778                            BaoContentItem::Leaf(leaf) => {
779                                data.write_bytes_at(leaf.offset, leaf.data).await?;
780                            }
781                        }
782                    }
783                    BlobContentNext::Done(end) => {
784                        return Ok(end);
785                    }
786                }
787            }
788        }
789
790        /// Immediately finish the get response without reading further
791        pub fn finish(self) -> AtClosing {
792            AtClosing::new(self.misc, self.stream.finish(), false)
793        }
794    }
795
796    /// State after we have read all the content for a blob
797    #[derive(Debug)]
798    pub struct AtEndBlob {
799        stream: WrappedRecvStream,
800        misc: Box<Misc>,
801    }
802
803    /// The next state after the end of a blob
804    #[derive(Debug, From)]
805    pub enum EndBlobNext {
806        /// Response is expected to have more children
807        MoreChildren(AtStartChild),
808        /// No more children expected
809        Closing(AtClosing),
810    }
811
812    impl AtEndBlob {
813        /// Read the next child, or finish
814        pub fn next(mut self) -> EndBlobNext {
815            if let Some((offset, ranges)) = self.misc.ranges_iter.next() {
816                AtStartChild {
817                    reader: self.stream,
818                    child_offset: offset - 1,
819                    ranges,
820                    misc: self.misc,
821                }
822                .into()
823            } else {
824                AtClosing::new(self.misc, self.stream, true).into()
825            }
826        }
827    }
828
829    /// State when finishing the get response
830    #[derive(Debug)]
831    pub struct AtClosing {
832        misc: Box<Misc>,
833        reader: WrappedRecvStream,
834        check_extra_data: bool,
835    }
836
837    impl AtClosing {
838        fn new(misc: Box<Misc>, reader: WrappedRecvStream, check_extra_data: bool) -> Self {
839            Self {
840                misc,
841                reader,
842                check_extra_data,
843            }
844        }
845
846        /// Finish the get response, returning statistics
847        pub async fn next(self) -> result::Result<Stats, quinn::ReadError> {
848            // Shut down the stream
849            let (reader, bytes_read) = self.reader.into_parts();
850            let mut reader = reader.into_inner();
851            if self.check_extra_data {
852                if let Some(chunk) = reader.read_chunk(8, false).await? {
853                    reader.stop(0u8.into()).ok();
854                    error!("Received unexpected data from the provider: {chunk:?}");
855                }
856            } else {
857                reader.stop(0u8.into()).ok();
858            }
859            Ok(Stats {
860                elapsed: self.misc.start.elapsed(),
861                bytes_written: self.misc.bytes_written,
862                bytes_read,
863            })
864        }
865    }
866
867    /// Stuff we need to hold on to while going through the machine states
868    #[derive(Debug)]
869    struct Misc {
870        /// start time for statistics
871        start: Instant,
872        /// bytes written for statistics
873        bytes_written: u64,
874        /// iterator over the ranges of the collection and the children
875        ranges_iter: RangesIter,
876    }
877}
878
879/// Error when processing a response
880#[derive(thiserror::Error, Debug)]
881pub enum GetResponseError {
882    /// Error when opening a stream
883    #[error("connection: {0}")]
884    Connection(#[from] quinn::ConnectionError),
885    /// Error when writing the handshake or request to the stream
886    #[error("write: {0}")]
887    Write(#[from] quinn::WriteError),
888    /// Error when reading from the stream
889    #[error("read: {0}")]
890    Read(#[from] quinn::ReadError),
891    /// Error when decoding, e.g. hash mismatch
892    #[error("decode: {0}")]
893    Decode(bao_tree::io::DecodeError),
894    /// A generic error
895    #[error("generic: {0}")]
896    Generic(anyhow::Error),
897}
898
899impl From<postcard::Error> for GetResponseError {
900    fn from(cause: postcard::Error) -> Self {
901        Self::Generic(cause.into())
902    }
903}
904
905impl From<bao_tree::io::DecodeError> for GetResponseError {
906    fn from(cause: bao_tree::io::DecodeError) -> Self {
907        match cause {
908            bao_tree::io::DecodeError::Io(cause) => {
909                // try to downcast to specific quinn errors
910                if let Some(source) = cause.source() {
911                    if let Some(error) = source.downcast_ref::<quinn::ConnectionError>() {
912                        return Self::Connection(error.clone());
913                    }
914                    if let Some(error) = source.downcast_ref::<quinn::ReadError>() {
915                        return Self::Read(error.clone());
916                    }
917                    if let Some(error) = source.downcast_ref::<quinn::WriteError>() {
918                        return Self::Write(error.clone());
919                    }
920                }
921                Self::Generic(cause.into())
922            }
923            _ => Self::Decode(cause),
924        }
925    }
926}
927
928impl From<anyhow::Error> for GetResponseError {
929    fn from(cause: anyhow::Error) -> Self {
930        Self::Generic(cause)
931    }
932}
933
934impl From<GetResponseError> for std::io::Error {
935    fn from(cause: GetResponseError) -> Self {
936        Self::new(std::io::ErrorKind::Other, cause)
937    }
938}