1use std::{
5 collections::BTreeMap,
6 future::{Future, IntoFuture},
7 num::NonZeroU64,
8 sync::Arc,
9};
10
11use bao_tree::{
12 io::{BaoContentItem, Leaf},
13 ChunkNum, ChunkRanges,
14};
15use genawaiter::sync::{Co, Gen};
16use iroh::endpoint::Connection;
17use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
18use n0_future::{io, Stream, StreamExt};
19use n0_snafu::SpanTrace;
20use nested_enum_utils::common_fields;
21use ref_cast::RefCast;
22use snafu::{Backtrace, IntoError, ResultExt, Snafu};
23use tracing::{debug, trace};
24
25use super::blobs::{Bitfield, ExportBaoOptions};
26use crate::{
27 api::{
28 self,
29 blobs::{Blobs, WriteProgress},
30 ApiClient, Store,
31 },
32 get::{
33 fsm::{
34 AtBlobHeader, AtConnected, AtEndBlob, BlobContentNext, ConnectedNext, DecodeError,
35 EndBlobNext,
36 },
37 get_error::{BadRequestSnafu, LocalFailureSnafu},
38 GetError, GetResult, Stats, StreamPair,
39 },
40 hashseq::{HashSeq, HashSeqIter},
41 protocol::{
42 ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest,
43 Request, RequestType, MAX_MESSAGE_SIZE,
44 },
45 provider::events::{ClientResult, ProgressError},
46 store::IROH_BLOCK_SIZE,
47 util::{
48 sink::{Sink, TokioMpscSenderSink},
49 RecvStream, SendStream,
50 },
51 Hash, HashAndFormat,
52};
53
54#[derive(Debug, Clone, RefCast)]
70#[repr(transparent)]
71pub struct Remote {
72 client: ApiClient,
73}
74
75#[derive(Debug)]
76pub enum GetProgressItem {
77 Progress(u64),
79 Done(Stats),
81 Error(GetError),
83}
84
85impl From<GetResult<Stats>> for GetProgressItem {
86 fn from(res: GetResult<Stats>) -> Self {
87 match res {
88 Ok(stats) => GetProgressItem::Done(stats),
89 Err(e) => GetProgressItem::Error(e),
90 }
91 }
92}
93
94impl TryFrom<GetProgressItem> for GetResult<Stats> {
95 type Error = &'static str;
96
97 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
98 match item {
99 GetProgressItem::Done(stats) => Ok(Ok(stats)),
100 GetProgressItem::Error(e) => Ok(Err(e)),
101 GetProgressItem::Progress(_) => Err("not a final item"),
102 }
103 }
104}
105
106pub struct GetProgress {
107 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
108 fut: n0_future::boxed::BoxFuture<()>,
109}
110
111impl IntoFuture for GetProgress {
112 type Output = GetResult<Stats>;
113 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
114
115 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
116 Box::pin(self.complete())
117 }
118}
119
120impl GetProgress {
121 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
122 into_stream(self.rx, self.fut)
123 }
124
125 pub async fn complete(self) -> GetResult<Stats> {
126 just_result(self.stream()).await.unwrap_or_else(|| {
127 Err(LocalFailureSnafu.into_error(anyhow::anyhow!("stream closed without result")))
128 })
129 }
130}
131
132#[derive(Debug)]
133pub enum PushProgressItem {
134 Progress(u64),
136 Done(Stats),
138 Error(anyhow::Error),
140}
141
142impl From<anyhow::Result<Stats>> for PushProgressItem {
143 fn from(res: anyhow::Result<Stats>) -> Self {
144 match res {
145 Ok(stats) => Self::Done(stats),
146 Err(e) => Self::Error(e),
147 }
148 }
149}
150
151impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
152 type Error = &'static str;
153
154 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
155 match item {
156 PushProgressItem::Done(stats) => Ok(Ok(stats)),
157 PushProgressItem::Error(e) => Ok(Err(e)),
158 PushProgressItem::Progress(_) => Err("not a final item"),
159 }
160 }
161}
162
163pub struct PushProgress {
164 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
165 fut: n0_future::boxed::BoxFuture<()>,
166}
167
168impl IntoFuture for PushProgress {
169 type Output = anyhow::Result<Stats>;
170 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
171
172 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
173 Box::pin(self.complete())
174 }
175}
176
177impl PushProgress {
178 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
179 into_stream(self.rx, self.fut)
180 }
181
182 pub async fn complete(self) -> anyhow::Result<Stats> {
183 just_result(self.stream())
184 .await
185 .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
186 }
187}
188
189async fn just_result<S, R>(stream: S) -> Option<R>
190where
191 S: Stream<Item: std::fmt::Debug>,
192 R: TryFrom<S::Item>,
193{
194 tokio::pin!(stream);
195 while let Some(item) = stream.next().await {
196 if let Ok(res) = R::try_from(item) {
197 return Some(res);
198 }
199 }
200 None
201}
202
203fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
204where
205 F: Future,
206{
207 Gen::new(move |co| async move {
208 tokio::pin!(fut);
209 loop {
210 tokio::select! {
211 biased;
212 item = rx.recv() => {
213 if let Some(item) = item {
214 co.yield_(item).await;
215 } else {
216 break;
217 }
218 }
219 _ = &mut fut => {
220 break;
221 }
222 }
223 }
224 while let Some(item) = rx.recv().await {
225 co.yield_(item).await;
226 }
227 })
228}
229
230#[derive(Debug)]
235pub struct LocalInfo {
236 request: Arc<GetRequest>,
238 bitfield: Bitfield,
240 children: Option<NonRawLocalInfo>,
242}
243
244impl LocalInfo {
245 pub fn local_bytes(&self) -> u64 {
247 let Some(root_requested) = self.requested_root_ranges() else {
248 return 0;
250 };
251 let mut local = self.bitfield.clone();
252 local.ranges.intersection_with(root_requested);
253 let mut res = local.total_bytes();
254 if let Some(children) = &self.children {
255 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
256 return res;
258 };
259 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
260 if offset == 0 {
261 continue;
263 }
264 let child = offset - 1;
265 if child > *max_local_index {
266 break;
268 }
269 let Some(hash) = children.hash_seq.get(&child) else {
270 continue;
271 };
272 let bitfield = &children.bitfields[hash];
273 let mut local = bitfield.clone();
274 local.ranges.intersection_with(ranges);
275 res += local.total_bytes();
276 }
277 }
278 res
279 }
280
281 pub fn children(&self) -> Option<u64> {
283 if self.children.is_some() {
284 self.bitfield.validated_size().map(|x| x / 32)
285 } else {
286 Some(0)
287 }
288 }
289
290 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
295 self.request.ranges.iter().next()
296 }
297
298 pub fn is_complete(&self) -> bool {
304 let Some(root_requested) = self.requested_root_ranges() else {
305 return true;
307 };
308 if !self.bitfield.ranges.is_superset(root_requested) {
309 return false;
310 }
311 if let Some(children) = self.children.as_ref() {
312 let mut iter = self.request.ranges.iter_non_empty_infinite();
313 let max_child = self.bitfield.validated_size().map(|x| x / 32);
314 loop {
315 let Some((offset, range)) = iter.next() else {
316 break;
317 };
318 if offset == 0 {
319 continue;
321 }
322 let child = offset - 1;
323 if let Some(hash) = children.hash_seq.get(&child) {
324 let bitfield = &children.bitfields[hash];
325 if !bitfield.ranges.is_superset(range) {
326 return false;
328 }
329 } else {
330 if let Some(max_child) = max_child {
331 if child >= max_child {
332 return true;
334 }
335 }
336 return false;
337 }
338 }
339 }
340 true
341 }
342
343 pub fn missing(&self) -> GetRequest {
345 let Some(root_requested) = self.requested_root_ranges() else {
346 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
348 };
349 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
350
351 let Some(children) = self.children.as_ref() else {
352 return builder.build(self.request.hash);
353 };
354 let mut iter = self.request.ranges.iter_non_empty_infinite();
355 let max_local = children
356 .hash_seq
357 .keys()
358 .next_back()
359 .map(|x| *x + 1)
360 .unwrap_or_default();
361 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
362 loop {
363 let Some((offset, requested)) = iter.next() else {
364 break;
365 };
366 if offset == 0 {
367 continue;
369 }
370 let child = offset - 1;
371 let missing = match children.hash_seq.get(&child) {
372 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
373 None => requested.clone(),
374 };
375 builder = builder.child(child, missing);
376 if offset >= max_local {
377 break;
379 }
380 }
381 loop {
382 let Some((offset, requested)) = iter.next() else {
383 return builder.build(self.request.hash);
384 };
385 if offset == 0 {
386 continue;
388 }
389 let child = offset - 1;
390 if let Some(max_offset) = &max_offset {
391 if child >= *max_offset {
392 return builder.build(self.request.hash);
393 }
394 builder = builder.child(child, requested.clone());
395 } else {
396 builder = builder.child(child, requested.clone());
397 if iter.is_at_end() {
398 if iter.next().is_none() {
399 return builder.build(self.request.hash);
400 } else {
401 return builder.build_open(self.request.hash);
402 }
403 }
404 }
405 }
406 }
407}
408
409#[derive(Debug)]
410struct NonRawLocalInfo {
411 hash_seq: BTreeMap<u64, Hash>,
413 bitfields: BTreeMap<Hash, Bitfield>,
416}
417
418impl Remote {
433 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
434 Self::ref_cast(sender)
435 }
436
437 fn store(&self) -> &Store {
438 Store::ref_from_sender(&self.client)
439 }
440
441 pub async fn local_for_request(
442 &self,
443 request: impl Into<Arc<GetRequest>>,
444 ) -> anyhow::Result<LocalInfo> {
445 let request = request.into();
446 let root = request.hash;
447 let bitfield = self.store().observe(root).await?;
448 let children = if !request.ranges.is_blob() {
449 let opts = ExportBaoOptions {
450 hash: root,
451 ranges: bitfield.ranges.clone(),
452 };
453 let bao = self.store().export_bao_with_opts(opts, 32);
454 let mut by_index = BTreeMap::new();
455 let mut stream = bao.hashes_with_index();
456 while let Some(item) = stream.next().await {
457 if let Ok((index, hash)) = item {
458 by_index.insert(index, hash);
459 }
460 }
461 let mut bitfields = BTreeMap::new();
462 let mut hash_seq = BTreeMap::new();
463 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
464 for (index, _) in request.ranges.iter_non_empty_infinite() {
465 if index == 0 {
466 continue;
468 }
469 let child = index - 1;
470 if child > max {
471 break;
473 }
474 let Some(hash) = by_index.get(&child) else {
475 continue;
477 };
478 let bitfield = self.store().observe(*hash).await?;
479 bitfields.insert(*hash, bitfield);
480 hash_seq.insert(child, *hash);
481 }
482 Some(NonRawLocalInfo {
483 hash_seq,
484 bitfields,
485 })
486 } else {
487 None
488 };
489 Ok(LocalInfo {
490 request: request.clone(),
491 bitfield,
492 children,
493 })
494 }
495
496 pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
498 let request = GetRequest::from(content.into());
499 self.local_for_request(request).await
500 }
501
502 pub fn fetch(
503 &self,
504 sp: impl GetStreamPair + 'static,
505 content: impl Into<HashAndFormat>,
506 ) -> GetProgress {
507 let content = content.into();
508 let (tx, rx) = tokio::sync::mpsc::channel(64);
509 let tx2 = tx.clone();
510 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
511 let this = self.clone();
512 let fut = async move {
513 let res = this.fetch_sink(sp, content, sink).await.into();
514 tx2.send(res).await.ok();
515 };
516 GetProgress {
517 rx,
518 fut: Box::pin(fut),
519 }
520 }
521
522 pub(crate) async fn fetch_sink(
530 &self,
531 sp: impl GetStreamPair,
532 content: impl Into<HashAndFormat>,
533 progress: impl Sink<u64, Error = irpc::channel::SendError>,
534 ) -> GetResult<Stats> {
535 let content = content.into();
536 let local = self
537 .local(content)
538 .await
539 .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e))?;
540 if local.is_complete() {
541 return Ok(Default::default());
542 }
543 let request = local.missing();
544 let stats = self.execute_get_sink(sp, request, progress).await?;
545 Ok(stats)
546 }
547
548 pub fn observe(
549 &self,
550 conn: Connection,
551 request: ObserveRequest,
552 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
553 Gen::new(|co| async move {
554 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
555 co.yield_(Err(cause)).await
556 }
557 })
558 }
559
560 async fn observe_impl(
561 conn: Connection,
562 request: ObserveRequest,
563 co: &Co<io::Result<Bitfield>>,
564 ) -> io::Result<()> {
565 let hash = request.hash;
566 debug!(%hash, "observing");
567 let (mut send, mut recv) = conn.open_bi().await?;
568 write_observe_request(request, &mut send).await?;
570 send.finish()?;
571 loop {
572 let msg = recv
573 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
574 .await?;
575 co.yield_(Ok(Bitfield::from(&msg))).await;
576 }
577 }
578
579 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
580 let (tx, rx) = tokio::sync::mpsc::channel(64);
581 let tx2 = tx.clone();
582 let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
583 let this = self.clone();
584 let fut = async move {
585 let res = this.execute_push_sink(conn, request, sink).await.into();
586 tx2.send(res).await.ok();
587 };
588 PushProgress {
589 rx,
590 fut: Box::pin(fut),
591 }
592 }
593
594 pub(crate) async fn execute_push_sink(
598 &self,
599 conn: Connection,
600 request: PushRequest,
601 progress: impl Sink<u64, Error = irpc::channel::SendError>,
602 ) -> anyhow::Result<Stats> {
603 let hash = request.hash;
604 debug!(%hash, "pushing");
605 let (mut send, mut recv) = conn.open_bi().await?;
606 let mut context = StreamContext {
607 payload_bytes_sent: 0,
608 sender: progress,
609 };
610 recv.stop(0u32.into())?;
612 let request = write_push_request(request, &mut send).await?;
614 let mut request_ranges = request.ranges.iter_infinite();
615 let root = request.hash;
616 let root_ranges = request_ranges.next().expect("infinite iterator");
617 if !root_ranges.is_empty() {
618 self.store()
619 .export_bao(root, root_ranges.clone())
620 .write_with_progress(&mut send, &mut context, &root, 0)
621 .await?;
622 }
623 if request.ranges.is_blob() {
624 send.finish()?;
626 return Ok(Default::default());
627 }
628 let hash_seq = self.store().get_bytes(root).await?;
629 let hash_seq = HashSeq::try_from(hash_seq)?;
630 for (child, (child_hash, child_ranges)) in
631 hash_seq.into_iter().zip(request_ranges).enumerate()
632 {
633 if !child_ranges.is_empty() {
634 self.store()
635 .export_bao(child_hash, child_ranges.clone())
636 .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64)
637 .await?;
638 }
639 }
640 send.finish()?;
641 Ok(Default::default())
642 }
643
644 pub fn execute_get(&self, conn: impl GetStreamPair, request: GetRequest) -> GetProgress {
645 self.execute_get_with_opts(conn, request)
646 }
647
648 pub fn execute_get_with_opts(
649 &self,
650 conn: impl GetStreamPair,
651 request: GetRequest,
652 ) -> GetProgress {
653 let (tx, rx) = tokio::sync::mpsc::channel(64);
654 let tx2 = tx.clone();
655 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
656 let this = self.clone();
657 let fut = async move {
658 let res = this.execute_get_sink(conn, request, sink).await.into();
659 tx2.send(res).await.ok();
660 };
661 GetProgress {
662 rx,
663 fut: Box::pin(fut),
664 }
665 }
666
667 pub(crate) async fn execute_get_sink(
676 &self,
677 conn: impl GetStreamPair,
678 request: GetRequest,
679 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
680 ) -> GetResult<Stats> {
681 let store = self.store();
682 let root = request.hash;
683 let conn = conn.open_stream_pair().await.map_err(|e| {
684 LocalFailureSnafu.into_error(anyhow::anyhow!("failed to open stream pair: {e}"))
685 })?;
686 let connected =
689 AtConnected::new(conn.t0, conn.recv, conn.send, request, Default::default());
690 trace!("Getting header");
691 let next_child = match connected.next().await? {
693 ConnectedNext::StartRoot(at_start_root) => {
694 let header = at_start_root.next();
695 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
696 match end.next() {
697 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
698 EndBlobNext::Closing(at_closing) => Err(at_closing),
699 }
700 }
701 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
702 ConnectedNext::Closing(at_closing) => Err(at_closing),
703 };
704 let at_closing = match next_child {
706 Ok(at_start_child) => {
707 let mut next_child = Ok(at_start_child);
708 let hash_seq = HashSeq::try_from(
709 store
710 .get_bytes(root)
711 .await
712 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
713 )
714 .context(BadRequestSnafu)?;
715 loop {
717 let at_start_child = match next_child {
718 Ok(at_start_child) => at_start_child,
719 Err(at_closing) => break at_closing,
720 };
721 let offset = at_start_child.offset() - 1;
722 let Some(hash) = hash_seq.get(offset as usize) else {
723 break at_start_child.finish();
724 };
725 trace!("getting child {offset} {}", hash.fmt_short());
726 let header = at_start_child.next(hash);
727 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
728 next_child = match end.next() {
729 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
730 EndBlobNext::Closing(at_closing) => Err(at_closing),
731 }
732 }
733 }
734 Err(at_closing) => at_closing,
735 };
736 let stats = at_closing.next().await?;
738 trace!(?stats, "get hash seq done");
739 Ok(stats)
740 }
741
742 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
743 let (tx, rx) = tokio::sync::mpsc::channel(64);
744 let tx2 = tx.clone();
745 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
746 let this = self.clone();
747 let fut = async move {
748 let res = this.execute_get_many_sink(conn, request, sink).await.into();
749 tx2.send(res).await.ok();
750 };
751 GetProgress {
752 rx,
753 fut: Box::pin(fut),
754 }
755 }
756
757 pub async fn execute_get_many_sink(
766 &self,
767 conn: Connection,
768 request: GetManyRequest,
769 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
770 ) -> GetResult<Stats> {
771 let store = self.store();
772 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
773 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
774 let at_closing = match next_child {
776 Ok(at_start_child) => {
777 let mut next_child = Ok(at_start_child);
778 loop {
779 let at_start_child = match next_child {
780 Ok(at_start_child) => at_start_child,
781 Err(at_closing) => break at_closing,
782 };
783 let offset = at_start_child.offset();
784 let Some(hash) = hash_seq.get(offset as usize) else {
785 break at_start_child.finish();
786 };
787 trace!("getting child {offset} {}", hash.fmt_short());
788 let header = at_start_child.next(hash);
789 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
790 next_child = match end.next() {
791 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
792 EndBlobNext::Closing(at_closing) => Err(at_closing),
793 }
794 }
795 }
796 Err(at_closing) => at_closing,
797 };
798 let stats = at_closing.next().await?;
800 trace!(?stats, "get hash seq done");
801 Ok(stats)
802 }
803}
804
805#[common_fields({
807 backtrace: Option<Backtrace>,
808 #[snafu(implicit)]
809 span_trace: SpanTrace,
810})]
811#[allow(missing_docs)]
812#[non_exhaustive]
813#[derive(Debug, Snafu)]
814pub enum ExecuteError {
815 #[snafu(display("Unable to open bidi stream"))]
817 Connection {
818 source: iroh::endpoint::ConnectionError,
819 },
820 #[snafu(display("Unable to read from the remote"))]
821 Read { source: iroh::endpoint::ReadError },
822 #[snafu(display("Error sending the request"))]
823 Send {
824 source: crate::get::fsm::ConnectedNextError,
825 },
826 #[snafu(display("Unable to read size"))]
827 Size {
828 source: crate::get::fsm::AtBlobHeaderNextError,
829 },
830 #[snafu(display("Error while decoding the data"))]
831 Decode {
832 source: crate::get::fsm::DecodeError,
833 },
834 #[snafu(display("Internal error while reading the hash sequence"))]
835 ExportBao { source: api::ExportBaoError },
836 #[snafu(display("Hash sequence has an invalid length"))]
837 InvalidHashSeq { source: anyhow::Error },
838 #[snafu(display("Internal error importing the data"))]
839 ImportBao { source: crate::api::RequestError },
840 #[snafu(display("Error sending download progress - receiver closed"))]
841 SendDownloadProgress { source: irpc::channel::SendError },
842 #[snafu(display("Internal error importing the data"))]
843 MpscSend {
844 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
845 },
846}
847
848pub trait GetStreamPair: Send + 'static {
849 fn open_stream_pair(
850 self,
851 ) -> impl Future<Output = io::Result<StreamPair<impl RecvStream, impl SendStream>>> + Send + 'static;
852}
853
854impl<R: RecvStream + 'static, W: SendStream + 'static> GetStreamPair for StreamPair<R, W> {
855 async fn open_stream_pair(self) -> io::Result<StreamPair<impl RecvStream, impl SendStream>> {
856 Ok(self)
857 }
858}
859
860impl GetStreamPair for Connection {
861 async fn open_stream_pair(
862 self,
863 ) -> io::Result<StreamPair<impl crate::util::RecvStream, impl crate::util::SendStream>> {
864 let connection_id = self.stable_id() as u64;
865 let (send, recv) = self.open_bi().await?;
866 Ok(StreamPair::new(connection_id, recv, send))
867 }
868}
869
870fn get_buffer_size(size: NonZeroU64) -> usize {
871 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
872}
873
874async fn get_blob_ranges_impl<R: RecvStream>(
875 header: AtBlobHeader<R>,
876 hash: Hash,
877 store: &Store,
878 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
879) -> GetResult<AtEndBlob<R>> {
880 let (mut content, size) = header.next().await?;
881 let Some(size) = NonZeroU64::new(size) else {
882 return if hash == Hash::EMPTY {
883 let end = content.drain().await?;
884 Ok(end)
885 } else {
886 Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
887 };
888 };
889 let buffer_size = get_buffer_size(size);
890 trace!(%size, %buffer_size, "get blob");
891 let handle = store
892 .import_bao(hash, size, buffer_size)
893 .await
894 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
895 let write = async move {
896 GetResult::Ok(loop {
897 match content.next().await {
898 BlobContentNext::More((next, res)) => {
899 let item = res?;
900 progress
901 .send(next.stats().payload_bytes_read)
902 .await
903 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
904 handle.tx.send(item).await?;
905 content = next;
906 }
907 BlobContentNext::Done(end) => {
908 drop(handle.tx);
909 break end;
910 }
911 }
912 })
913 };
914 let complete = async move {
915 handle.rx.await.map_err(|e| {
916 LocalFailureSnafu.into_error(anyhow::anyhow!("error reading from import stream: {e}"))
917 })
918 };
919 let (_, end) = tokio::try_join!(complete, write)?;
920 Ok(end)
921}
922
923#[derive(Debug)]
924pub(crate) struct LazyHashSeq {
925 blobs: Blobs,
926 hash: Hash,
927 current_chunk: Option<HashSeqChunk>,
928}
929
930#[derive(Debug)]
931pub(crate) struct HashSeqChunk {
932 offset: u64,
934 chunk: HashSeq,
936}
937
938impl TryFrom<Leaf> for HashSeqChunk {
939 type Error = anyhow::Error;
940
941 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
942 let offset = leaf.offset;
943 let chunk = HashSeq::try_from(leaf.data)?;
944 Ok(Self { offset, chunk })
945 }
946}
947
948impl IntoIterator for HashSeqChunk {
949 type Item = Hash;
950 type IntoIter = HashSeqIter;
951
952 fn into_iter(self) -> Self::IntoIter {
953 self.chunk.into_iter()
954 }
955}
956
957impl HashSeqChunk {
958 pub fn base(&self) -> u64 {
959 self.offset / 32
960 }
961
962 #[allow(dead_code)]
963 fn get(&self, offset: u64) -> Option<Hash> {
964 let start = self.offset;
965 let end = start + self.chunk.len() as u64;
966 if offset >= start && offset < end {
967 let o = (offset - start) as usize;
968 self.chunk.get(o)
969 } else {
970 None
971 }
972 }
973}
974
975impl LazyHashSeq {
976 #[allow(dead_code)]
977 pub fn new(blobs: Blobs, hash: Hash) -> Self {
978 Self {
979 blobs,
980 hash,
981 current_chunk: None,
982 }
983 }
984
985 #[allow(dead_code)]
986 pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
987 if offset == 0 {
988 Ok(Some(self.hash))
989 } else {
990 self.get(offset - 1).await
991 }
992 }
993
994 #[allow(dead_code)]
995 pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
996 if let Some(chunk) = &self.current_chunk {
998 if let Some(hash) = chunk.get(child_offset) {
999 return Ok(Some(hash));
1000 }
1001 }
1002 let leaf = self
1004 .blobs
1005 .export_chunk(self.hash, child_offset * 32)
1006 .await?;
1007 let hs = HashSeqChunk::try_from(leaf)?;
1009 Ok(hs.get(child_offset).inspect(|_hash| {
1010 self.current_chunk = Some(hs);
1011 }))
1012 }
1013}
1014
1015async fn write_push_request(
1016 request: PushRequest,
1017 stream: &mut impl SendStream,
1018) -> anyhow::Result<PushRequest> {
1019 let mut request_bytes = Vec::new();
1020 request_bytes.push(RequestType::Push as u8);
1021 request_bytes.write_length_prefixed(&request).unwrap();
1022 stream.send_bytes(request_bytes.into()).await?;
1023 Ok(request)
1024}
1025
1026async fn write_observe_request(
1027 request: ObserveRequest,
1028 stream: &mut impl SendStream,
1029) -> io::Result<()> {
1030 let request = Request::Observe(request);
1031 let request_bytes = postcard::to_allocvec(&request)
1032 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1033 stream.send_bytes(request_bytes.into()).await?;
1034 Ok(())
1035}
1036
1037struct StreamContext<S> {
1038 payload_bytes_sent: u64,
1039 sender: S,
1040}
1041
1042impl<S> WriteProgress for StreamContext<S>
1043where
1044 S: Sink<u64, Error = irpc::channel::SendError>,
1045{
1046 async fn notify_payload_write(
1047 &mut self,
1048 _index: u64,
1049 _offset: u64,
1050 len: usize,
1051 ) -> ClientResult {
1052 self.payload_bytes_sent += len as u64;
1053 self.sender
1054 .send(self.payload_bytes_sent)
1055 .await
1056 .map_err(|e| ProgressError::Internal { source: e.into() })?;
1057 Ok(())
1058 }
1059
1060 fn log_other_write(&mut self, _len: usize) {}
1061
1062 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1063}
1064
1065#[cfg(test)]
1066#[cfg(feature = "fs-store")]
1067mod tests {
1068 use bao_tree::{ChunkNum, ChunkRanges};
1069 use testresult::TestResult;
1070
1071 use crate::{
1072 api::blobs::Blobs,
1073 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
1074 store::{
1075 fs::{
1076 tests::{test_data, INTERESTING_SIZES},
1077 FsStore,
1078 },
1079 mem::MemStore,
1080 util::tests::create_n0_bao,
1081 },
1082 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1083 };
1084
1085 #[tokio::test]
1086 async fn test_local_info_raw() -> TestResult<()> {
1087 let td = tempfile::tempdir()?;
1088 let store = FsStore::load(td.path().join("blobs.db")).await?;
1089 let blobs = store.blobs();
1090 let tt = blobs.add_slice(b"test").temp_tag().await?;
1091 let hash = tt.hash();
1092 let info = store.remote().local(hash).await?;
1093 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1094 assert_eq!(info.local_bytes(), 4);
1095 assert!(info.is_complete());
1096 assert_eq!(
1097 info.missing(),
1098 GetRequest::new(hash, ChunkRangesSeq::empty())
1099 );
1100 Ok(())
1101 }
1102
1103 #[tokio::test]
1104 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1105 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1106 let relevant_sizes = sizes[32 * 16..32 * 32]
1107 .iter()
1108 .map(|x| *x as u64)
1109 .sum::<u64>();
1110 let td = tempfile::tempdir()?;
1111 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1112 let store = FsStore::load(td.path().join("blobs.db")).await?;
1113 {
1114 let present = |i| {
1116 if i == 0 {
1117 hash_seq_ranges.clone()
1118 } else {
1119 ChunkRanges::from(..ChunkNum(1))
1120 }
1121 };
1122 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1123 let info = store.remote().local(content).await?;
1124 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1125 assert!(!info.is_complete());
1126 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1127 }
1128
1129 Ok(())
1130 }
1131
1132 async fn test_observe_partial(blobs: &Blobs) -> TestResult<()> {
1133 let sizes = INTERESTING_SIZES;
1134 for size in sizes {
1135 let data = test_data(size);
1136 let ranges = ChunkRanges::chunk(0);
1137 let (hash, bao) = create_n0_bao(&data, &ranges)?;
1138 blobs.import_bao_bytes(hash, ranges.clone(), bao).await?;
1139 let bitfield = blobs.observe(hash).await?;
1140 if size > 1024 {
1141 assert_eq!(bitfield.ranges, ranges);
1142 } else {
1143 assert_eq!(bitfield.ranges, ChunkRanges::all());
1144 }
1145 }
1146 Ok(())
1147 }
1148
1149 #[tokio::test]
1150 async fn test_observe_partial_mem() -> TestResult<()> {
1151 let store = MemStore::new();
1152 test_observe_partial(store.blobs()).await?;
1153 Ok(())
1154 }
1155
1156 #[tokio::test]
1157 async fn test_observe_partial_fs() -> TestResult<()> {
1158 let td = tempfile::tempdir()?;
1159 let store = FsStore::load(td.path()).await?;
1160 test_observe_partial(store.blobs()).await?;
1161 Ok(())
1162 }
1163
1164 #[tokio::test]
1165 async fn test_local_info_hash_seq() -> TestResult<()> {
1166 let sizes = INTERESTING_SIZES;
1167 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1168 let hash_seq_size = (sizes.len() as u64) * 32;
1169 let td = tempfile::tempdir()?;
1170 let store = FsStore::load(td.path().join("blobs.db")).await?;
1171 {
1172 let present = |i| {
1174 if i == 0 {
1175 ChunkRanges::all()
1176 } else {
1177 ChunkRanges::empty()
1178 }
1179 };
1180 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1181 let info = store.remote().local(content).await?;
1182 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1183 assert_eq!(info.local_bytes(), hash_seq_size);
1184 assert!(!info.is_complete());
1185 assert_eq!(
1186 info.missing(),
1187 GetRequest::new(
1188 content.hash,
1189 ChunkRangesSeq::from_ranges([
1190 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1194 ChunkRanges::all(),
1195 ChunkRanges::all(),
1196 ChunkRanges::all(),
1197 ChunkRanges::all(),
1198 ChunkRanges::all(),
1199 ])
1200 )
1201 );
1202 store.tags().delete_all().await?;
1203 }
1204 {
1205 let present = |i| {
1207 if i == 0 {
1208 ChunkRanges::all()
1209 } else {
1210 ChunkRanges::from(..ChunkNum(1))
1211 }
1212 };
1213 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1214 let info = store.remote().local(content).await?;
1215 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1216 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1217 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1218 assert!(!info.is_complete());
1219 assert_eq!(
1220 info.missing(),
1221 GetRequest::new(
1222 content.hash,
1223 ChunkRangesSeq::from_ranges([
1224 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1229 ChunkRanges::chunks(1..),
1230 ChunkRanges::chunks(1..),
1231 ChunkRanges::chunks(1..),
1232 ChunkRanges::chunks(1..),
1233 ])
1234 )
1235 );
1236 }
1237 {
1238 let content = add_test_hash_seq(&store, sizes).await?;
1239 let info = store.remote().local(content).await?;
1240 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1241 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1242 assert!(info.is_complete());
1243 assert_eq!(
1244 info.missing(),
1245 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1246 );
1247 }
1248 Ok(())
1249 }
1250
1251 #[tokio::test]
1252 async fn test_local_info_complex_request() -> TestResult<()> {
1253 let sizes = INTERESTING_SIZES;
1254 let hash_seq_size = (sizes.len() as u64) * 32;
1255 let td = tempfile::tempdir()?;
1256 let store = FsStore::load(td.path().join("blobs.db")).await?;
1257 let present = |i| {
1259 if i == 0 {
1260 ChunkRanges::all()
1261 } else {
1262 ChunkRanges::chunks(..2)
1263 }
1264 };
1265 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1266 {
1267 let request: GetRequest = GetRequest::builder()
1268 .root(ChunkRanges::all())
1269 .build(content.hash);
1270 let info = store.remote().local_for_request(request).await?;
1271 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1272 assert_eq!(info.local_bytes(), hash_seq_size);
1273 assert!(info.is_complete());
1274 }
1275 {
1276 let request: GetRequest = GetRequest::builder()
1277 .root(ChunkRanges::all())
1278 .next(ChunkRanges::all())
1279 .build(content.hash);
1280 let info = store.remote().local_for_request(request).await?;
1281 let expected_child_sizes = sizes
1282 .into_iter()
1283 .take(1)
1284 .map(|x| 1024.min(x as u64))
1285 .sum::<u64>();
1286 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1287 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1288 assert!(info.is_complete());
1289 }
1290 {
1291 let request: GetRequest = GetRequest::builder()
1292 .root(ChunkRanges::all())
1293 .next(ChunkRanges::all())
1294 .next(ChunkRanges::all())
1295 .build(content.hash);
1296 let info = store.remote().local_for_request(request).await?;
1297 let expected_child_sizes = sizes
1298 .into_iter()
1299 .take(2)
1300 .map(|x| 1024.min(x as u64))
1301 .sum::<u64>();
1302 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1303 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1304 assert!(info.is_complete());
1305 }
1306 {
1307 let request: GetRequest = GetRequest::builder()
1308 .root(ChunkRanges::all())
1309 .next(ChunkRanges::chunk(0))
1310 .build_open(content.hash);
1311 let info = store.remote().local_for_request(request).await?;
1312 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1313 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1314 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1315 assert!(info.is_complete());
1316 }
1317 Ok(())
1318 }
1319}