1use std::{
16 error::Error,
17 fmt::{self, Debug},
18 time::{Duration, Instant},
19};
20
21use anyhow::Result;
22use bao_tree::{io::fsm::BaoContentItem, ChunkNum};
23use iroh::endpoint::{self, ClosedStream, RecvStream, SendStream, WriteError};
24use serde::{Deserialize, Serialize};
25use tracing::{debug, error};
26
27use crate::{
28 protocol::RangeSpecSeq,
29 util::io::{TrackingReader, TrackingWriter},
30 Hash, IROH_BLOCK_SIZE,
31};
32
33pub mod db;
34pub mod error;
35pub mod progress;
36pub mod request;
37
38#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct Stats {
41 pub bytes_written: u64,
43 pub bytes_read: u64,
45 pub elapsed: Duration,
47}
48
49impl Stats {
50 pub fn mbits(&self) -> f64 {
52 let data_len_bit = self.bytes_read * 8;
53 data_len_bit as f64 / (1000. * 1000.) / self.elapsed.as_secs_f64()
54 }
55}
56
57#[doc = include_str!("../docs/img/get_machine.drawio.svg")]
61pub mod fsm {
62 use std::{io, result};
63
64 use bao_tree::{
65 io::fsm::{OutboardMut, ResponseDecoder, ResponseDecoderNext},
66 BaoTree, ChunkRanges, TreeNode,
67 };
68 use derive_more::From;
69 use iroh::endpoint::Connection;
70 use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader};
71 use tokio::io::AsyncWriteExt;
72
73 use super::*;
74 use crate::{
75 protocol::{GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE},
76 store::BaoBatchWriter,
77 };
78
79 type WrappedRecvStream = TrackingReader<TokioStreamReader<RecvStream>>;
80
81 self_cell::self_cell! {
82 struct RangesIterInner {
83 owner: RangeSpecSeq,
84 #[covariant]
85 dependent: NonEmptyRequestRangeSpecIter,
86 }
87 }
88
89 pub fn start(connection: Connection, request: GetRequest) -> AtInitial {
91 AtInitial::new(connection, request)
92 }
93
94 struct RangesIter(RangesIterInner);
99
100 impl fmt::Debug for RangesIter {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("RangesIter").finish()
103 }
104 }
105
106 impl RangesIter {
107 pub fn new(owner: RangeSpecSeq) -> Self {
108 Self(RangesIterInner::new(owner, |owner| owner.iter_non_empty()))
109 }
110
111 pub fn offset(&self) -> u64 {
112 self.0.with_dependent(|_owner, iter| iter.offset())
113 }
114 }
115
116 impl Iterator for RangesIter {
117 type Item = (u64, ChunkRanges);
118
119 fn next(&mut self) -> Option<Self::Item> {
120 self.0.with_dependent_mut(|_owner, iter| {
121 iter.next()
122 .map(|(offset, ranges)| (offset, ranges.to_chunk_ranges()))
123 })
124 }
125 }
126
127 #[derive(Debug)]
129 pub struct AtInitial {
130 connection: Connection,
131 request: GetRequest,
132 }
133
134 impl AtInitial {
135 pub fn new(connection: Connection, request: GetRequest) -> Self {
140 Self {
141 connection,
142 request,
143 }
144 }
145
146 pub async fn next(self) -> Result<AtConnected, endpoint::ConnectionError> {
148 let start = Instant::now();
149 let (writer, reader) = self.connection.open_bi().await?;
150 let reader = TrackingReader::new(TokioStreamReader::new(reader));
151 let writer = TrackingWriter::new(writer);
152 Ok(AtConnected {
153 start,
154 reader,
155 writer,
156 request: self.request,
157 })
158 }
159 }
160
161 #[derive(Debug)]
163 pub struct AtConnected {
164 start: Instant,
165 reader: WrappedRecvStream,
166 writer: TrackingWriter<SendStream>,
167 request: GetRequest,
168 }
169
170 #[derive(Debug, From)]
172 pub enum ConnectedNext {
173 StartRoot(AtStartRoot),
175 StartChild(AtStartChild),
177 Closing(AtClosing),
179 }
180
181 #[derive(Debug, thiserror::Error)]
183 pub enum ConnectedNextError {
184 #[error("postcard ser: {0}")]
186 PostcardSer(postcard::Error),
187 #[error("request too big")]
189 RequestTooBig,
190 #[error("write: {0}")]
192 Write(#[from] WriteError),
193 #[error("closed")]
195 Closed(#[from] ClosedStream),
196 #[error("io {0}")]
198 Io(io::Error),
199 }
200
201 impl ConnectedNextError {
202 fn from_io(cause: io::Error) -> Self {
203 if let Some(inner) = cause.get_ref() {
204 if let Some(e) = inner.downcast_ref::<endpoint::WriteError>() {
205 Self::Write(e.clone())
206 } else {
207 Self::Io(cause)
208 }
209 } else {
210 Self::Io(cause)
211 }
212 }
213 }
214
215 impl From<ConnectedNextError> for io::Error {
216 fn from(cause: ConnectedNextError) -> Self {
217 match cause {
218 ConnectedNextError::Write(cause) => cause.into(),
219 ConnectedNextError::Io(cause) => cause,
220 ConnectedNextError::PostcardSer(cause) => {
221 io::Error::new(io::ErrorKind::Other, cause)
222 }
223 _ => io::Error::new(io::ErrorKind::Other, cause),
224 }
225 }
226 }
227
228 impl AtConnected {
229 pub async fn next(self) -> Result<ConnectedNext, ConnectedNextError> {
236 let Self {
237 start,
238 reader,
239 mut writer,
240 mut request,
241 } = self;
242 {
244 debug!("sending request");
245 let wrapped = Request::Get(request);
246 let request_bytes =
247 postcard::to_stdvec(&wrapped).map_err(ConnectedNextError::PostcardSer)?;
248 let Request::Get(x) = wrapped;
249 request = x;
250
251 if request_bytes.len() > MAX_MESSAGE_SIZE {
252 return Err(ConnectedNextError::RequestTooBig);
253 }
254
255 writer
257 .write_all(&request_bytes)
258 .await
259 .map_err(ConnectedNextError::from_io)?;
260 }
261
262 let (mut writer, bytes_written) = writer.into_parts();
264 writer.finish()?;
265
266 let hash = request.hash;
267 let ranges_iter = RangesIter::new(request.ranges);
268 let mut misc = Box::new(Misc {
270 start,
271 bytes_written,
272 ranges_iter,
273 });
274 Ok(match misc.ranges_iter.next() {
275 Some((offset, ranges)) => {
276 if offset == 0 {
277 AtStartRoot {
278 reader,
279 ranges,
280 misc,
281 hash,
282 }
283 .into()
284 } else {
285 AtStartChild {
286 reader,
287 ranges,
288 misc,
289 child_offset: offset - 1,
290 }
291 .into()
292 }
293 }
294 None => AtClosing::new(misc, reader, true).into(),
295 })
296 }
297 }
298
299 #[derive(Debug)]
301 pub struct AtStartRoot {
302 ranges: ChunkRanges,
303 reader: TrackingReader<TokioStreamReader<RecvStream>>,
304 misc: Box<Misc>,
305 hash: Hash,
306 }
307
308 #[derive(Debug)]
310 pub struct AtStartChild {
311 ranges: ChunkRanges,
312 reader: TrackingReader<TokioStreamReader<RecvStream>>,
313 misc: Box<Misc>,
314 child_offset: u64,
315 }
316
317 impl AtStartChild {
318 pub fn child_offset(&self) -> u64 {
324 self.child_offset
325 }
326
327 pub fn ranges(&self) -> &ChunkRanges {
329 &self.ranges
330 }
331
332 pub fn next(self, hash: Hash) -> AtBlobHeader {
336 AtBlobHeader {
337 reader: self.reader,
338 ranges: self.ranges,
339 misc: self.misc,
340 hash,
341 }
342 }
343
344 pub fn finish(self) -> AtClosing {
350 AtClosing::new(self.misc, self.reader, false)
351 }
352 }
353
354 impl AtStartRoot {
355 pub fn ranges(&self) -> &ChunkRanges {
357 &self.ranges
358 }
359
360 pub fn hash(&self) -> Hash {
362 self.hash
363 }
364
365 pub fn next(self) -> AtBlobHeader {
369 AtBlobHeader {
370 reader: self.reader,
371 ranges: self.ranges,
372 hash: self.hash,
373 misc: self.misc,
374 }
375 }
376
377 pub fn finish(self) -> AtClosing {
379 AtClosing::new(self.misc, self.reader, false)
380 }
381 }
382
383 #[derive(Debug)]
385 pub struct AtBlobHeader {
386 ranges: ChunkRanges,
387 reader: TrackingReader<TokioStreamReader<RecvStream>>,
388 misc: Box<Misc>,
389 hash: Hash,
390 }
391
392 #[derive(Debug, thiserror::Error)]
394 pub enum AtBlobHeaderNextError {
395 #[error("not found")]
399 NotFound,
400 #[error("read: {0}")]
402 Read(endpoint::ReadError),
403 #[error("io: {0}")]
405 Io(io::Error),
406 }
407
408 impl From<AtBlobHeaderNextError> for io::Error {
409 fn from(cause: AtBlobHeaderNextError) -> Self {
410 match cause {
411 AtBlobHeaderNextError::NotFound => {
412 io::Error::new(io::ErrorKind::UnexpectedEof, cause)
413 }
414 AtBlobHeaderNextError::Read(cause) => cause.into(),
415 AtBlobHeaderNextError::Io(cause) => cause,
416 }
417 }
418 }
419
420 impl AtBlobHeader {
421 pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> {
423 let size = self.reader.read::<8>().await.map_err(|cause| {
424 if cause.kind() == io::ErrorKind::UnexpectedEof {
425 AtBlobHeaderNextError::NotFound
426 } else if let Some(e) = cause
427 .get_ref()
428 .and_then(|x| x.downcast_ref::<endpoint::ReadError>())
429 {
430 AtBlobHeaderNextError::Read(e.clone())
431 } else {
432 AtBlobHeaderNextError::Io(cause)
433 }
434 })?;
435 let size = u64::from_le_bytes(size);
436 let stream = ResponseDecoder::new(
437 self.hash.into(),
438 self.ranges,
439 BaoTree::new(size, IROH_BLOCK_SIZE),
440 self.reader,
441 );
442 Ok((
443 AtBlobContent {
444 stream,
445 misc: self.misc,
446 },
447 size,
448 ))
449 }
450
451 pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
453 let (content, _size) = self.next().await?;
454 content.drain().await
455 }
456
457 pub async fn concatenate_into_vec(
462 self,
463 ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
464 let (content, _size) = self.next().await?;
465 content.concatenate_into_vec().await
466 }
467
468 pub async fn write_all<D: AsyncSliceWriter>(
470 self,
471 data: D,
472 ) -> result::Result<AtEndBlob, DecodeError> {
473 let (content, _size) = self.next().await?;
474 let res = content.write_all(data).await?;
475 Ok(res)
476 }
477
478 pub async fn write_all_with_outboard<D, O>(
483 self,
484 outboard: Option<O>,
485 data: D,
486 ) -> result::Result<AtEndBlob, DecodeError>
487 where
488 D: AsyncSliceWriter,
489 O: OutboardMut,
490 {
491 let (content, _size) = self.next().await?;
492 let res = content.write_all_with_outboard(outboard, data).await?;
493 Ok(res)
494 }
495
496 pub async fn write_all_batch<B>(self, batch: B) -> result::Result<AtEndBlob, DecodeError>
498 where
499 B: BaoBatchWriter,
500 {
501 let (content, _size) = self.next().await?;
502 let res = content.write_all_batch(batch).await?;
503 Ok(res)
504 }
505
506 pub fn hash(&self) -> Hash {
508 self.hash
509 }
510
511 pub fn ranges(&self) -> &ChunkRanges {
513 &self.ranges
514 }
515
516 pub fn offset(&self) -> u64 {
518 self.misc.ranges_iter.offset()
519 }
520 }
521
522 #[derive(Debug)]
524 pub struct AtBlobContent {
525 stream: ResponseDecoder<WrappedRecvStream>,
526 misc: Box<Misc>,
527 }
528
529 #[derive(Debug, thiserror::Error)]
553 pub enum DecodeError {
554 #[error("not found")]
556 NotFound,
557 #[error("parent not found {0:?}")]
559 ParentNotFound(TreeNode),
560 #[error("chunk not found {0}")]
562 LeafNotFound(ChunkNum),
563 #[error("parent hash mismatch: {0:?}")]
565 ParentHashMismatch(TreeNode),
566 #[error("leaf hash mismatch: {0}")]
568 LeafHashMismatch(ChunkNum),
569 #[error("read: {0}")]
571 Read(endpoint::ReadError),
572 #[error("io: {0}")]
574 Io(#[from] io::Error),
575 }
576
577 impl From<AtBlobHeaderNextError> for DecodeError {
578 fn from(cause: AtBlobHeaderNextError) -> Self {
579 match cause {
580 AtBlobHeaderNextError::NotFound => Self::NotFound,
581 AtBlobHeaderNextError::Read(cause) => Self::Read(cause),
582 AtBlobHeaderNextError::Io(cause) => Self::Io(cause),
583 }
584 }
585 }
586
587 impl From<DecodeError> for io::Error {
588 fn from(cause: DecodeError) -> Self {
589 match cause {
590 DecodeError::ParentNotFound(_) => {
591 io::Error::new(io::ErrorKind::UnexpectedEof, cause)
592 }
593 DecodeError::LeafNotFound(_) => io::Error::new(io::ErrorKind::UnexpectedEof, cause),
594 DecodeError::Read(cause) => cause.into(),
595 DecodeError::Io(cause) => cause,
596 _ => io::Error::new(io::ErrorKind::Other, cause),
597 }
598 }
599 }
600
601 impl From<bao_tree::io::DecodeError> for DecodeError {
602 fn from(value: bao_tree::io::DecodeError) -> Self {
603 match value {
604 bao_tree::io::DecodeError::ParentNotFound(x) => Self::ParentNotFound(x),
605 bao_tree::io::DecodeError::LeafNotFound(x) => Self::LeafNotFound(x),
606 bao_tree::io::DecodeError::ParentHashMismatch(node) => {
607 Self::ParentHashMismatch(node)
608 }
609 bao_tree::io::DecodeError::LeafHashMismatch(chunk) => Self::LeafHashMismatch(chunk),
610 bao_tree::io::DecodeError::Io(cause) => {
611 if let Some(inner) = cause.get_ref() {
612 if let Some(e) = inner.downcast_ref::<endpoint::ReadError>() {
613 Self::Read(e.clone())
614 } else {
615 Self::Io(cause)
616 }
617 } else {
618 Self::Io(cause)
619 }
620 }
621 }
622 }
623 }
624
625 #[derive(Debug, From)]
627 pub enum BlobContentNext {
628 More((AtBlobContent, result::Result<BaoContentItem, DecodeError>)),
630 Done(AtEndBlob),
632 }
633
634 impl AtBlobContent {
635 pub async fn next(self) -> BlobContentNext {
637 match self.stream.next().await {
638 ResponseDecoderNext::More((stream, res)) => {
639 let next = Self { stream, ..self };
640 let res = res.map_err(DecodeError::from);
641 BlobContentNext::More((next, res))
642 }
643 ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob {
644 stream,
645 misc: self.misc,
646 }),
647 }
648 }
649
650 pub fn tree(&self) -> bao_tree::BaoTree {
652 self.stream.tree()
653 }
654
655 pub fn hash(&self) -> Hash {
657 (*self.stream.hash()).into()
658 }
659
660 pub fn offset(&self) -> u64 {
662 self.misc.ranges_iter.offset()
663 }
664
665 pub async fn drain(self) -> result::Result<AtEndBlob, DecodeError> {
667 let mut content = self;
668 loop {
669 match content.next().await {
670 BlobContentNext::More((content1, res)) => {
671 let _ = res?;
672 content = content1;
673 }
674 BlobContentNext::Done(end) => {
675 break Ok(end);
676 }
677 }
678 }
679 }
680
681 pub async fn concatenate_into_vec(
683 self,
684 ) -> result::Result<(AtEndBlob, Vec<u8>), DecodeError> {
685 let mut res = Vec::with_capacity(1024);
686 let mut curr = self;
687 let done = loop {
688 match curr.next().await {
689 BlobContentNext::More((next, data)) => {
690 if let BaoContentItem::Leaf(leaf) = data? {
691 res.extend_from_slice(&leaf.data);
692 }
693 curr = next;
694 }
695 BlobContentNext::Done(done) => {
696 break done;
698 }
699 }
700 };
701 Ok((done, res))
702 }
703
704 pub async fn write_all_batch<B>(self, writer: B) -> result::Result<AtEndBlob, DecodeError>
706 where
707 B: BaoBatchWriter,
708 {
709 let mut writer = writer;
710 let mut buf = Vec::new();
711 let mut content = self;
712 let size = content.tree().size();
713 loop {
714 match content.next().await {
715 BlobContentNext::More((next, item)) => {
716 let item = item?;
717 match &item {
718 BaoContentItem::Parent(_) => {
719 buf.push(item);
720 }
721 BaoContentItem::Leaf(_) => {
722 buf.push(item);
723 let batch = std::mem::take(&mut buf);
724 writer.write_batch(size, batch).await?;
725 }
726 }
727 content = next;
728 }
729 BlobContentNext::Done(end) => {
730 assert!(buf.is_empty());
731 return Ok(end);
732 }
733 }
734 }
735 }
736
737 pub async fn write_all_with_outboard<D, O>(
742 self,
743 mut outboard: Option<O>,
744 mut data: D,
745 ) -> result::Result<AtEndBlob, DecodeError>
746 where
747 D: AsyncSliceWriter,
748 O: OutboardMut,
749 {
750 let mut content = self;
751 loop {
752 match content.next().await {
753 BlobContentNext::More((content1, item)) => {
754 content = content1;
755 match item? {
756 BaoContentItem::Parent(parent) => {
757 if let Some(outboard) = outboard.as_mut() {
758 outboard.save(parent.node, &parent.pair).await?;
759 }
760 }
761 BaoContentItem::Leaf(leaf) => {
762 data.write_bytes_at(leaf.offset, leaf.data).await?;
763 }
764 }
765 }
766 BlobContentNext::Done(end) => {
767 return Ok(end);
768 }
769 }
770 }
771 }
772
773 pub async fn write_all<D>(self, mut data: D) -> result::Result<AtEndBlob, DecodeError>
775 where
776 D: AsyncSliceWriter,
777 {
778 let mut content = self;
779 loop {
780 match content.next().await {
781 BlobContentNext::More((content1, item)) => {
782 content = content1;
783 match item? {
784 BaoContentItem::Parent(_) => {}
785 BaoContentItem::Leaf(leaf) => {
786 data.write_bytes_at(leaf.offset, leaf.data).await?;
787 }
788 }
789 }
790 BlobContentNext::Done(end) => {
791 return Ok(end);
792 }
793 }
794 }
795 }
796
797 pub fn finish(self) -> AtClosing {
799 AtClosing::new(self.misc, self.stream.finish(), false)
800 }
801 }
802
803 #[derive(Debug)]
805 pub struct AtEndBlob {
806 stream: WrappedRecvStream,
807 misc: Box<Misc>,
808 }
809
810 #[derive(Debug, From)]
812 pub enum EndBlobNext {
813 MoreChildren(AtStartChild),
815 Closing(AtClosing),
817 }
818
819 impl AtEndBlob {
820 pub fn next(mut self) -> EndBlobNext {
822 if let Some((offset, ranges)) = self.misc.ranges_iter.next() {
823 AtStartChild {
824 reader: self.stream,
825 child_offset: offset - 1,
826 ranges,
827 misc: self.misc,
828 }
829 .into()
830 } else {
831 AtClosing::new(self.misc, self.stream, true).into()
832 }
833 }
834 }
835
836 #[derive(Debug)]
838 pub struct AtClosing {
839 misc: Box<Misc>,
840 reader: WrappedRecvStream,
841 check_extra_data: bool,
842 }
843
844 impl AtClosing {
845 fn new(misc: Box<Misc>, reader: WrappedRecvStream, check_extra_data: bool) -> Self {
846 Self {
847 misc,
848 reader,
849 check_extra_data,
850 }
851 }
852
853 pub async fn next(self) -> result::Result<Stats, endpoint::ReadError> {
855 let (reader, bytes_read) = self.reader.into_parts();
857 let mut reader = reader.into_inner();
858 if self.check_extra_data {
859 if let Some(chunk) = reader.read_chunk(8, false).await? {
860 reader.stop(0u8.into()).ok();
861 error!("Received unexpected data from the provider: {chunk:?}");
862 }
863 } else {
864 reader.stop(0u8.into()).ok();
865 }
866 Ok(Stats {
867 elapsed: self.misc.start.elapsed(),
868 bytes_written: self.misc.bytes_written,
869 bytes_read,
870 })
871 }
872 }
873
874 #[derive(Debug)]
876 struct Misc {
877 start: Instant,
879 bytes_written: u64,
881 ranges_iter: RangesIter,
883 }
884}
885
886#[derive(thiserror::Error, Debug)]
888pub enum GetResponseError {
889 #[error("connection: {0}")]
891 Connection(#[from] endpoint::ConnectionError),
892 #[error("write: {0}")]
894 Write(#[from] endpoint::WriteError),
895 #[error("read: {0}")]
897 Read(#[from] endpoint::ReadError),
898 #[error("decode: {0}")]
900 Decode(bao_tree::io::DecodeError),
901 #[error("generic: {0}")]
903 Generic(anyhow::Error),
904}
905
906impl From<postcard::Error> for GetResponseError {
907 fn from(cause: postcard::Error) -> Self {
908 Self::Generic(cause.into())
909 }
910}
911
912impl From<bao_tree::io::DecodeError> for GetResponseError {
913 fn from(cause: bao_tree::io::DecodeError) -> Self {
914 match cause {
915 bao_tree::io::DecodeError::Io(cause) => {
916 if let Some(source) = cause.source() {
918 if let Some(error) = source.downcast_ref::<endpoint::ConnectionError>() {
919 return Self::Connection(error.clone());
920 }
921 if let Some(error) = source.downcast_ref::<endpoint::ReadError>() {
922 return Self::Read(error.clone());
923 }
924 if let Some(error) = source.downcast_ref::<endpoint::WriteError>() {
925 return Self::Write(error.clone());
926 }
927 }
928 Self::Generic(cause.into())
929 }
930 _ => Self::Decode(cause),
931 }
932 }
933}
934
935impl From<anyhow::Error> for GetResponseError {
936 fn from(cause: anyhow::Error) -> Self {
937 Self::Generic(cause)
938 }
939}
940
941impl From<GetResponseError> for std::io::Error {
942 fn from(cause: GetResponseError) -> Self {
943 Self::new(std::io::ErrorKind::Other, cause)
944 }
945}