1use genawaiter::sync::{Co, Gen};
5use iroh::endpoint::SendStream;
6use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
7use n0_future::{io, Stream, StreamExt};
8use n0_snafu::SpanTrace;
9use nested_enum_utils::common_fields;
10use ref_cast::RefCast;
11use snafu::{Backtrace, IntoError, Snafu};
12
13use super::blobs::{Bitfield, ExportBaoOptions};
14use crate::{
15 api::{blobs::WriteProgress, ApiClient},
16 get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats},
17 protocol::{
18 GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
19 MAX_MESSAGE_SIZE,
20 },
21 util::sink::{Sink, TokioMpscSenderSink},
22};
23
24#[derive(Debug, Clone, RefCast)]
40#[repr(transparent)]
41pub struct Remote {
42 client: ApiClient,
43}
44
45#[derive(Debug)]
46pub enum GetProgressItem {
47 Progress(u64),
49 Done(Stats),
51 Error(GetError),
53}
54
55impl From<GetResult<Stats>> for GetProgressItem {
56 fn from(res: GetResult<Stats>) -> Self {
57 match res {
58 Ok(stats) => GetProgressItem::Done(stats),
59 Err(e) => GetProgressItem::Error(e),
60 }
61 }
62}
63
64impl TryFrom<GetProgressItem> for GetResult<Stats> {
65 type Error = &'static str;
66
67 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
68 match item {
69 GetProgressItem::Done(stats) => Ok(Ok(stats)),
70 GetProgressItem::Error(e) => Ok(Err(e)),
71 GetProgressItem::Progress(_) => Err("not a final item"),
72 }
73 }
74}
75
76pub struct GetProgress {
77 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
78 fut: n0_future::boxed::BoxFuture<()>,
79}
80
81impl IntoFuture for GetProgress {
82 type Output = GetResult<Stats>;
83 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
84
85 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
86 Box::pin(self.complete())
87 }
88}
89
90impl GetProgress {
91 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
92 into_stream(self.rx, self.fut)
93 }
94
95 pub async fn complete(self) -> GetResult<Stats> {
96 just_result(self.stream()).await.unwrap_or_else(|| {
97 Err(LocalFailureSnafu
98 .into_error(anyhow::anyhow!("stream closed without result").into()))
99 })
100 }
101}
102
103#[derive(Debug)]
104pub enum PushProgressItem {
105 Progress(u64),
107 Done(Stats),
109 Error(anyhow::Error),
111}
112
113impl From<anyhow::Result<Stats>> for PushProgressItem {
114 fn from(res: anyhow::Result<Stats>) -> Self {
115 match res {
116 Ok(stats) => Self::Done(stats),
117 Err(e) => Self::Error(e),
118 }
119 }
120}
121
122impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
123 type Error = &'static str;
124
125 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
126 match item {
127 PushProgressItem::Done(stats) => Ok(Ok(stats)),
128 PushProgressItem::Error(e) => Ok(Err(e)),
129 PushProgressItem::Progress(_) => Err("not a final item"),
130 }
131 }
132}
133
134pub struct PushProgress {
135 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
136 fut: n0_future::boxed::BoxFuture<()>,
137}
138
139impl IntoFuture for PushProgress {
140 type Output = anyhow::Result<Stats>;
141 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
142
143 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
144 Box::pin(self.complete())
145 }
146}
147
148impl PushProgress {
149 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
150 into_stream(self.rx, self.fut)
151 }
152
153 pub async fn complete(self) -> anyhow::Result<Stats> {
154 just_result(self.stream())
155 .await
156 .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
157 }
158}
159
160async fn just_result<S, R>(stream: S) -> Option<R>
161where
162 S: Stream<Item: std::fmt::Debug>,
163 R: TryFrom<S::Item>,
164{
165 tokio::pin!(stream);
166 while let Some(item) = stream.next().await {
167 if let Ok(res) = R::try_from(item) {
168 return Some(res);
169 }
170 }
171 None
172}
173
174fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
175where
176 F: Future,
177{
178 Gen::new(move |co| async move {
179 tokio::pin!(fut);
180 loop {
181 tokio::select! {
182 biased;
183 item = rx.recv() => {
184 if let Some(item) = item {
185 co.yield_(item).await;
186 } else {
187 break;
188 }
189 }
190 _ = &mut fut => {
191 break;
192 }
193 }
194 }
195 while let Some(item) = rx.recv().await {
196 co.yield_(item).await;
197 }
198 })
199}
200
201#[derive(Debug)]
206pub struct LocalInfo {
207 request: Arc<GetRequest>,
209 bitfield: Bitfield,
211 children: Option<NonRawLocalInfo>,
213}
214
215impl LocalInfo {
216 pub fn local_bytes(&self) -> u64 {
218 let Some(root_requested) = self.requested_root_ranges() else {
219 return 0;
221 };
222 let mut local = self.bitfield.clone();
223 local.ranges.intersection_with(root_requested);
224 let mut res = local.total_bytes();
225 if let Some(children) = &self.children {
226 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
227 return res;
229 };
230 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
231 if offset == 0 {
232 continue;
234 }
235 let child = offset - 1;
236 if child > *max_local_index {
237 break;
239 }
240 let Some(hash) = children.hash_seq.get(&child) else {
241 continue;
242 };
243 let bitfield = &children.bitfields[hash];
244 let mut local = bitfield.clone();
245 local.ranges.intersection_with(ranges);
246 res += local.total_bytes();
247 }
248 }
249 res
250 }
251
252 pub fn children(&self) -> Option<u64> {
254 if self.children.is_some() {
255 self.bitfield.validated_size().map(|x| x / 32)
256 } else {
257 Some(0)
258 }
259 }
260
261 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
266 self.request.ranges.iter().next()
267 }
268
269 pub fn is_complete(&self) -> bool {
275 let Some(root_requested) = self.requested_root_ranges() else {
276 return true;
278 };
279 if !self.bitfield.ranges.is_superset(root_requested) {
280 return false;
281 }
282 if let Some(children) = self.children.as_ref() {
283 let mut iter = self.request.ranges.iter_non_empty_infinite();
284 let max_child = self.bitfield.validated_size().map(|x| x / 32);
285 loop {
286 let Some((offset, range)) = iter.next() else {
287 break;
288 };
289 if offset == 0 {
290 continue;
292 }
293 let child = offset - 1;
294 if let Some(hash) = children.hash_seq.get(&child) {
295 let bitfield = &children.bitfields[hash];
296 if !bitfield.ranges.is_superset(range) {
297 return false;
299 }
300 } else {
301 if let Some(max_child) = max_child {
302 if child >= max_child {
303 return true;
305 }
306 }
307 return false;
308 }
309 }
310 }
311 true
312 }
313
314 pub fn missing(&self) -> GetRequest {
316 let Some(root_requested) = self.requested_root_ranges() else {
317 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
319 };
320 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
321
322 let Some(children) = self.children.as_ref() else {
323 return builder.build(self.request.hash);
324 };
325 let mut iter = self.request.ranges.iter_non_empty_infinite();
326 let max_local = children
327 .hash_seq
328 .keys()
329 .next_back()
330 .map(|x| *x + 1)
331 .unwrap_or_default();
332 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
333 loop {
334 let Some((offset, requested)) = iter.next() else {
335 break;
336 };
337 if offset == 0 {
338 continue;
340 }
341 let child = offset - 1;
342 let missing = match children.hash_seq.get(&child) {
343 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
344 None => requested.clone(),
345 };
346 builder = builder.child(child, missing);
347 if offset >= max_local {
348 break;
350 }
351 }
352 loop {
353 let Some((offset, requested)) = iter.next() else {
354 return builder.build(self.request.hash);
355 };
356 if offset == 0 {
357 continue;
359 }
360 let child = offset - 1;
361 if let Some(max_offset) = &max_offset {
362 if child >= *max_offset {
363 return builder.build(self.request.hash);
364 }
365 builder = builder.child(child, requested.clone());
366 } else {
367 builder = builder.child(child, requested.clone());
368 if iter.is_at_end() {
369 if iter.next().is_none() {
370 return builder.build(self.request.hash);
371 } else {
372 return builder.build_open(self.request.hash);
373 }
374 }
375 }
376 }
377 }
378}
379
380#[derive(Debug)]
381struct NonRawLocalInfo {
382 hash_seq: BTreeMap<u64, Hash>,
384 bitfields: BTreeMap<Hash, Bitfield>,
387}
388
389impl Remote {
404 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
405 Self::ref_cast(sender)
406 }
407
408 fn store(&self) -> &Store {
409 Store::ref_from_sender(&self.client)
410 }
411
412 pub async fn local_for_request(
413 &self,
414 request: impl Into<Arc<GetRequest>>,
415 ) -> anyhow::Result<LocalInfo> {
416 let request = request.into();
417 let root = request.hash;
418 let bitfield = self.store().observe(root).await?;
419 let children = if !request.ranges.is_blob() {
420 let opts = ExportBaoOptions {
421 hash: root,
422 ranges: bitfield.ranges.clone(),
423 };
424 let bao = self.store().export_bao_with_opts(opts, 32);
425 let mut by_index = BTreeMap::new();
426 let mut stream = bao.hashes_with_index();
427 while let Some(item) = stream.next().await {
428 if let Ok((index, hash)) = item {
429 by_index.insert(index, hash);
430 }
431 }
432 let mut bitfields = BTreeMap::new();
433 let mut hash_seq = BTreeMap::new();
434 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
435 for (index, _) in request.ranges.iter_non_empty_infinite() {
436 if index == 0 {
437 continue;
439 }
440 let child = index - 1;
441 if child > max {
442 break;
444 }
445 let Some(hash) = by_index.get(&child) else {
446 continue;
448 };
449 let bitfield = self.store().observe(*hash).await?;
450 bitfields.insert(*hash, bitfield);
451 hash_seq.insert(child, *hash);
452 }
453 Some(NonRawLocalInfo {
454 hash_seq,
455 bitfields,
456 })
457 } else {
458 None
459 };
460 Ok(LocalInfo {
461 request: request.clone(),
462 bitfield,
463 children,
464 })
465 }
466
467 pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
469 let request = GetRequest::from(content.into());
470 self.local_for_request(request).await
471 }
472
473 pub fn fetch(
474 &self,
475 conn: impl GetConnection + Send + 'static,
476 content: impl Into<HashAndFormat>,
477 ) -> GetProgress {
478 let content = content.into();
479 let (tx, rx) = tokio::sync::mpsc::channel(64);
480 let tx2 = tx.clone();
481 let sink = TokioMpscSenderSink(tx)
482 .with_map(GetProgressItem::Progress)
483 .with_map_err(io::Error::other);
484 let this = self.clone();
485 let fut = async move {
486 let res = this.fetch_sink(conn, content, sink).await.into();
487 tx2.send(res).await.ok();
488 };
489 GetProgress {
490 rx,
491 fut: Box::pin(fut),
492 }
493 }
494
495 pub async fn fetch_sink(
503 &self,
504 mut conn: impl GetConnection,
505 content: impl Into<HashAndFormat>,
506 progress: impl Sink<u64, Error = io::Error>,
507 ) -> GetResult<Stats> {
508 let content = content.into();
509 let local = self
510 .local(content)
511 .await
512 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
513 if local.is_complete() {
514 return Ok(Default::default());
515 }
516 let request = local.missing();
517 let conn = conn
518 .connection()
519 .await
520 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
521 let stats = self.execute_get_sink(conn, request, progress).await?;
522 Ok(stats)
523 }
524
525 pub fn observe(
526 &self,
527 conn: Connection,
528 request: ObserveRequest,
529 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
530 Gen::new(|co| async move {
531 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
532 co.yield_(Err(cause)).await
533 }
534 })
535 }
536
537 async fn observe_impl(
538 conn: Connection,
539 request: ObserveRequest,
540 co: &Co<io::Result<Bitfield>>,
541 ) -> io::Result<()> {
542 let hash = request.hash;
543 debug!(%hash, "observing");
544 let (mut send, mut recv) = conn.open_bi().await?;
545 write_observe_request(request, &mut send).await?;
547 send.finish()?;
548 loop {
549 let msg = recv
550 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
551 .await?;
552 co.yield_(Ok(Bitfield::from(&msg))).await;
553 }
554 }
555
556 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
557 let (tx, rx) = tokio::sync::mpsc::channel(64);
558 let tx2 = tx.clone();
559 let sink = TokioMpscSenderSink(tx)
560 .with_map(PushProgressItem::Progress)
561 .with_map_err(io::Error::other);
562 let this = self.clone();
563 let fut = async move {
564 let res = this.execute_push_sink(conn, request, sink).await.into();
565 tx2.send(res).await.ok();
566 };
567 PushProgress {
568 rx,
569 fut: Box::pin(fut),
570 }
571 }
572
573 pub async fn execute_push_sink(
577 &self,
578 conn: Connection,
579 request: PushRequest,
580 progress: impl Sink<u64, Error = io::Error>,
581 ) -> anyhow::Result<Stats> {
582 let hash = request.hash;
583 debug!(%hash, "pushing");
584 let (mut send, mut recv) = conn.open_bi().await?;
585 let mut context = StreamContext {
586 payload_bytes_sent: 0,
587 sender: progress,
588 };
589 recv.stop(0u32.into())?;
591 let request = write_push_request(request, &mut send).await?;
593 let mut request_ranges = request.ranges.iter_infinite();
594 let root = request.hash;
595 let root_ranges = request_ranges.next().expect("infinite iterator");
596 if !root_ranges.is_empty() {
597 self.store()
598 .export_bao(root, root_ranges.clone())
599 .write_quinn_with_progress(&mut send, &mut context, &root, 0)
600 .await?;
601 }
602 if request.ranges.is_blob() {
603 send.finish()?;
605 return Ok(Default::default());
606 }
607 let hash_seq = self.store().get_bytes(root).await?;
608 let hash_seq = HashSeq::try_from(hash_seq)?;
609 for (child, (child_hash, child_ranges)) in
610 hash_seq.into_iter().zip(request_ranges).enumerate()
611 {
612 if !child_ranges.is_empty() {
613 self.store()
614 .export_bao(child_hash, child_ranges.clone())
615 .write_quinn_with_progress(
616 &mut send,
617 &mut context,
618 &child_hash,
619 (child + 1) as u64,
620 )
621 .await?;
622 }
623 }
624 send.finish()?;
625 Ok(Default::default())
626 }
627
628 pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress {
629 self.execute_get_with_opts(conn, request)
630 }
631
632 pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
633 let (tx, rx) = tokio::sync::mpsc::channel(64);
634 let tx2 = tx.clone();
635 let sink = TokioMpscSenderSink(tx)
636 .with_map(GetProgressItem::Progress)
637 .with_map_err(io::Error::other);
638 let this = self.clone();
639 let fut = async move {
640 let res = this.execute_get_sink(conn, request, sink).await.into();
641 tx2.send(res).await.ok();
642 };
643 GetProgress {
644 rx,
645 fut: Box::pin(fut),
646 }
647 }
648
649 pub async fn execute_get_sink(
658 &self,
659 conn: Connection,
660 request: GetRequest,
661 mut progress: impl Sink<u64, Error = io::Error>,
662 ) -> GetResult<Stats> {
663 let store = self.store();
664 let root = request.hash;
665 let start = crate::get::fsm::start(conn, request, Default::default());
666 let connected = start.next().await?;
667 trace!("Getting header");
668 let next_child = match connected.next().await? {
670 ConnectedNext::StartRoot(at_start_root) => {
671 let header = at_start_root.next();
672 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
673 match end.next() {
674 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
675 EndBlobNext::Closing(at_closing) => Err(at_closing),
676 }
677 }
678 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
679 ConnectedNext::Closing(at_closing) => Err(at_closing),
680 };
681 let at_closing = match next_child {
683 Ok(at_start_child) => {
684 let mut next_child = Ok(at_start_child);
685 let hash_seq = HashSeq::try_from(
686 store
687 .get_bytes(root)
688 .await
689 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
690 )
691 .map_err(|source| BadRequestSnafu.into_error(source.into()))?;
692 loop {
694 let at_start_child = match next_child {
695 Ok(at_start_child) => at_start_child,
696 Err(at_closing) => break at_closing,
697 };
698 let offset = at_start_child.offset() - 1;
699 let Some(hash) = hash_seq.get(offset as usize) else {
700 break at_start_child.finish();
701 };
702 trace!("getting child {offset} {}", hash.fmt_short());
703 let header = at_start_child.next(hash);
704 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
705 next_child = match end.next() {
706 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
707 EndBlobNext::Closing(at_closing) => Err(at_closing),
708 }
709 }
710 }
711 Err(at_closing) => at_closing,
712 };
713 let stats = at_closing.next().await?;
715 trace!(?stats, "get hash seq done");
716 Ok(stats)
717 }
718
719 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
720 let (tx, rx) = tokio::sync::mpsc::channel(64);
721 let tx2 = tx.clone();
722 let sink = TokioMpscSenderSink(tx)
723 .with_map(GetProgressItem::Progress)
724 .with_map_err(io::Error::other);
725 let this = self.clone();
726 let fut = async move {
727 let res = this.execute_get_many_sink(conn, request, sink).await.into();
728 tx2.send(res).await.ok();
729 };
730 GetProgress {
731 rx,
732 fut: Box::pin(fut),
733 }
734 }
735
736 pub async fn execute_get_many_sink(
745 &self,
746 conn: Connection,
747 request: GetManyRequest,
748 mut progress: impl Sink<u64, Error = io::Error>,
749 ) -> GetResult<Stats> {
750 let store = self.store();
751 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
752 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
753 let at_closing = match next_child {
755 Ok(at_start_child) => {
756 let mut next_child = Ok(at_start_child);
757 loop {
758 let at_start_child = match next_child {
759 Ok(at_start_child) => at_start_child,
760 Err(at_closing) => break at_closing,
761 };
762 let offset = at_start_child.offset();
763 println!("offset {offset}");
764 let Some(hash) = hash_seq.get(offset as usize) else {
765 break at_start_child.finish();
766 };
767 trace!("getting child {offset} {}", hash.fmt_short());
768 let header = at_start_child.next(hash);
769 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
770 next_child = match end.next() {
771 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
772 EndBlobNext::Closing(at_closing) => Err(at_closing),
773 }
774 }
775 }
776 Err(at_closing) => at_closing,
777 };
778 let stats = at_closing.next().await?;
780 trace!(?stats, "get hash seq done");
781 Ok(stats)
782 }
783}
784
785#[common_fields({
787 backtrace: Option<Backtrace>,
788 #[snafu(implicit)]
789 span_trace: SpanTrace,
790})]
791#[allow(missing_docs)]
792#[non_exhaustive]
793#[derive(Debug, Snafu)]
794pub enum ExecuteError {
795 #[snafu(display("Unable to open bidi stream"))]
797 Connection {
798 source: iroh::endpoint::ConnectionError,
799 },
800 #[snafu(display("Unable to read from the remote"))]
801 Read { source: iroh::endpoint::ReadError },
802 #[snafu(display("Error sending the request"))]
803 Send {
804 source: crate::get::fsm::ConnectedNextError,
805 },
806 #[snafu(display("Unable to read size"))]
807 Size {
808 source: crate::get::fsm::AtBlobHeaderNextError,
809 },
810 #[snafu(display("Error while decoding the data"))]
811 Decode {
812 source: crate::get::fsm::DecodeError,
813 },
814 #[snafu(display("Internal error while reading the hash sequence"))]
815 ExportBao { source: api::ExportBaoError },
816 #[snafu(display("Hash sequence has an invalid length"))]
817 InvalidHashSeq { source: anyhow::Error },
818 #[snafu(display("Internal error importing the data"))]
819 ImportBao { source: crate::api::RequestError },
820 #[snafu(display("Error sending download progress - receiver closed"))]
821 SendDownloadProgress { source: irpc::channel::SendError },
822 #[snafu(display("Internal error importing the data"))]
823 MpscSend {
824 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
825 },
826}
827
828use std::{
829 collections::BTreeMap,
830 future::{Future, IntoFuture},
831 num::NonZeroU64,
832 sync::Arc,
833};
834
835use bao_tree::{
836 io::{BaoContentItem, Leaf},
837 ChunkNum, ChunkRanges,
838};
839use iroh::endpoint::Connection;
840use tracing::{debug, trace};
841
842use crate::{
843 api::{self, blobs::Blobs, Store},
844 get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext},
845 hashseq::{HashSeq, HashSeqIter},
846 protocol::{ChunkRangesSeq, GetRequest},
847 store::IROH_BLOCK_SIZE,
848 Hash, HashAndFormat,
849};
850
851pub trait GetConnection {
853 fn connection(&mut self)
854 -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_;
855}
856
857impl GetConnection for Connection {
859 fn connection(
860 &mut self,
861 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
862 let conn = self.clone();
863 async { Ok(conn) }
864 }
865}
866
867impl GetConnection for &Connection {
869 fn connection(
870 &mut self,
871 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
872 let conn = self.clone();
873 async { Ok(conn) }
874 }
875}
876
877fn get_buffer_size(size: NonZeroU64) -> usize {
878 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
879}
880
881async fn get_blob_ranges_impl(
882 header: AtBlobHeader,
883 hash: Hash,
884 store: &Store,
885 mut progress: impl Sink<u64, Error = io::Error>,
886) -> GetResult<AtEndBlob> {
887 let (mut content, size) = header.next().await?;
888 let Some(size) = NonZeroU64::new(size) else {
889 return if hash == Hash::EMPTY {
890 let end = content.drain().await?;
891 Ok(end)
892 } else {
893 Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
894 };
895 };
896 let buffer_size = get_buffer_size(size);
897 trace!(%size, %buffer_size, "get blob");
898 let handle = store
899 .import_bao(hash, size, buffer_size)
900 .await
901 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
902 let write = async move {
903 GetResult::Ok(loop {
904 match content.next().await {
905 BlobContentNext::More((next, res)) => {
906 let item = res?;
907 progress
908 .send(next.stats().payload_bytes_read)
909 .await
910 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
911 handle.tx.send(item).await?;
912 content = next;
913 }
914 BlobContentNext::Done(end) => {
915 drop(handle.tx);
916 break end;
917 }
918 }
919 })
920 };
921 let complete = async move {
922 handle.rx.await.map_err(|e| {
923 LocalFailureSnafu
924 .into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
925 })
926 };
927 let (_, end) = tokio::try_join!(complete, write)?;
928 Ok(end)
929}
930
931#[derive(Debug)]
932pub(crate) struct LazyHashSeq {
933 blobs: Blobs,
934 hash: Hash,
935 current_chunk: Option<HashSeqChunk>,
936}
937
938#[derive(Debug)]
939pub(crate) struct HashSeqChunk {
940 offset: u64,
942 chunk: HashSeq,
944}
945
946impl TryFrom<Leaf> for HashSeqChunk {
947 type Error = anyhow::Error;
948
949 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
950 let offset = leaf.offset;
951 let chunk = HashSeq::try_from(leaf.data)?;
952 Ok(Self { offset, chunk })
953 }
954}
955
956impl IntoIterator for HashSeqChunk {
957 type Item = Hash;
958 type IntoIter = HashSeqIter;
959
960 fn into_iter(self) -> Self::IntoIter {
961 self.chunk.into_iter()
962 }
963}
964
965impl HashSeqChunk {
966 pub fn base(&self) -> u64 {
967 self.offset / 32
968 }
969
970 #[allow(dead_code)]
971 fn get(&self, offset: u64) -> Option<Hash> {
972 let start = self.offset;
973 let end = start + self.chunk.len() as u64;
974 if offset >= start && offset < end {
975 let o = (offset - start) as usize;
976 self.chunk.get(o)
977 } else {
978 None
979 }
980 }
981}
982
983impl LazyHashSeq {
984 #[allow(dead_code)]
985 pub fn new(blobs: Blobs, hash: Hash) -> Self {
986 Self {
987 blobs,
988 hash,
989 current_chunk: None,
990 }
991 }
992
993 #[allow(dead_code)]
994 pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
995 if offset == 0 {
996 Ok(Some(self.hash))
997 } else {
998 self.get(offset - 1).await
999 }
1000 }
1001
1002 #[allow(dead_code)]
1003 pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
1004 if let Some(chunk) = &self.current_chunk {
1006 if let Some(hash) = chunk.get(child_offset) {
1007 return Ok(Some(hash));
1008 }
1009 }
1010 let leaf = self
1012 .blobs
1013 .export_chunk(self.hash, child_offset * 32)
1014 .await?;
1015 let hs = HashSeqChunk::try_from(leaf)?;
1017 Ok(hs.get(child_offset).inspect(|_hash| {
1018 self.current_chunk = Some(hs);
1019 }))
1020 }
1021}
1022
1023async fn write_push_request(
1024 request: PushRequest,
1025 stream: &mut SendStream,
1026) -> anyhow::Result<PushRequest> {
1027 let mut request_bytes = Vec::new();
1028 request_bytes.push(RequestType::Push as u8);
1029 request_bytes.write_length_prefixed(&request).unwrap();
1030 stream.write_all(&request_bytes).await?;
1031 Ok(request)
1032}
1033
1034async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> {
1035 let request = Request::Observe(request);
1036 let request_bytes = postcard::to_allocvec(&request)
1037 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1038 stream.write_all(&request_bytes).await?;
1039 Ok(())
1040}
1041
1042struct StreamContext<S> {
1043 payload_bytes_sent: u64,
1044 sender: S,
1045}
1046
1047impl<S> WriteProgress for StreamContext<S>
1048where
1049 S: Sink<u64, Error = io::Error>,
1050{
1051 async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) {
1052 self.payload_bytes_sent += len as u64;
1053 self.sender.send(self.payload_bytes_sent).await.ok();
1054 }
1055
1056 fn log_other_write(&mut self, _len: usize) {}
1057
1058 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use bao_tree::{ChunkNum, ChunkRanges};
1064 use testresult::TestResult;
1065
1066 use crate::{
1067 protocol::{ChunkRangesSeq, GetRequest},
1068 store::fs::{tests::INTERESTING_SIZES, FsStore},
1069 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1070 util::ChunkRangesExt,
1071 };
1072
1073 #[tokio::test]
1074 async fn test_local_info_raw() -> TestResult<()> {
1075 let td = tempfile::tempdir()?;
1076 let store = FsStore::load(td.path().join("blobs.db")).await?;
1077 let blobs = store.blobs();
1078 let tt = blobs.add_slice(b"test").temp_tag().await?;
1079 let hash = *tt.hash();
1080 let info = store.remote().local(hash).await?;
1081 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1082 assert_eq!(info.local_bytes(), 4);
1083 assert!(info.is_complete());
1084 assert_eq!(
1085 info.missing(),
1086 GetRequest::new(hash, ChunkRangesSeq::empty())
1087 );
1088 Ok(())
1089 }
1090
1091 #[tokio::test]
1092 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1093 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1094 let relevant_sizes = sizes[32 * 16..32 * 32]
1095 .iter()
1096 .map(|x| *x as u64)
1097 .sum::<u64>();
1098 let td = tempfile::tempdir()?;
1099 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1100 let store = FsStore::load(td.path().join("blobs.db")).await?;
1101 {
1102 let present = |i| {
1104 if i == 0 {
1105 hash_seq_ranges.clone()
1106 } else {
1107 ChunkRanges::from(..ChunkNum(1))
1108 }
1109 };
1110 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1111 let info = store.remote().local(content).await?;
1112 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1113 assert!(!info.is_complete());
1114 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1115 }
1116
1117 Ok(())
1118 }
1119
1120 #[tokio::test]
1121 async fn test_local_info_hash_seq() -> TestResult<()> {
1122 let sizes = INTERESTING_SIZES;
1123 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1124 let hash_seq_size = (sizes.len() as u64) * 32;
1125 let td = tempfile::tempdir()?;
1126 let store = FsStore::load(td.path().join("blobs.db")).await?;
1127 {
1128 let present = |i| {
1130 if i == 0 {
1131 ChunkRanges::all()
1132 } else {
1133 ChunkRanges::empty()
1134 }
1135 };
1136 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1137 let info = store.remote().local(content).await?;
1138 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1139 assert_eq!(info.local_bytes(), hash_seq_size);
1140 assert!(!info.is_complete());
1141 assert_eq!(
1142 info.missing(),
1143 GetRequest::new(
1144 content.hash,
1145 ChunkRangesSeq::from_ranges([
1146 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1150 ChunkRanges::all(),
1151 ChunkRanges::all(),
1152 ChunkRanges::all(),
1153 ChunkRanges::all(),
1154 ChunkRanges::all(),
1155 ])
1156 )
1157 );
1158 store.tags().delete_all().await?;
1159 }
1160 {
1161 let present = |i| {
1163 if i == 0 {
1164 ChunkRanges::all()
1165 } else {
1166 ChunkRanges::from(..ChunkNum(1))
1167 }
1168 };
1169 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1170 let info = store.remote().local(content).await?;
1171 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1172 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1173 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1174 assert!(!info.is_complete());
1175 assert_eq!(
1176 info.missing(),
1177 GetRequest::new(
1178 content.hash,
1179 ChunkRangesSeq::from_ranges([
1180 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1185 ChunkRanges::chunks(1..),
1186 ChunkRanges::chunks(1..),
1187 ChunkRanges::chunks(1..),
1188 ChunkRanges::chunks(1..),
1189 ])
1190 )
1191 );
1192 }
1193 {
1194 let content = add_test_hash_seq(&store, sizes).await?;
1195 let info = store.remote().local(content).await?;
1196 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1197 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1198 assert!(info.is_complete());
1199 assert_eq!(
1200 info.missing(),
1201 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1202 );
1203 }
1204 Ok(())
1205 }
1206
1207 #[tokio::test]
1208 async fn test_local_info_complex_request() -> TestResult<()> {
1209 let sizes = INTERESTING_SIZES;
1210 let hash_seq_size = (sizes.len() as u64) * 32;
1211 let td = tempfile::tempdir()?;
1212 let store = FsStore::load(td.path().join("blobs.db")).await?;
1213 let present = |i| {
1215 if i == 0 {
1216 ChunkRanges::all()
1217 } else {
1218 ChunkRanges::chunks(..2)
1219 }
1220 };
1221 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1222 {
1223 let request: GetRequest = GetRequest::builder()
1224 .root(ChunkRanges::all())
1225 .build(content.hash);
1226 let info = store.remote().local_for_request(request).await?;
1227 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1228 assert_eq!(info.local_bytes(), hash_seq_size);
1229 assert!(info.is_complete());
1230 }
1231 {
1232 let request: GetRequest = GetRequest::builder()
1233 .root(ChunkRanges::all())
1234 .next(ChunkRanges::all())
1235 .build(content.hash);
1236 let info = store.remote().local_for_request(request).await?;
1237 let expected_child_sizes = sizes
1238 .into_iter()
1239 .take(1)
1240 .map(|x| 1024.min(x as u64))
1241 .sum::<u64>();
1242 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1243 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1244 assert!(info.is_complete());
1245 }
1246 {
1247 let request: GetRequest = GetRequest::builder()
1248 .root(ChunkRanges::all())
1249 .next(ChunkRanges::all())
1250 .next(ChunkRanges::all())
1251 .build(content.hash);
1252 let info = store.remote().local_for_request(request).await?;
1253 let expected_child_sizes = sizes
1254 .into_iter()
1255 .take(2)
1256 .map(|x| 1024.min(x as u64))
1257 .sum::<u64>();
1258 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1259 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1260 assert!(info.is_complete());
1261 }
1262 {
1263 let request: GetRequest = GetRequest::builder()
1264 .root(ChunkRanges::all())
1265 .next(ChunkRanges::chunk(0))
1266 .build_open(content.hash);
1267 let info = store.remote().local_for_request(request).await?;
1268 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1269 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1270 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1271 assert!(info.is_complete());
1272 }
1273 Ok(())
1274 }
1275}