1use std::error::Error;
16use std::fmt::{self, Debug};
17use std::time::{Duration, Instant};
18
19use crate::Hash;
20use anyhow::Result;
21use bao_tree::io::fsm::BaoContentItem;
22use bao_tree::ChunkNum;
23use quinn::RecvStream;
24use serde::{Deserialize, Serialize};
25use tracing::{debug, error};
26
27use crate::protocol::RangeSpecSeq;
28use crate::util::io::{TrackingReader, TrackingWriter};
29use crate::IROH_BLOCK_SIZE;
30
31pub mod db;
32pub mod error;
33pub mod progress;
34pub mod request;
35
36#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
38pub struct Stats {
39    pub bytes_written: u64,
41    pub bytes_read: u64,
43    pub elapsed: Duration,
45}
46
47impl Stats {
48    pub fn mbits(&self) -> f64 {
50        let data_len_bit = self.bytes_read * 8;
51        data_len_bit as f64 / (1000. * 1000.) / self.elapsed.as_secs_f64()
52    }
53}
54
55#[doc = include_str!("../docs/img/get_machine.drawio.svg")]
59pub mod fsm {
60    use std::{io, result};
61
62    use crate::{
63        protocol::{GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE},
64        store::BaoBatchWriter,
65    };
66
67    use super::*;
68
69    use bao_tree::{
70        io::fsm::{OutboardMut, ResponseDecoder, ResponseDecoderNext},
71        BaoTree, ChunkRanges, TreeNode,
72    };
73    use derive_more::From;
74    use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader};
75    use tokio::io::AsyncWriteExt;
76
77    type WrappedRecvStream = TrackingReader<TokioStreamReader<RecvStream>>;
78
79    self_cell::self_cell! {
80        struct RangesIterInner {
81            owner: RangeSpecSeq,
82            #[covariant]
83            dependent: NonEmptyRequestRangeSpecIter,
84        }
85    }
86
87    pub fn start(connection: quinn::Connection, request: GetRequest) -> AtInitial {
89        AtInitial::new(connection, request)
90    }
91
92    struct RangesIter(RangesIterInner);
97
98    impl fmt::Debug for RangesIter {
99        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100            f.debug_struct("RangesIter").finish()
101        }
102    }
103
104    impl RangesIter {
105        pub fn new(owner: RangeSpecSeq) -> Self {
106            Self(RangesIterInner::new(owner, |owner| owner.iter_non_empty()))
107        }
108
109        pub fn offset(&self) -> u64 {
110            self.0.with_dependent(|_owner, iter| iter.offset())
111        }
112    }
113
114    impl Iterator for RangesIter {
115        type Item = (u64, ChunkRanges);
116
117        fn next(&mut self) -> Option<Self::Item> {
118            self.0.with_dependent_mut(|_owner, iter| {
119                iter.next()
120                    .map(|(offset, ranges)| (offset, ranges.to_chunk_ranges()))
121            })
122        }
123    }
124
125    #[derive(Debug)]
127    pub struct AtInitial {
128        connection: quinn::Connection,
129        request: GetRequest,
130    }
131
132    impl AtInitial {
133        pub fn new(connection: quinn::Connection, request: GetRequest) -> Self {
138            Self {
139                connection,
140                request,
141            }
142        }
143
144        pub async fn next(self) -> Result<AtConnected, quinn::ConnectionError> {
146            let start = Instant::now();
147            let (writer, reader) = self.connection.open_bi().await?;
148            let reader = TrackingReader::new(TokioStreamReader::new(reader));
149            let writer = TrackingWriter::new(writer);
150            Ok(AtConnected {
151                start,
152                reader,
153                writer,
154                request: self.request,
155            })
156        }
157    }
158
159    #[derive(Debug)]
161    pub struct AtConnected {
162        start: Instant,
163        reader: WrappedRecvStream,
164        writer: TrackingWriter<quinn::SendStream>,
165        request: GetRequest,
166    }
167
168    #[derive(Debug, From)]
170    pub enum ConnectedNext {
171        StartRoot(AtStartRoot),
173        StartChild(AtStartChild),
175        Closing(AtClosing),
177    }
178
179    #[derive(Debug, thiserror::Error)]
181    pub enum ConnectedNextError {
182        #[error("postcard ser: {0}")]
184        PostcardSer(postcard::Error),
185        #[error("request too big")]
187        RequestTooBig,
188        #[error("write: {0}")]
190        Write(#[from] quinn::WriteError),
191        #[error("io {0}")]
193        Io(io::Error),
194    }
195
196    impl ConnectedNextError {
197        fn from_io(cause: io::Error) -> Self {
198            if let Some(inner) = cause.get_ref() {
199                if let Some(e) = inner.downcast_ref::<quinn::WriteError>() {
200                    Self::Write(e.clone())
201                } else {
202                    Self::Io(cause)
203                }
204            } else {
205                Self::Io(cause)
206            }
207        }
208    }
209
210    impl From<ConnectedNextError> for io::Error {
211        fn from(cause: ConnectedNextError) -> Self {
212            match cause {
213                ConnectedNextError::Write(cause) => cause.into(),
214                ConnectedNextError::Io(cause) => cause,
215                ConnectedNextError::PostcardSer(cause) => {
216                    io::Error::new(io::ErrorKind::Other, cause)
217                }
218                _ => io::Error::new(io::ErrorKind::Other, cause),
219            }
220        }
221    }
222
223    impl AtConnected {
224        pub async fn next(self) -> Result<ConnectedNext, ConnectedNextError> {
231            let Self {
232                start,
233                reader,
234                mut writer,
235                mut request,
236            } = self;
237            {
239                debug!("sending request");
240                let wrapped = Request::Get(request);
241                let request_bytes =
242                    postcard::to_stdvec(&wrapped).map_err(ConnectedNextError::PostcardSer)?;
243                let Request::Get(x) = wrapped;
244                request = x;
245
246                if request_bytes.len() > MAX_MESSAGE_SIZE {
247                    return Err(ConnectedNextError::RequestTooBig);
248                }
249
250                writer
252                    .write_all(&request_bytes)
253                    .await
254                    .map_err(ConnectedNextError::from_io)?;
255            }
256
257            let (mut writer, bytes_written) = writer.into_parts();
259            writer.finish().await?;
260
261            let hash = request.hash;
262            let ranges_iter = RangesIter::new(request.ranges);
263            let mut misc = Box::new(Misc {
265                start,
266                bytes_written,
267                ranges_iter,
268            });
269            Ok(match misc.ranges_iter.next() {
270                Some((offset, ranges)) => {
271                    if offset == 0 {
272                        AtStartRoot {
273                            reader,
274                            ranges,
275                            misc,
276                            hash,
277                        }
278                        .into()
279                    } else {
280                        AtStartChild {
281                            reader,
282                            ranges,
283                            misc,
284                            child_offset: offset - 1,
285                        }
286                        .into()
287                    }
288                }
289                None => AtClosing::new(misc, reader, true).into(),
290            })
291        }
292    }
293
294    #[derive(Debug)]
296    pub struct AtStartRoot {
297        ranges: ChunkRanges,
298        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
299        misc: Box<Misc>,
300        hash: Hash,
301    }
302
303    #[derive(Debug)]
305    pub struct AtStartChild {
306        ranges: ChunkRanges,
307        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
308        misc: Box<Misc>,
309        child_offset: u64,
310    }
311
312    impl AtStartChild {
313        pub fn child_offset(&self) -> u64 {
319            self.child_offset
320        }
321
322        pub fn ranges(&self) -> &ChunkRanges {
324            &self.ranges
325        }
326
327        pub fn next(self, hash: Hash) -> AtBlobHeader {
331            AtBlobHeader {
332                reader: self.reader,
333                ranges: self.ranges,
334                misc: self.misc,
335                hash,
336            }
337        }
338
339        pub fn finish(self) -> AtClosing {
345            AtClosing::new(self.misc, self.reader, false)
346        }
347    }
348
349    impl AtStartRoot {
350        pub fn ranges(&self) -> &ChunkRanges {
352            &self.ranges
353        }
354
355        pub fn hash(&self) -> Hash {
357            self.hash
358        }
359
360        pub fn next(self) -> AtBlobHeader {
364            AtBlobHeader {
365                reader: self.reader,
366                ranges: self.ranges,
367                hash: self.hash,
368                misc: self.misc,
369            }
370        }
371
372        pub fn finish(self) -> AtClosing {
374            AtClosing::new(self.misc, self.reader, false)
375        }
376    }
377
378    #[derive(Debug)]
380    pub struct AtBlobHeader {
381        ranges: ChunkRanges,
382        reader: TrackingReader<TokioStreamReader<quinn::RecvStream>>,
383        misc: Box<Misc>,
384        hash: Hash,
385    }
386
387    #[derive(Debug, thiserror::Error)]
389    pub enum AtBlobHeaderNextError {
390        #[error("not found")]
394        NotFound,
395        #[error("read: {0}")]
397        Read(quinn::ReadError),
398        #[error("io: {0}")]
400        Io(io::Error),
401    }
402
403    impl From<AtBlobHeaderNextError> for io::Error {
404        fn from(cause: AtBlobHeaderNextError) -> Self {
405            match cause {
406                AtBlobHeaderNextError::NotFound => {
407                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
408                }
409                AtBlobHeaderNextError::Read(cause) => cause.into(),
410                AtBlobHeaderNextError::Io(cause) => cause,
411            }
412        }
413    }
414
415    impl AtBlobHeader {
416        pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> {
418            let size = self.reader.read::<8>().await.map_err(|cause| {
419                if cause.kind() == io::ErrorKind::UnexpectedEof {
420                    AtBlobHeaderNextError::NotFound
421                } else if let Some(e) = cause
422                    .get_ref()
423                    .and_then(|x| x.downcast_ref::<quinn::ReadError>())
424                {
425                    AtBlobHeaderNextError::Read(e.clone())
426                } else {
427                    AtBlobHeaderNextError::Io(cause)
428                }
429            })?;
430            let size = u64::from_le_bytes(size);
431            let stream = ResponseDecoder::new(
432                self.hash.into(),
433                self.ranges,
434                BaoTree::new(size, IROH_BLOCK_SIZE),
435                self.reader,
436            );
437            Ok((
438                AtBlobContent {
439                    stream,
440                    misc: self.misc,
441                },
442                size,
443            ))
444        }
445
446        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
448            let (content, _size) = self.next().await?;
449            content.drain().await
450        }
451
452        pub async fn concatenate_into_vec(
457            self,
458        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
459            let (content, _size) = self.next().await?;
460            content.concatenate_into_vec().await
461        }
462
463        pub async fn write_all<D: AsyncSliceWriter>(
465            self,
466            data: D,
467        ) -> result::Result<AtEndBlob, DecodeError> {
468            let (content, _size) = self.next().await?;
469            let res = content.write_all(data).await?;
470            Ok(res)
471        }
472
473        pub async fn write_all_with_outboard<D, O>(
478            self,
479            outboard: Option<O>,
480            data: D,
481        ) -> result::Result<AtEndBlob, DecodeError>
482        where
483            D: AsyncSliceWriter,
484            O: OutboardMut,
485        {
486            let (content, _size) = self.next().await?;
487            let res = content.write_all_with_outboard(outboard, data).await?;
488            Ok(res)
489        }
490
491        pub async fn write_all_batch<B>(self, batch: B) -> result::Result<AtEndBlob, DecodeError>
493        where
494            B: BaoBatchWriter,
495        {
496            let (content, _size) = self.next().await?;
497            let res = content.write_all_batch(batch).await?;
498            Ok(res)
499        }
500
501        pub fn hash(&self) -> Hash {
503            self.hash
504        }
505
506        pub fn ranges(&self) -> &ChunkRanges {
508            &self.ranges
509        }
510
511        pub fn offset(&self) -> u64 {
513            self.misc.ranges_iter.offset()
514        }
515    }
516
517    #[derive(Debug)]
519    pub struct AtBlobContent {
520        stream: ResponseDecoder<WrappedRecvStream>,
521        misc: Box<Misc>,
522    }
523
524    #[derive(Debug, thiserror::Error)]
546    pub enum DecodeError {
547        #[error("not found")]
549        NotFound,
550        #[error("parent not found {0:?}")]
552        ParentNotFound(TreeNode),
553        #[error("chunk not found {0}")]
555        LeafNotFound(ChunkNum),
556        #[error("parent hash mismatch: {0:?}")]
558        ParentHashMismatch(TreeNode),
559        #[error("leaf hash mismatch: {0}")]
561        LeafHashMismatch(ChunkNum),
562        #[error("read: {0}")]
564        Read(quinn::ReadError),
565        #[error("io: {0}")]
567        Io(#[from] io::Error),
568    }
569
570    impl From<AtBlobHeaderNextError> for DecodeError {
571        fn from(cause: AtBlobHeaderNextError) -> Self {
572            match cause {
573                AtBlobHeaderNextError::NotFound => Self::NotFound,
574                AtBlobHeaderNextError::Read(cause) => Self::Read(cause),
575                AtBlobHeaderNextError::Io(cause) => Self::Io(cause),
576            }
577        }
578    }
579
580    impl From<DecodeError> for io::Error {
581        fn from(cause: DecodeError) -> Self {
582            match cause {
583                DecodeError::ParentNotFound(_) => {
584                    io::Error::new(io::ErrorKind::UnexpectedEof, cause)
585                }
586                DecodeError::LeafNotFound(_) => io::Error::new(io::ErrorKind::UnexpectedEof, cause),
587                DecodeError::Read(cause) => cause.into(),
588                DecodeError::Io(cause) => cause,
589                _ => io::Error::new(io::ErrorKind::Other, cause),
590            }
591        }
592    }
593
594    impl From<bao_tree::io::DecodeError> for DecodeError {
595        fn from(value: bao_tree::io::DecodeError) -> Self {
596            match value {
597                bao_tree::io::DecodeError::ParentNotFound(x) => Self::ParentNotFound(x),
598                bao_tree::io::DecodeError::LeafNotFound(x) => Self::LeafNotFound(x),
599                bao_tree::io::DecodeError::ParentHashMismatch(node) => {
600                    Self::ParentHashMismatch(node)
601                }
602                bao_tree::io::DecodeError::LeafHashMismatch(chunk) => Self::LeafHashMismatch(chunk),
603                bao_tree::io::DecodeError::Io(cause) => {
604                    if let Some(inner) = cause.get_ref() {
605                        if let Some(e) = inner.downcast_ref::<quinn::ReadError>() {
606                            Self::Read(e.clone())
607                        } else {
608                            Self::Io(cause)
609                        }
610                    } else {
611                        Self::Io(cause)
612                    }
613                }
614            }
615        }
616    }
617
618    #[derive(Debug, From)]
620    pub enum BlobContentNext {
621        More((AtBlobContent, result::Result<BaoContentItem, DecodeError>)),
623        Done(AtEndBlob),
625    }
626
627    impl AtBlobContent {
628        pub async fn next(self) -> BlobContentNext {
630            match self.stream.next().await {
631                ResponseDecoderNext::More((stream, res)) => {
632                    let next = Self { stream, ..self };
633                    let res = res.map_err(DecodeError::from);
634                    BlobContentNext::More((next, res))
635                }
636                ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob {
637                    stream,
638                    misc: self.misc,
639                }),
640            }
641        }
642
643        pub fn tree(&self) -> bao_tree::BaoTree {
645            self.stream.tree()
646        }
647
648        pub fn hash(&self) -> Hash {
650            (*self.stream.hash()).into()
651        }
652
653        pub fn offset(&self) -> u64 {
655            self.misc.ranges_iter.offset()
656        }
657
658        pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
660            let mut content = self;
661            loop {
662                match content.next().await {
663                    BlobContentNext::More((content1, res)) => {
664                        let _ = res?;
665                        content = content1;
666                    }
667                    BlobContentNext::Done(end) => {
668                        break Ok(end);
669                    }
670                }
671            }
672        }
673
674        pub async fn concatenate_into_vec(
676            self,
677        ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
678            let mut res = Vec::with_capacity(1024);
679            let mut curr = self;
680            let done = loop {
681                match curr.next().await {
682                    BlobContentNext::More((next, data)) => {
683                        if let BaoContentItem::Leaf(leaf) = data? {
684                            res.extend_from_slice(&leaf.data);
685                        }
686                        curr = next;
687                    }
688                    BlobContentNext::Done(done) => {
689                        break done;
691                    }
692                }
693            };
694            Ok((done, res))
695        }
696
697        pub async fn write_all_batch<B>(self, writer: B) -> result::Result<AtEndBlob, DecodeError>
699        where
700            B: BaoBatchWriter,
701        {
702            let mut writer = writer;
703            let mut buf = Vec::new();
704            let mut content = self;
705            let size = content.tree().size();
706            loop {
707                match content.next().await {
708                    BlobContentNext::More((next, item)) => {
709                        let item = item?;
710                        match &item {
711                            BaoContentItem::Parent(_) => {
712                                buf.push(item);
713                            }
714                            BaoContentItem::Leaf(_) => {
715                                buf.push(item);
716                                let batch = std::mem::take(&mut buf);
717                                writer.write_batch(size, batch).await?;
718                            }
719                        }
720                        content = next;
721                    }
722                    BlobContentNext::Done(end) => {
723                        assert!(buf.is_empty());
724                        return Ok(end);
725                    }
726                }
727            }
728        }
729
730        pub async fn write_all_with_outboard<D, O>(
735            self,
736            mut outboard: Option<O>,
737            mut data: D,
738        ) -> result::Result<AtEndBlob, DecodeError>
739        where
740            D: AsyncSliceWriter,
741            O: OutboardMut,
742        {
743            let mut content = self;
744            loop {
745                match content.next().await {
746                    BlobContentNext::More((content1, item)) => {
747                        content = content1;
748                        match item? {
749                            BaoContentItem::Parent(parent) => {
750                                if let Some(outboard) = outboard.as_mut() {
751                                    outboard.save(parent.node, &parent.pair).await?;
752                                }
753                            }
754                            BaoContentItem::Leaf(leaf) => {
755                                data.write_bytes_at(leaf.offset, leaf.data).await?;
756                            }
757                        }
758                    }
759                    BlobContentNext::Done(end) => {
760                        return Ok(end);
761                    }
762                }
763            }
764        }
765
766        pub async fn write_all<D>(self, mut data: D) -> result::Result<AtEndBlob, DecodeError>
768        where
769            D: AsyncSliceWriter,
770        {
771            let mut content = self;
772            loop {
773                match content.next().await {
774                    BlobContentNext::More((content1, item)) => {
775                        content = content1;
776                        match item? {
777                            BaoContentItem::Parent(_) => {}
778                            BaoContentItem::Leaf(leaf) => {
779                                data.write_bytes_at(leaf.offset, leaf.data).await?;
780                            }
781                        }
782                    }
783                    BlobContentNext::Done(end) => {
784                        return Ok(end);
785                    }
786                }
787            }
788        }
789
790        pub fn finish(self) -> AtClosing {
792            AtClosing::new(self.misc, self.stream.finish(), false)
793        }
794    }
795
796    #[derive(Debug)]
798    pub struct AtEndBlob {
799        stream: WrappedRecvStream,
800        misc: Box<Misc>,
801    }
802
803    #[derive(Debug, From)]
805    pub enum EndBlobNext {
806        MoreChildren(AtStartChild),
808        Closing(AtClosing),
810    }
811
812    impl AtEndBlob {
813        pub fn next(mut self) -> EndBlobNext {
815            if let Some((offset, ranges)) = self.misc.ranges_iter.next() {
816                AtStartChild {
817                    reader: self.stream,
818                    child_offset: offset - 1,
819                    ranges,
820                    misc: self.misc,
821                }
822                .into()
823            } else {
824                AtClosing::new(self.misc, self.stream, true).into()
825            }
826        }
827    }
828
829    #[derive(Debug)]
831    pub struct AtClosing {
832        misc: Box<Misc>,
833        reader: WrappedRecvStream,
834        check_extra_data: bool,
835    }
836
837    impl AtClosing {
838        fn new(misc: Box<Misc>, reader: WrappedRecvStream, check_extra_data: bool) -> Self {
839            Self {
840                misc,
841                reader,
842                check_extra_data,
843            }
844        }
845
846        pub async fn next(self) -> result::Result<Stats, quinn::ReadError> {
848            let (reader, bytes_read) = self.reader.into_parts();
850            let mut reader = reader.into_inner();
851            if self.check_extra_data {
852                if let Some(chunk) = reader.read_chunk(8, false).await? {
853                    reader.stop(0u8.into()).ok();
854                    error!("Received unexpected data from the provider: {chunk:?}");
855                }
856            } else {
857                reader.stop(0u8.into()).ok();
858            }
859            Ok(Stats {
860                elapsed: self.misc.start.elapsed(),
861                bytes_written: self.misc.bytes_written,
862                bytes_read,
863            })
864        }
865    }
866
867    #[derive(Debug)]
869    struct Misc {
870        start: Instant,
872        bytes_written: u64,
874        ranges_iter: RangesIter,
876    }
877}
878
879#[derive(thiserror::Error, Debug)]
881pub enum GetResponseError {
882    #[error("connection: {0}")]
884    Connection(#[from] quinn::ConnectionError),
885    #[error("write: {0}")]
887    Write(#[from] quinn::WriteError),
888    #[error("read: {0}")]
890    Read(#[from] quinn::ReadError),
891    #[error("decode: {0}")]
893    Decode(bao_tree::io::DecodeError),
894    #[error("generic: {0}")]
896    Generic(anyhow::Error),
897}
898
899impl From<postcard::Error> for GetResponseError {
900    fn from(cause: postcard::Error) -> Self {
901        Self::Generic(cause.into())
902    }
903}
904
905impl From<bao_tree::io::DecodeError> for GetResponseError {
906    fn from(cause: bao_tree::io::DecodeError) -> Self {
907        match cause {
908            bao_tree::io::DecodeError::Io(cause) => {
909                if let Some(source) = cause.source() {
911                    if let Some(error) = source.downcast_ref::<quinn::ConnectionError>() {
912                        return Self::Connection(error.clone());
913                    }
914                    if let Some(error) = source.downcast_ref::<quinn::ReadError>() {
915                        return Self::Read(error.clone());
916                    }
917                    if let Some(error) = source.downcast_ref::<quinn::WriteError>() {
918                        return Self::Write(error.clone());
919                    }
920                }
921                Self::Generic(cause.into())
922            }
923            _ => Self::Decode(cause),
924        }
925    }
926}
927
928impl From<anyhow::Error> for GetResponseError {
929    fn from(cause: anyhow::Error) -> Self {
930        Self::Generic(cause)
931    }
932}
933
934impl From<GetResponseError> for std::io::Error {
935    fn from(cause: GetResponseError) -> Self {
936        Self::new(std::io::ErrorKind::Other, cause)
937    }
938}