1use genawaiter::sync::{Co, Gen};
5use iroh::endpoint::SendStream;
6use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
7use n0_future::{io, Stream, StreamExt};
8use n0_snafu::SpanTrace;
9use nested_enum_utils::common_fields;
10use ref_cast::RefCast;
11use snafu::{Backtrace, IntoError, Snafu};
12
13use super::blobs::{Bitfield, ExportBaoOptions};
14use crate::{
15 api::{blobs::WriteProgress, ApiClient},
16 get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats},
17 protocol::{
18 GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
19 MAX_MESSAGE_SIZE,
20 },
21 provider::events::{ClientResult, ProgressError},
22 util::sink::{Sink, TokioMpscSenderSink},
23};
24
25#[derive(Debug, Clone, RefCast)]
41#[repr(transparent)]
42pub struct Remote {
43 client: ApiClient,
44}
45
46#[derive(Debug)]
47pub enum GetProgressItem {
48 Progress(u64),
50 Done(Stats),
52 Error(GetError),
54}
55
56impl From<GetResult<Stats>> for GetProgressItem {
57 fn from(res: GetResult<Stats>) -> Self {
58 match res {
59 Ok(stats) => GetProgressItem::Done(stats),
60 Err(e) => GetProgressItem::Error(e),
61 }
62 }
63}
64
65impl TryFrom<GetProgressItem> for GetResult<Stats> {
66 type Error = &'static str;
67
68 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
69 match item {
70 GetProgressItem::Done(stats) => Ok(Ok(stats)),
71 GetProgressItem::Error(e) => Ok(Err(e)),
72 GetProgressItem::Progress(_) => Err("not a final item"),
73 }
74 }
75}
76
77pub struct GetProgress {
78 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
79 fut: n0_future::boxed::BoxFuture<()>,
80}
81
82impl IntoFuture for GetProgress {
83 type Output = GetResult<Stats>;
84 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
85
86 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
87 Box::pin(self.complete())
88 }
89}
90
91impl GetProgress {
92 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
93 into_stream(self.rx, self.fut)
94 }
95
96 pub async fn complete(self) -> GetResult<Stats> {
97 just_result(self.stream()).await.unwrap_or_else(|| {
98 Err(LocalFailureSnafu
99 .into_error(anyhow::anyhow!("stream closed without result").into()))
100 })
101 }
102}
103
104#[derive(Debug)]
105pub enum PushProgressItem {
106 Progress(u64),
108 Done(Stats),
110 Error(anyhow::Error),
112}
113
114impl From<anyhow::Result<Stats>> for PushProgressItem {
115 fn from(res: anyhow::Result<Stats>) -> Self {
116 match res {
117 Ok(stats) => Self::Done(stats),
118 Err(e) => Self::Error(e),
119 }
120 }
121}
122
123impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
124 type Error = &'static str;
125
126 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
127 match item {
128 PushProgressItem::Done(stats) => Ok(Ok(stats)),
129 PushProgressItem::Error(e) => Ok(Err(e)),
130 PushProgressItem::Progress(_) => Err("not a final item"),
131 }
132 }
133}
134
135pub struct PushProgress {
136 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
137 fut: n0_future::boxed::BoxFuture<()>,
138}
139
140impl IntoFuture for PushProgress {
141 type Output = anyhow::Result<Stats>;
142 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
143
144 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
145 Box::pin(self.complete())
146 }
147}
148
149impl PushProgress {
150 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
151 into_stream(self.rx, self.fut)
152 }
153
154 pub async fn complete(self) -> anyhow::Result<Stats> {
155 just_result(self.stream())
156 .await
157 .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
158 }
159}
160
161async fn just_result<S, R>(stream: S) -> Option<R>
162where
163 S: Stream<Item: std::fmt::Debug>,
164 R: TryFrom<S::Item>,
165{
166 tokio::pin!(stream);
167 while let Some(item) = stream.next().await {
168 if let Ok(res) = R::try_from(item) {
169 return Some(res);
170 }
171 }
172 None
173}
174
175fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
176where
177 F: Future,
178{
179 Gen::new(move |co| async move {
180 tokio::pin!(fut);
181 loop {
182 tokio::select! {
183 biased;
184 item = rx.recv() => {
185 if let Some(item) = item {
186 co.yield_(item).await;
187 } else {
188 break;
189 }
190 }
191 _ = &mut fut => {
192 break;
193 }
194 }
195 }
196 while let Some(item) = rx.recv().await {
197 co.yield_(item).await;
198 }
199 })
200}
201
202#[derive(Debug)]
207pub struct LocalInfo {
208 request: Arc<GetRequest>,
210 bitfield: Bitfield,
212 children: Option<NonRawLocalInfo>,
214}
215
216impl LocalInfo {
217 pub fn local_bytes(&self) -> u64 {
219 let Some(root_requested) = self.requested_root_ranges() else {
220 return 0;
222 };
223 let mut local = self.bitfield.clone();
224 local.ranges.intersection_with(root_requested);
225 let mut res = local.total_bytes();
226 if let Some(children) = &self.children {
227 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
228 return res;
230 };
231 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
232 if offset == 0 {
233 continue;
235 }
236 let child = offset - 1;
237 if child > *max_local_index {
238 break;
240 }
241 let Some(hash) = children.hash_seq.get(&child) else {
242 continue;
243 };
244 let bitfield = &children.bitfields[hash];
245 let mut local = bitfield.clone();
246 local.ranges.intersection_with(ranges);
247 res += local.total_bytes();
248 }
249 }
250 res
251 }
252
253 pub fn children(&self) -> Option<u64> {
255 if self.children.is_some() {
256 self.bitfield.validated_size().map(|x| x / 32)
257 } else {
258 Some(0)
259 }
260 }
261
262 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
267 self.request.ranges.iter().next()
268 }
269
270 pub fn is_complete(&self) -> bool {
276 let Some(root_requested) = self.requested_root_ranges() else {
277 return true;
279 };
280 if !self.bitfield.ranges.is_superset(root_requested) {
281 return false;
282 }
283 if let Some(children) = self.children.as_ref() {
284 let mut iter = self.request.ranges.iter_non_empty_infinite();
285 let max_child = self.bitfield.validated_size().map(|x| x / 32);
286 loop {
287 let Some((offset, range)) = iter.next() else {
288 break;
289 };
290 if offset == 0 {
291 continue;
293 }
294 let child = offset - 1;
295 if let Some(hash) = children.hash_seq.get(&child) {
296 let bitfield = &children.bitfields[hash];
297 if !bitfield.ranges.is_superset(range) {
298 return false;
300 }
301 } else {
302 if let Some(max_child) = max_child {
303 if child >= max_child {
304 return true;
306 }
307 }
308 return false;
309 }
310 }
311 }
312 true
313 }
314
315 pub fn missing(&self) -> GetRequest {
317 let Some(root_requested) = self.requested_root_ranges() else {
318 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
320 };
321 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
322
323 let Some(children) = self.children.as_ref() else {
324 return builder.build(self.request.hash);
325 };
326 let mut iter = self.request.ranges.iter_non_empty_infinite();
327 let max_local = children
328 .hash_seq
329 .keys()
330 .next_back()
331 .map(|x| *x + 1)
332 .unwrap_or_default();
333 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
334 loop {
335 let Some((offset, requested)) = iter.next() else {
336 break;
337 };
338 if offset == 0 {
339 continue;
341 }
342 let child = offset - 1;
343 let missing = match children.hash_seq.get(&child) {
344 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
345 None => requested.clone(),
346 };
347 builder = builder.child(child, missing);
348 if offset >= max_local {
349 break;
351 }
352 }
353 loop {
354 let Some((offset, requested)) = iter.next() else {
355 return builder.build(self.request.hash);
356 };
357 if offset == 0 {
358 continue;
360 }
361 let child = offset - 1;
362 if let Some(max_offset) = &max_offset {
363 if child >= *max_offset {
364 return builder.build(self.request.hash);
365 }
366 builder = builder.child(child, requested.clone());
367 } else {
368 builder = builder.child(child, requested.clone());
369 if iter.is_at_end() {
370 if iter.next().is_none() {
371 return builder.build(self.request.hash);
372 } else {
373 return builder.build_open(self.request.hash);
374 }
375 }
376 }
377 }
378 }
379}
380
381#[derive(Debug)]
382struct NonRawLocalInfo {
383 hash_seq: BTreeMap<u64, Hash>,
385 bitfields: BTreeMap<Hash, Bitfield>,
388}
389
390impl Remote {
405 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
406 Self::ref_cast(sender)
407 }
408
409 fn store(&self) -> &Store {
410 Store::ref_from_sender(&self.client)
411 }
412
413 pub async fn local_for_request(
414 &self,
415 request: impl Into<Arc<GetRequest>>,
416 ) -> anyhow::Result<LocalInfo> {
417 let request = request.into();
418 let root = request.hash;
419 let bitfield = self.store().observe(root).await?;
420 let children = if !request.ranges.is_blob() {
421 let opts = ExportBaoOptions {
422 hash: root,
423 ranges: bitfield.ranges.clone(),
424 };
425 let bao = self.store().export_bao_with_opts(opts, 32);
426 let mut by_index = BTreeMap::new();
427 let mut stream = bao.hashes_with_index();
428 while let Some(item) = stream.next().await {
429 if let Ok((index, hash)) = item {
430 by_index.insert(index, hash);
431 }
432 }
433 let mut bitfields = BTreeMap::new();
434 let mut hash_seq = BTreeMap::new();
435 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
436 for (index, _) in request.ranges.iter_non_empty_infinite() {
437 if index == 0 {
438 continue;
440 }
441 let child = index - 1;
442 if child > max {
443 break;
445 }
446 let Some(hash) = by_index.get(&child) else {
447 continue;
449 };
450 let bitfield = self.store().observe(*hash).await?;
451 bitfields.insert(*hash, bitfield);
452 hash_seq.insert(child, *hash);
453 }
454 Some(NonRawLocalInfo {
455 hash_seq,
456 bitfields,
457 })
458 } else {
459 None
460 };
461 Ok(LocalInfo {
462 request: request.clone(),
463 bitfield,
464 children,
465 })
466 }
467
468 pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
470 let request = GetRequest::from(content.into());
471 self.local_for_request(request).await
472 }
473
474 pub fn fetch(
475 &self,
476 conn: impl GetConnection + Send + 'static,
477 content: impl Into<HashAndFormat>,
478 ) -> GetProgress {
479 let content = content.into();
480 let (tx, rx) = tokio::sync::mpsc::channel(64);
481 let tx2 = tx.clone();
482 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
483 let this = self.clone();
484 let fut = async move {
485 let res = this.fetch_sink(conn, content, sink).await.into();
486 tx2.send(res).await.ok();
487 };
488 GetProgress {
489 rx,
490 fut: Box::pin(fut),
491 }
492 }
493
494 pub(crate) async fn fetch_sink(
502 &self,
503 mut conn: impl GetConnection,
504 content: impl Into<HashAndFormat>,
505 progress: impl Sink<u64, Error = irpc::channel::SendError>,
506 ) -> GetResult<Stats> {
507 let content = content.into();
508 let local = self
509 .local(content)
510 .await
511 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
512 if local.is_complete() {
513 return Ok(Default::default());
514 }
515 let request = local.missing();
516 let conn = conn
517 .connection()
518 .await
519 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
520 let stats = self.execute_get_sink(&conn, request, progress).await?;
521 Ok(stats)
522 }
523
524 pub fn observe(
525 &self,
526 conn: Connection,
527 request: ObserveRequest,
528 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
529 Gen::new(|co| async move {
530 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
531 co.yield_(Err(cause)).await
532 }
533 })
534 }
535
536 async fn observe_impl(
537 conn: Connection,
538 request: ObserveRequest,
539 co: &Co<io::Result<Bitfield>>,
540 ) -> io::Result<()> {
541 let hash = request.hash;
542 debug!(%hash, "observing");
543 let (mut send, mut recv) = conn.open_bi().await?;
544 write_observe_request(request, &mut send).await?;
546 send.finish()?;
547 loop {
548 let msg = recv
549 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
550 .await?;
551 co.yield_(Ok(Bitfield::from(&msg))).await;
552 }
553 }
554
555 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
556 let (tx, rx) = tokio::sync::mpsc::channel(64);
557 let tx2 = tx.clone();
558 let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
559 let this = self.clone();
560 let fut = async move {
561 let res = this.execute_push_sink(conn, request, sink).await.into();
562 tx2.send(res).await.ok();
563 };
564 PushProgress {
565 rx,
566 fut: Box::pin(fut),
567 }
568 }
569
570 pub(crate) async fn execute_push_sink(
574 &self,
575 conn: Connection,
576 request: PushRequest,
577 progress: impl Sink<u64, Error = irpc::channel::SendError>,
578 ) -> anyhow::Result<Stats> {
579 let hash = request.hash;
580 debug!(%hash, "pushing");
581 let (mut send, mut recv) = conn.open_bi().await?;
582 let mut context = StreamContext {
583 payload_bytes_sent: 0,
584 sender: progress,
585 };
586 recv.stop(0u32.into())?;
588 let request = write_push_request(request, &mut send).await?;
590 let mut request_ranges = request.ranges.iter_infinite();
591 let root = request.hash;
592 let root_ranges = request_ranges.next().expect("infinite iterator");
593 if !root_ranges.is_empty() {
594 self.store()
595 .export_bao(root, root_ranges.clone())
596 .write_quinn_with_progress(&mut send, &mut context, &root, 0)
597 .await?;
598 }
599 if request.ranges.is_blob() {
600 send.finish()?;
602 return Ok(Default::default());
603 }
604 let hash_seq = self.store().get_bytes(root).await?;
605 let hash_seq = HashSeq::try_from(hash_seq)?;
606 for (child, (child_hash, child_ranges)) in
607 hash_seq.into_iter().zip(request_ranges).enumerate()
608 {
609 if !child_ranges.is_empty() {
610 self.store()
611 .export_bao(child_hash, child_ranges.clone())
612 .write_quinn_with_progress(
613 &mut send,
614 &mut context,
615 &child_hash,
616 (child + 1) as u64,
617 )
618 .await?;
619 }
620 }
621 send.finish()?;
622 Ok(Default::default())
623 }
624
625 pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress {
626 self.execute_get_with_opts(conn, request)
627 }
628
629 pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
630 let (tx, rx) = tokio::sync::mpsc::channel(64);
631 let tx2 = tx.clone();
632 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
633 let this = self.clone();
634 let fut = async move {
635 let res = this.execute_get_sink(&conn, request, sink).await.into();
636 tx2.send(res).await.ok();
637 };
638 GetProgress {
639 rx,
640 fut: Box::pin(fut),
641 }
642 }
643
644 pub(crate) async fn execute_get_sink(
653 &self,
654 conn: &Connection,
655 request: GetRequest,
656 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
657 ) -> GetResult<Stats> {
658 let store = self.store();
659 let root = request.hash;
660 let start = crate::get::fsm::start(conn.clone(), request, Default::default());
663 let connected = start.next().await?;
664 trace!("Getting header");
665 let next_child = match connected.next().await? {
667 ConnectedNext::StartRoot(at_start_root) => {
668 let header = at_start_root.next();
669 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
670 match end.next() {
671 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
672 EndBlobNext::Closing(at_closing) => Err(at_closing),
673 }
674 }
675 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
676 ConnectedNext::Closing(at_closing) => Err(at_closing),
677 };
678 let at_closing = match next_child {
680 Ok(at_start_child) => {
681 let mut next_child = Ok(at_start_child);
682 let hash_seq = HashSeq::try_from(
683 store
684 .get_bytes(root)
685 .await
686 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
687 )
688 .map_err(|source| BadRequestSnafu.into_error(source.into()))?;
689 loop {
691 let at_start_child = match next_child {
692 Ok(at_start_child) => at_start_child,
693 Err(at_closing) => break at_closing,
694 };
695 let offset = at_start_child.offset() - 1;
696 let Some(hash) = hash_seq.get(offset as usize) else {
697 break at_start_child.finish();
698 };
699 trace!("getting child {offset} {}", hash.fmt_short());
700 let header = at_start_child.next(hash);
701 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
702 next_child = match end.next() {
703 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
704 EndBlobNext::Closing(at_closing) => Err(at_closing),
705 }
706 }
707 }
708 Err(at_closing) => at_closing,
709 };
710 let stats = at_closing.next().await?;
712 trace!(?stats, "get hash seq done");
713 Ok(stats)
714 }
715
716 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
717 let (tx, rx) = tokio::sync::mpsc::channel(64);
718 let tx2 = tx.clone();
719 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
720 let this = self.clone();
721 let fut = async move {
722 let res = this.execute_get_many_sink(conn, request, sink).await.into();
723 tx2.send(res).await.ok();
724 };
725 GetProgress {
726 rx,
727 fut: Box::pin(fut),
728 }
729 }
730
731 pub async fn execute_get_many_sink(
740 &self,
741 conn: Connection,
742 request: GetManyRequest,
743 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
744 ) -> GetResult<Stats> {
745 let store = self.store();
746 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
747 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
748 let at_closing = match next_child {
750 Ok(at_start_child) => {
751 let mut next_child = Ok(at_start_child);
752 loop {
753 let at_start_child = match next_child {
754 Ok(at_start_child) => at_start_child,
755 Err(at_closing) => break at_closing,
756 };
757 let offset = at_start_child.offset();
758 println!("offset {offset}");
759 let Some(hash) = hash_seq.get(offset as usize) else {
760 break at_start_child.finish();
761 };
762 trace!("getting child {offset} {}", hash.fmt_short());
763 let header = at_start_child.next(hash);
764 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
765 next_child = match end.next() {
766 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
767 EndBlobNext::Closing(at_closing) => Err(at_closing),
768 }
769 }
770 }
771 Err(at_closing) => at_closing,
772 };
773 let stats = at_closing.next().await?;
775 trace!(?stats, "get hash seq done");
776 Ok(stats)
777 }
778}
779
780#[common_fields({
782 backtrace: Option<Backtrace>,
783 #[snafu(implicit)]
784 span_trace: SpanTrace,
785})]
786#[allow(missing_docs)]
787#[non_exhaustive]
788#[derive(Debug, Snafu)]
789pub enum ExecuteError {
790 #[snafu(display("Unable to open bidi stream"))]
792 Connection {
793 source: iroh::endpoint::ConnectionError,
794 },
795 #[snafu(display("Unable to read from the remote"))]
796 Read { source: iroh::endpoint::ReadError },
797 #[snafu(display("Error sending the request"))]
798 Send {
799 source: crate::get::fsm::ConnectedNextError,
800 },
801 #[snafu(display("Unable to read size"))]
802 Size {
803 source: crate::get::fsm::AtBlobHeaderNextError,
804 },
805 #[snafu(display("Error while decoding the data"))]
806 Decode {
807 source: crate::get::fsm::DecodeError,
808 },
809 #[snafu(display("Internal error while reading the hash sequence"))]
810 ExportBao { source: api::ExportBaoError },
811 #[snafu(display("Hash sequence has an invalid length"))]
812 InvalidHashSeq { source: anyhow::Error },
813 #[snafu(display("Internal error importing the data"))]
814 ImportBao { source: crate::api::RequestError },
815 #[snafu(display("Error sending download progress - receiver closed"))]
816 SendDownloadProgress { source: irpc::channel::SendError },
817 #[snafu(display("Internal error importing the data"))]
818 MpscSend {
819 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
820 },
821}
822
823use std::{
824 collections::BTreeMap,
825 future::{Future, IntoFuture},
826 num::NonZeroU64,
827 sync::Arc,
828};
829
830use bao_tree::{
831 io::{BaoContentItem, Leaf},
832 ChunkNum, ChunkRanges,
833};
834use iroh::endpoint::Connection;
835use tracing::{debug, trace};
836
837use crate::{
838 api::{self, blobs::Blobs, Store},
839 get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext},
840 hashseq::{HashSeq, HashSeqIter},
841 protocol::{ChunkRangesSeq, GetRequest},
842 store::IROH_BLOCK_SIZE,
843 Hash, HashAndFormat,
844};
845
846pub trait GetConnection {
848 fn connection(&mut self)
849 -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_;
850}
851
852impl GetConnection for Connection {
854 fn connection(
855 &mut self,
856 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
857 let conn = self.clone();
858 async { Ok(conn) }
859 }
860}
861
862impl GetConnection for &Connection {
864 fn connection(
865 &mut self,
866 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
867 let conn = self.clone();
868 async { Ok(conn) }
869 }
870}
871
872fn get_buffer_size(size: NonZeroU64) -> usize {
873 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
874}
875
876async fn get_blob_ranges_impl(
877 header: AtBlobHeader,
878 hash: Hash,
879 store: &Store,
880 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
881) -> GetResult<AtEndBlob> {
882 let (mut content, size) = header.next().await?;
883 let Some(size) = NonZeroU64::new(size) else {
884 return if hash == Hash::EMPTY {
885 let end = content.drain().await?;
886 Ok(end)
887 } else {
888 Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
889 };
890 };
891 let buffer_size = get_buffer_size(size);
892 trace!(%size, %buffer_size, "get blob");
893 let handle = store
894 .import_bao(hash, size, buffer_size)
895 .await
896 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
897 let write = async move {
898 GetResult::Ok(loop {
899 match content.next().await {
900 BlobContentNext::More((next, res)) => {
901 let item = res?;
902 progress
903 .send(next.stats().payload_bytes_read)
904 .await
905 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
906 handle.tx.send(item).await?;
907 content = next;
908 }
909 BlobContentNext::Done(end) => {
910 drop(handle.tx);
911 break end;
912 }
913 }
914 })
915 };
916 let complete = async move {
917 handle.rx.await.map_err(|e| {
918 LocalFailureSnafu
919 .into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
920 })
921 };
922 let (_, end) = tokio::try_join!(complete, write)?;
923 Ok(end)
924}
925
926#[derive(Debug)]
927pub(crate) struct LazyHashSeq {
928 blobs: Blobs,
929 hash: Hash,
930 current_chunk: Option<HashSeqChunk>,
931}
932
933#[derive(Debug)]
934pub(crate) struct HashSeqChunk {
935 offset: u64,
937 chunk: HashSeq,
939}
940
941impl TryFrom<Leaf> for HashSeqChunk {
942 type Error = anyhow::Error;
943
944 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
945 let offset = leaf.offset;
946 let chunk = HashSeq::try_from(leaf.data)?;
947 Ok(Self { offset, chunk })
948 }
949}
950
951impl IntoIterator for HashSeqChunk {
952 type Item = Hash;
953 type IntoIter = HashSeqIter;
954
955 fn into_iter(self) -> Self::IntoIter {
956 self.chunk.into_iter()
957 }
958}
959
960impl HashSeqChunk {
961 pub fn base(&self) -> u64 {
962 self.offset / 32
963 }
964
965 #[allow(dead_code)]
966 fn get(&self, offset: u64) -> Option<Hash> {
967 let start = self.offset;
968 let end = start + self.chunk.len() as u64;
969 if offset >= start && offset < end {
970 let o = (offset - start) as usize;
971 self.chunk.get(o)
972 } else {
973 None
974 }
975 }
976}
977
978impl LazyHashSeq {
979 #[allow(dead_code)]
980 pub fn new(blobs: Blobs, hash: Hash) -> Self {
981 Self {
982 blobs,
983 hash,
984 current_chunk: None,
985 }
986 }
987
988 #[allow(dead_code)]
989 pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
990 if offset == 0 {
991 Ok(Some(self.hash))
992 } else {
993 self.get(offset - 1).await
994 }
995 }
996
997 #[allow(dead_code)]
998 pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
999 if let Some(chunk) = &self.current_chunk {
1001 if let Some(hash) = chunk.get(child_offset) {
1002 return Ok(Some(hash));
1003 }
1004 }
1005 let leaf = self
1007 .blobs
1008 .export_chunk(self.hash, child_offset * 32)
1009 .await?;
1010 let hs = HashSeqChunk::try_from(leaf)?;
1012 Ok(hs.get(child_offset).inspect(|_hash| {
1013 self.current_chunk = Some(hs);
1014 }))
1015 }
1016}
1017
1018async fn write_push_request(
1019 request: PushRequest,
1020 stream: &mut SendStream,
1021) -> anyhow::Result<PushRequest> {
1022 let mut request_bytes = Vec::new();
1023 request_bytes.push(RequestType::Push as u8);
1024 request_bytes.write_length_prefixed(&request).unwrap();
1025 stream.write_all(&request_bytes).await?;
1026 Ok(request)
1027}
1028
1029async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> {
1030 let request = Request::Observe(request);
1031 let request_bytes = postcard::to_allocvec(&request)
1032 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1033 stream.write_all(&request_bytes).await?;
1034 Ok(())
1035}
1036
1037struct StreamContext<S> {
1038 payload_bytes_sent: u64,
1039 sender: S,
1040}
1041
1042impl<S> WriteProgress for StreamContext<S>
1043where
1044 S: Sink<u64, Error = irpc::channel::SendError>,
1045{
1046 async fn notify_payload_write(
1047 &mut self,
1048 _index: u64,
1049 _offset: u64,
1050 len: usize,
1051 ) -> ClientResult {
1052 self.payload_bytes_sent += len as u64;
1053 self.sender
1054 .send(self.payload_bytes_sent)
1055 .await
1056 .map_err(|e| ProgressError::Internal { source: e.into() })?;
1057 Ok(())
1058 }
1059
1060 fn log_other_write(&mut self, _len: usize) {}
1061
1062 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1063}
1064
1065#[cfg(test)]
1066#[cfg(feature = "fs-store")]
1067mod tests {
1068 use bao_tree::{ChunkNum, ChunkRanges};
1069 use testresult::TestResult;
1070
1071 use crate::{
1072 api::blobs::Blobs,
1073 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
1074 store::{
1075 fs::{
1076 tests::{create_n0_bao, test_data, INTERESTING_SIZES},
1077 FsStore,
1078 },
1079 mem::MemStore,
1080 },
1081 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1082 };
1083
1084 #[tokio::test]
1085 async fn test_local_info_raw() -> TestResult<()> {
1086 let td = tempfile::tempdir()?;
1087 let store = FsStore::load(td.path().join("blobs.db")).await?;
1088 let blobs = store.blobs();
1089 let tt = blobs.add_slice(b"test").temp_tag().await?;
1090 let hash = *tt.hash();
1091 let info = store.remote().local(hash).await?;
1092 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1093 assert_eq!(info.local_bytes(), 4);
1094 assert!(info.is_complete());
1095 assert_eq!(
1096 info.missing(),
1097 GetRequest::new(hash, ChunkRangesSeq::empty())
1098 );
1099 Ok(())
1100 }
1101
1102 #[tokio::test]
1103 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1104 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1105 let relevant_sizes = sizes[32 * 16..32 * 32]
1106 .iter()
1107 .map(|x| *x as u64)
1108 .sum::<u64>();
1109 let td = tempfile::tempdir()?;
1110 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1111 let store = FsStore::load(td.path().join("blobs.db")).await?;
1112 {
1113 let present = |i| {
1115 if i == 0 {
1116 hash_seq_ranges.clone()
1117 } else {
1118 ChunkRanges::from(..ChunkNum(1))
1119 }
1120 };
1121 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1122 let info = store.remote().local(content).await?;
1123 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1124 assert!(!info.is_complete());
1125 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1126 }
1127
1128 Ok(())
1129 }
1130
1131 async fn test_observe_partial(blobs: &Blobs) -> TestResult<()> {
1132 let sizes = INTERESTING_SIZES;
1133 for size in sizes {
1134 let data = test_data(size);
1135 let ranges = ChunkRanges::chunk(0);
1136 let (hash, bao) = create_n0_bao(&data, &ranges)?;
1137 blobs.import_bao_bytes(hash, ranges.clone(), bao).await?;
1138 let bitfield = blobs.observe(hash).await?;
1139 if size > 1024 {
1140 assert_eq!(bitfield.ranges, ranges);
1141 } else {
1142 assert_eq!(bitfield.ranges, ChunkRanges::all());
1143 }
1144 }
1145 Ok(())
1146 }
1147
1148 #[tokio::test]
1149 async fn test_observe_partial_mem() -> TestResult<()> {
1150 let store = MemStore::new();
1151 test_observe_partial(store.blobs()).await?;
1152 Ok(())
1153 }
1154
1155 #[tokio::test]
1156 async fn test_observe_partial_fs() -> TestResult<()> {
1157 let td = tempfile::tempdir()?;
1158 let store = FsStore::load(td.path()).await?;
1159 test_observe_partial(store.blobs()).await?;
1160 Ok(())
1161 }
1162
1163 #[tokio::test]
1164 async fn test_local_info_hash_seq() -> TestResult<()> {
1165 let sizes = INTERESTING_SIZES;
1166 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1167 let hash_seq_size = (sizes.len() as u64) * 32;
1168 let td = tempfile::tempdir()?;
1169 let store = FsStore::load(td.path().join("blobs.db")).await?;
1170 {
1171 let present = |i| {
1173 if i == 0 {
1174 ChunkRanges::all()
1175 } else {
1176 ChunkRanges::empty()
1177 }
1178 };
1179 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1180 let info = store.remote().local(content).await?;
1181 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1182 assert_eq!(info.local_bytes(), hash_seq_size);
1183 assert!(!info.is_complete());
1184 assert_eq!(
1185 info.missing(),
1186 GetRequest::new(
1187 content.hash,
1188 ChunkRangesSeq::from_ranges([
1189 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1193 ChunkRanges::all(),
1194 ChunkRanges::all(),
1195 ChunkRanges::all(),
1196 ChunkRanges::all(),
1197 ChunkRanges::all(),
1198 ])
1199 )
1200 );
1201 store.tags().delete_all().await?;
1202 }
1203 {
1204 let present = |i| {
1206 if i == 0 {
1207 ChunkRanges::all()
1208 } else {
1209 ChunkRanges::from(..ChunkNum(1))
1210 }
1211 };
1212 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1213 let info = store.remote().local(content).await?;
1214 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1215 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1216 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1217 assert!(!info.is_complete());
1218 assert_eq!(
1219 info.missing(),
1220 GetRequest::new(
1221 content.hash,
1222 ChunkRangesSeq::from_ranges([
1223 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1228 ChunkRanges::chunks(1..),
1229 ChunkRanges::chunks(1..),
1230 ChunkRanges::chunks(1..),
1231 ChunkRanges::chunks(1..),
1232 ])
1233 )
1234 );
1235 }
1236 {
1237 let content = add_test_hash_seq(&store, sizes).await?;
1238 let info = store.remote().local(content).await?;
1239 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1240 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1241 assert!(info.is_complete());
1242 assert_eq!(
1243 info.missing(),
1244 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1245 );
1246 }
1247 Ok(())
1248 }
1249
1250 #[tokio::test]
1251 async fn test_local_info_complex_request() -> TestResult<()> {
1252 let sizes = INTERESTING_SIZES;
1253 let hash_seq_size = (sizes.len() as u64) * 32;
1254 let td = tempfile::tempdir()?;
1255 let store = FsStore::load(td.path().join("blobs.db")).await?;
1256 let present = |i| {
1258 if i == 0 {
1259 ChunkRanges::all()
1260 } else {
1261 ChunkRanges::chunks(..2)
1262 }
1263 };
1264 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1265 {
1266 let request: GetRequest = GetRequest::builder()
1267 .root(ChunkRanges::all())
1268 .build(content.hash);
1269 let info = store.remote().local_for_request(request).await?;
1270 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1271 assert_eq!(info.local_bytes(), hash_seq_size);
1272 assert!(info.is_complete());
1273 }
1274 {
1275 let request: GetRequest = GetRequest::builder()
1276 .root(ChunkRanges::all())
1277 .next(ChunkRanges::all())
1278 .build(content.hash);
1279 let info = store.remote().local_for_request(request).await?;
1280 let expected_child_sizes = sizes
1281 .into_iter()
1282 .take(1)
1283 .map(|x| 1024.min(x as u64))
1284 .sum::<u64>();
1285 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1286 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1287 assert!(info.is_complete());
1288 }
1289 {
1290 let request: GetRequest = GetRequest::builder()
1291 .root(ChunkRanges::all())
1292 .next(ChunkRanges::all())
1293 .next(ChunkRanges::all())
1294 .build(content.hash);
1295 let info = store.remote().local_for_request(request).await?;
1296 let expected_child_sizes = sizes
1297 .into_iter()
1298 .take(2)
1299 .map(|x| 1024.min(x as u64))
1300 .sum::<u64>();
1301 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1302 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1303 assert!(info.is_complete());
1304 }
1305 {
1306 let request: GetRequest = GetRequest::builder()
1307 .root(ChunkRanges::all())
1308 .next(ChunkRanges::chunk(0))
1309 .build_open(content.hash);
1310 let info = store.remote().local_for_request(request).await?;
1311 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1312 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1313 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1314 assert!(info.is_complete());
1315 }
1316 Ok(())
1317 }
1318}