1use genawaiter::sync::{Co, Gen};
5use iroh::endpoint::SendStream;
6use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
7use n0_future::{io, Stream, StreamExt};
8use n0_snafu::SpanTrace;
9use nested_enum_utils::common_fields;
10use ref_cast::RefCast;
11use snafu::{Backtrace, IntoError, Snafu};
12
13use super::blobs::Bitfield;
14use crate::{
15 api::{blobs::WriteProgress, ApiClient},
16 get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats},
17 protocol::{
18 GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
19 MAX_MESSAGE_SIZE,
20 },
21 util::sink::{Sink, TokioMpscSenderSink},
22};
23
24#[derive(Debug, Clone, RefCast)]
40#[repr(transparent)]
41pub struct Remote {
42 client: ApiClient,
43}
44
45#[derive(Debug)]
46pub enum GetProgressItem {
47 Progress(u64),
49 Done(Stats),
51 Error(GetError),
53}
54
55impl From<GetResult<Stats>> for GetProgressItem {
56 fn from(res: GetResult<Stats>) -> Self {
57 match res {
58 Ok(stats) => GetProgressItem::Done(stats),
59 Err(e) => GetProgressItem::Error(e),
60 }
61 }
62}
63
64impl TryFrom<GetProgressItem> for GetResult<Stats> {
65 type Error = &'static str;
66
67 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
68 match item {
69 GetProgressItem::Done(stats) => Ok(Ok(stats)),
70 GetProgressItem::Error(e) => Ok(Err(e)),
71 GetProgressItem::Progress(_) => Err("not a final item"),
72 }
73 }
74}
75
76pub struct GetProgress {
77 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
78 fut: n0_future::boxed::BoxFuture<()>,
79}
80
81impl IntoFuture for GetProgress {
82 type Output = GetResult<Stats>;
83 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
84
85 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
86 Box::pin(self.complete())
87 }
88}
89
90impl GetProgress {
91 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
92 into_stream(self.rx, self.fut)
93 }
94
95 pub async fn complete(self) -> GetResult<Stats> {
96 just_result(self.stream()).await.unwrap_or_else(|| {
97 Err(LocalFailureSnafu
98 .into_error(anyhow::anyhow!("stream closed without result").into()))
99 })
100 }
101}
102
103#[derive(Debug)]
104pub enum PushProgressItem {
105 Progress(u64),
107 Done(Stats),
109 Error(anyhow::Error),
111}
112
113impl From<anyhow::Result<Stats>> for PushProgressItem {
114 fn from(res: anyhow::Result<Stats>) -> Self {
115 match res {
116 Ok(stats) => Self::Done(stats),
117 Err(e) => Self::Error(e),
118 }
119 }
120}
121
122impl TryFrom<PushProgressItem> for anyhow::Result<Stats> {
123 type Error = &'static str;
124
125 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
126 match item {
127 PushProgressItem::Done(stats) => Ok(Ok(stats)),
128 PushProgressItem::Error(e) => Ok(Err(e)),
129 PushProgressItem::Progress(_) => Err("not a final item"),
130 }
131 }
132}
133
134pub struct PushProgress {
135 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
136 fut: n0_future::boxed::BoxFuture<()>,
137}
138
139impl IntoFuture for PushProgress {
140 type Output = anyhow::Result<Stats>;
141 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
142
143 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
144 Box::pin(self.complete())
145 }
146}
147
148impl PushProgress {
149 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
150 into_stream(self.rx, self.fut)
151 }
152
153 pub async fn complete(self) -> anyhow::Result<Stats> {
154 just_result(self.stream())
155 .await
156 .unwrap_or_else(|| Err(anyhow::anyhow!("stream closed without result")))
157 }
158}
159
160async fn just_result<S, R>(stream: S) -> Option<R>
161where
162 S: Stream,
163 R: TryFrom<S::Item>,
164{
165 tokio::pin!(stream);
166 while let Some(item) = stream.next().await {
167 if let Ok(res) = R::try_from(item) {
168 return Some(res);
169 }
170 }
171 None
172}
173
174fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
175where
176 F: Future,
177{
178 Gen::new(move |co| async move {
179 tokio::pin!(fut);
180 loop {
181 tokio::select! {
182 biased;
183 item = rx.recv() => {
184 if let Some(item) = item {
185 co.yield_(item).await;
186 } else {
187 break;
188 }
189 }
190 _ = &mut fut => {
191 break;
192 }
193 }
194 }
195 while let Some(item) = rx.recv().await {
196 co.yield_(item).await;
197 }
198 })
199}
200
201#[derive(Debug)]
206pub struct LocalInfo {
207 request: Arc<GetRequest>,
209 bitfield: Bitfield,
211 children: Option<NonRawLocalInfo>,
213}
214
215impl LocalInfo {
216 pub fn local_bytes(&self) -> u64 {
218 let Some(root_requested) = self.requested_root_ranges() else {
219 return 0;
221 };
222 let mut local = self.bitfield.clone();
223 local.ranges.intersection_with(root_requested);
224 let mut res = local.total_bytes();
225 if let Some(children) = &self.children {
226 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
227 return res;
229 };
230 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
231 if offset == 0 {
232 continue;
234 }
235 let child = offset - 1;
236 if child > *max_local_index {
237 break;
239 }
240 let Some(hash) = children.hash_seq.get(&child) else {
241 continue;
242 };
243 let bitfield = &children.bitfields[hash];
244 let mut local = bitfield.clone();
245 local.ranges.intersection_with(ranges);
246 res += local.total_bytes();
247 }
248 }
249 res
250 }
251
252 pub fn children(&self) -> Option<u64> {
254 if self.children.is_some() {
255 self.bitfield.validated_size().map(|x| x / 32)
256 } else {
257 Some(0)
258 }
259 }
260
261 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
266 self.request.ranges.iter().next()
267 }
268
269 pub fn is_complete(&self) -> bool {
275 let Some(root_requested) = self.requested_root_ranges() else {
276 return true;
278 };
279 if !self.bitfield.ranges.is_superset(root_requested) {
280 return false;
281 }
282 if let Some(children) = self.children.as_ref() {
283 let mut iter = self.request.ranges.iter_non_empty_infinite();
284 let max_child = self.bitfield.validated_size().map(|x| x / 32);
285 loop {
286 let Some((offset, range)) = iter.next() else {
287 break;
288 };
289 if offset == 0 {
290 continue;
292 }
293 let child = offset - 1;
294 if let Some(hash) = children.hash_seq.get(&child) {
295 let bitfield = &children.bitfields[hash];
296 if !bitfield.ranges.is_superset(range) {
297 return false;
299 }
300 } else {
301 if let Some(max_child) = max_child {
302 if child >= max_child {
303 return true;
305 }
306 }
307 return false;
308 }
309 }
310 }
311 true
312 }
313
314 pub fn missing(&self) -> GetRequest {
316 let Some(root_requested) = self.requested_root_ranges() else {
317 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
319 };
320 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
321
322 let Some(children) = self.children.as_ref() else {
323 return builder.build(self.request.hash);
324 };
325 let mut iter = self.request.ranges.iter_non_empty_infinite();
326 let max_local = children
327 .hash_seq
328 .keys()
329 .next_back()
330 .map(|x| *x + 1)
331 .unwrap_or_default();
332 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
333 loop {
334 let Some((offset, requested)) = iter.next() else {
335 break;
336 };
337 if offset == 0 {
338 continue;
340 }
341 let child = offset - 1;
342 let missing = match children.hash_seq.get(&child) {
343 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
344 None => requested.clone(),
345 };
346 builder = builder.child(child, missing);
347 if offset >= max_local {
348 break;
350 }
351 }
352 loop {
353 let Some((offset, requested)) = iter.next() else {
354 return builder.build(self.request.hash);
355 };
356 if offset == 0 {
357 continue;
359 }
360 let child = offset - 1;
361 if let Some(max_offset) = &max_offset {
362 if child >= *max_offset {
363 return builder.build(self.request.hash);
364 }
365 builder = builder.child(child, requested.clone());
366 } else {
367 builder = builder.child(child, requested.clone());
368 if iter.is_at_end() {
369 if iter.next().is_none() {
370 return builder.build(self.request.hash);
371 } else {
372 return builder.build_open(self.request.hash);
373 }
374 }
375 }
376 }
377 }
378}
379
380#[derive(Debug)]
381struct NonRawLocalInfo {
382 hash_seq: BTreeMap<u64, Hash>,
384 bitfields: BTreeMap<Hash, Bitfield>,
387}
388
389impl Remote {
404 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
405 Self::ref_cast(sender)
406 }
407
408 fn store(&self) -> &Store {
409 Store::ref_from_sender(&self.client)
410 }
411
412 pub async fn local_for_request(
413 &self,
414 request: impl Into<Arc<GetRequest>>,
415 ) -> anyhow::Result<LocalInfo> {
416 let request = request.into();
417 let root = request.hash;
418 let bitfield = self.store().observe(root).await?;
419 let children = if !request.ranges.is_blob() {
420 let bao = self.store().export_bao(root, bitfield.ranges.clone());
421 let mut by_index = BTreeMap::new();
422 let mut stream = bao.hashes_with_index();
423 while let Some(item) = stream.next().await {
424 let (index, hash) = item?;
425 by_index.insert(index, hash);
426 }
427 let mut bitfields = BTreeMap::new();
428 let mut hash_seq = BTreeMap::new();
429 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
430 for (index, _) in request.ranges.iter_non_empty_infinite() {
431 if index == 0 {
432 continue;
434 }
435 let child = index - 1;
436 if child > max {
437 break;
439 }
440 let Some(hash) = by_index.get(&child) else {
441 continue;
443 };
444 let bitfield = self.store().observe(*hash).await?;
445 bitfields.insert(*hash, bitfield);
446 hash_seq.insert(child, *hash);
447 }
448 Some(NonRawLocalInfo {
449 hash_seq,
450 bitfields,
451 })
452 } else {
453 None
454 };
455 Ok(LocalInfo {
456 request: request.clone(),
457 bitfield,
458 children,
459 })
460 }
461
462 pub async fn local(&self, content: impl Into<HashAndFormat>) -> anyhow::Result<LocalInfo> {
464 let request = GetRequest::from(content.into());
465 self.local_for_request(request).await
466 }
467
468 pub fn fetch(
469 &self,
470 conn: impl GetConnection + Send + 'static,
471 content: impl Into<HashAndFormat>,
472 ) -> GetProgress {
473 let content = content.into();
474 let (tx, rx) = tokio::sync::mpsc::channel(64);
475 let tx2 = tx.clone();
476 let sink = TokioMpscSenderSink(tx)
477 .with_map(GetProgressItem::Progress)
478 .with_map_err(io::Error::other);
479 let this = self.clone();
480 let fut = async move {
481 let res = this.fetch_sink(conn, content, sink).await.into();
482 tx2.send(res).await.ok();
483 };
484 GetProgress {
485 rx,
486 fut: Box::pin(fut),
487 }
488 }
489
490 pub async fn fetch_sink(
498 &self,
499 mut conn: impl GetConnection,
500 content: impl Into<HashAndFormat>,
501 progress: impl Sink<u64, Error = io::Error>,
502 ) -> GetResult<Stats> {
503 let content = content.into();
504 let local = self
505 .local(content)
506 .await
507 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
508 if local.is_complete() {
509 return Ok(Default::default());
510 }
511 let request = local.missing();
512 let conn = conn
513 .connection()
514 .await
515 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
516 let stats = self.execute_get_sink(conn, request, progress).await?;
517 Ok(stats)
518 }
519
520 pub fn observe(
521 &self,
522 conn: Connection,
523 request: ObserveRequest,
524 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
525 Gen::new(|co| async move {
526 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
527 co.yield_(Err(cause)).await
528 }
529 })
530 }
531
532 async fn observe_impl(
533 conn: Connection,
534 request: ObserveRequest,
535 co: &Co<io::Result<Bitfield>>,
536 ) -> io::Result<()> {
537 let hash = request.hash;
538 debug!(%hash, "observing");
539 let (mut send, mut recv) = conn.open_bi().await?;
540 write_observe_request(request, &mut send).await?;
542 send.finish()?;
543 loop {
544 let msg = recv
545 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
546 .await?;
547 co.yield_(Ok(Bitfield::from(&msg))).await;
548 }
549 }
550
551 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
552 let (tx, rx) = tokio::sync::mpsc::channel(64);
553 let tx2 = tx.clone();
554 let sink = TokioMpscSenderSink(tx)
555 .with_map(PushProgressItem::Progress)
556 .with_map_err(io::Error::other);
557 let this = self.clone();
558 let fut = async move {
559 let res = this.execute_push_sink(conn, request, sink).await.into();
560 tx2.send(res).await.ok();
561 };
562 PushProgress {
563 rx,
564 fut: Box::pin(fut),
565 }
566 }
567
568 pub async fn execute_push_sink(
572 &self,
573 conn: Connection,
574 request: PushRequest,
575 progress: impl Sink<u64, Error = io::Error>,
576 ) -> anyhow::Result<Stats> {
577 let hash = request.hash;
578 debug!(%hash, "pushing");
579 let (mut send, mut recv) = conn.open_bi().await?;
580 let mut context = StreamContext {
581 payload_bytes_sent: 0,
582 sender: progress,
583 };
584 recv.stop(0u32.into())?;
586 let request = write_push_request(request, &mut send).await?;
588 let mut request_ranges = request.ranges.iter_infinite();
589 let root = request.hash;
590 let root_ranges = request_ranges.next().expect("infinite iterator");
591 if !root_ranges.is_empty() {
592 self.store()
593 .export_bao(root, root_ranges.clone())
594 .write_quinn_with_progress(&mut send, &mut context, &root, 0)
595 .await?;
596 }
597 if request.ranges.is_blob() {
598 send.finish()?;
600 return Ok(Default::default());
601 }
602 let hash_seq = self.store().get_bytes(root).await?;
603 let hash_seq = HashSeq::try_from(hash_seq)?;
604 for (child, (child_hash, child_ranges)) in
605 hash_seq.into_iter().zip(request_ranges).enumerate()
606 {
607 if !child_ranges.is_empty() {
608 self.store()
609 .export_bao(child_hash, child_ranges.clone())
610 .write_quinn_with_progress(
611 &mut send,
612 &mut context,
613 &child_hash,
614 (child + 1) as u64,
615 )
616 .await?;
617 }
618 }
619 send.finish()?;
620 Ok(Default::default())
621 }
622
623 pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress {
624 self.execute_get_with_opts(conn, request)
625 }
626
627 pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress {
628 let (tx, rx) = tokio::sync::mpsc::channel(64);
629 let tx2 = tx.clone();
630 let sink = TokioMpscSenderSink(tx)
631 .with_map(GetProgressItem::Progress)
632 .with_map_err(io::Error::other);
633 let this = self.clone();
634 let fut = async move {
635 let res = this.execute_get_sink(conn, request, sink).await.into();
636 tx2.send(res).await.ok();
637 };
638 GetProgress {
639 rx,
640 fut: Box::pin(fut),
641 }
642 }
643
644 pub async fn execute_get_sink(
653 &self,
654 conn: Connection,
655 request: GetRequest,
656 mut progress: impl Sink<u64, Error = io::Error>,
657 ) -> GetResult<Stats> {
658 let store = self.store();
659 let root = request.hash;
660 let start = crate::get::fsm::start(conn, request, Default::default());
661 let connected = start.next().await?;
662 trace!("Getting header");
663 let next_child = match connected.next().await? {
665 ConnectedNext::StartRoot(at_start_root) => {
666 let header = at_start_root.next();
667 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
668 match end.next() {
669 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
670 EndBlobNext::Closing(at_closing) => Err(at_closing),
671 }
672 }
673 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
674 ConnectedNext::Closing(at_closing) => Err(at_closing),
675 };
676 let at_closing = match next_child {
678 Ok(at_start_child) => {
679 let mut next_child = Ok(at_start_child);
680 let hash_seq = HashSeq::try_from(
681 store
682 .get_bytes(root)
683 .await
684 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?,
685 )
686 .map_err(|source| BadRequestSnafu.into_error(source.into()))?;
687 loop {
689 let at_start_child = match next_child {
690 Ok(at_start_child) => at_start_child,
691 Err(at_closing) => break at_closing,
692 };
693 let offset = at_start_child.offset() - 1;
694 let Some(hash) = hash_seq.get(offset as usize) else {
695 break at_start_child.finish();
696 };
697 trace!("getting child {offset} {}", hash.fmt_short());
698 let header = at_start_child.next(hash);
699 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
700 next_child = match end.next() {
701 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
702 EndBlobNext::Closing(at_closing) => Err(at_closing),
703 }
704 }
705 }
706 Err(at_closing) => at_closing,
707 };
708 let stats = at_closing.next().await?;
710 trace!(?stats, "get hash seq done");
711 Ok(stats)
712 }
713
714 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
715 let (tx, rx) = tokio::sync::mpsc::channel(64);
716 let tx2 = tx.clone();
717 let sink = TokioMpscSenderSink(tx)
718 .with_map(GetProgressItem::Progress)
719 .with_map_err(io::Error::other);
720 let this = self.clone();
721 let fut = async move {
722 let res = this.execute_get_many_sink(conn, request, sink).await.into();
723 tx2.send(res).await.ok();
724 };
725 GetProgress {
726 rx,
727 fut: Box::pin(fut),
728 }
729 }
730
731 pub async fn execute_get_many_sink(
740 &self,
741 conn: Connection,
742 request: GetManyRequest,
743 mut progress: impl Sink<u64, Error = io::Error>,
744 ) -> GetResult<Stats> {
745 let store = self.store();
746 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
747 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
748 let at_closing = match next_child {
750 Ok(at_start_child) => {
751 let mut next_child = Ok(at_start_child);
752 loop {
753 let at_start_child = match next_child {
754 Ok(at_start_child) => at_start_child,
755 Err(at_closing) => break at_closing,
756 };
757 let offset = at_start_child.offset();
758 println!("offset {offset}");
759 let Some(hash) = hash_seq.get(offset as usize) else {
760 break at_start_child.finish();
761 };
762 trace!("getting child {offset} {}", hash.fmt_short());
763 let header = at_start_child.next(hash);
764 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
765 next_child = match end.next() {
766 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
767 EndBlobNext::Closing(at_closing) => Err(at_closing),
768 }
769 }
770 }
771 Err(at_closing) => at_closing,
772 };
773 let stats = at_closing.next().await?;
775 trace!(?stats, "get hash seq done");
776 Ok(stats)
777 }
778}
779
780#[common_fields({
782 backtrace: Option<Backtrace>,
783 #[snafu(implicit)]
784 span_trace: SpanTrace,
785})]
786#[allow(missing_docs)]
787#[non_exhaustive]
788#[derive(Debug, Snafu)]
789pub enum ExecuteError {
790 #[snafu(display("Unable to open bidi stream"))]
792 Connection {
793 source: iroh::endpoint::ConnectionError,
794 },
795 #[snafu(display("Unable to read from the remote"))]
796 Read { source: iroh::endpoint::ReadError },
797 #[snafu(display("Error sending the request"))]
798 Send {
799 source: crate::get::fsm::ConnectedNextError,
800 },
801 #[snafu(display("Unable to read size"))]
802 Size {
803 source: crate::get::fsm::AtBlobHeaderNextError,
804 },
805 #[snafu(display("Error while decoding the data"))]
806 Decode {
807 source: crate::get::fsm::DecodeError,
808 },
809 #[snafu(display("Internal error while reading the hash sequence"))]
810 ExportBao { source: api::ExportBaoError },
811 #[snafu(display("Hash sequence has an invalid length"))]
812 InvalidHashSeq { source: anyhow::Error },
813 #[snafu(display("Internal error importing the data"))]
814 ImportBao { source: crate::api::RequestError },
815 #[snafu(display("Error sending download progress - receiver closed"))]
816 SendDownloadProgress { source: irpc::channel::SendError },
817 #[snafu(display("Internal error importing the data"))]
818 MpscSend {
819 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
820 },
821}
822
823use std::{
824 collections::BTreeMap,
825 future::{Future, IntoFuture},
826 num::NonZeroU64,
827 sync::Arc,
828};
829
830use bao_tree::{
831 io::{BaoContentItem, Leaf},
832 ChunkNum, ChunkRanges,
833};
834use iroh::endpoint::Connection;
835use tracing::{debug, trace};
836
837use crate::{
838 api::{self, blobs::Blobs, Store},
839 get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext},
840 hashseq::{HashSeq, HashSeqIter},
841 protocol::{ChunkRangesSeq, GetRequest},
842 store::IROH_BLOCK_SIZE,
843 Hash, HashAndFormat,
844};
845
846pub trait GetConnection {
848 fn connection(&mut self)
849 -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_;
850}
851
852impl GetConnection for Connection {
854 fn connection(
855 &mut self,
856 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
857 let conn = self.clone();
858 async { Ok(conn) }
859 }
860}
861
862impl GetConnection for &Connection {
864 fn connection(
865 &mut self,
866 ) -> impl Future<Output = Result<Connection, anyhow::Error>> + Send + '_ {
867 let conn = self.clone();
868 async { Ok(conn) }
869 }
870}
871
872fn get_buffer_size(size: NonZeroU64) -> usize {
873 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
874}
875
876async fn get_blob_ranges_impl(
877 header: AtBlobHeader,
878 hash: Hash,
879 store: &Store,
880 mut progress: impl Sink<u64, Error = io::Error>,
881) -> GetResult<AtEndBlob> {
882 let (mut content, size) = header.next().await?;
883 let Some(size) = NonZeroU64::new(size) else {
884 return if hash == Hash::EMPTY {
885 let end = content.drain().await?;
886 Ok(end)
887 } else {
888 Err(DecodeError::leaf_hash_mismatch(ChunkNum(0)).into())
889 };
890 };
891 let buffer_size = get_buffer_size(size);
892 trace!(%size, %buffer_size, "get blob");
893 let handle = store
894 .import_bao(hash, size, buffer_size)
895 .await
896 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
897 let write = async move {
898 GetResult::Ok(loop {
899 match content.next().await {
900 BlobContentNext::More((next, res)) => {
901 let item = res?;
902 progress
903 .send(next.stats().payload_bytes_read)
904 .await
905 .map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
906 handle.tx.send(item).await?;
907 content = next;
908 }
909 BlobContentNext::Done(end) => {
910 drop(handle.tx);
911 break end;
912 }
913 }
914 })
915 };
916 let complete = async move {
917 handle.rx.await.map_err(|e| {
918 LocalFailureSnafu
919 .into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
920 })
921 };
922 let (_, end) = tokio::try_join!(complete, write)?;
923 Ok(end)
924}
925
926#[derive(Debug)]
927pub(crate) struct LazyHashSeq {
928 blobs: Blobs,
929 hash: Hash,
930 current_chunk: Option<HashSeqChunk>,
931}
932
933#[derive(Debug)]
934pub(crate) struct HashSeqChunk {
935 offset: u64,
937 chunk: HashSeq,
939}
940
941impl TryFrom<Leaf> for HashSeqChunk {
942 type Error = anyhow::Error;
943
944 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
945 let offset = leaf.offset;
946 let chunk = HashSeq::try_from(leaf.data)?;
947 Ok(Self { offset, chunk })
948 }
949}
950
951impl IntoIterator for HashSeqChunk {
952 type Item = Hash;
953 type IntoIter = HashSeqIter;
954
955 fn into_iter(self) -> Self::IntoIter {
956 self.chunk.into_iter()
957 }
958}
959
960impl HashSeqChunk {
961 pub fn base(&self) -> u64 {
962 self.offset / 32
963 }
964
965 #[allow(dead_code)]
966 fn get(&self, offset: u64) -> Option<Hash> {
967 let start = self.offset;
968 let end = start + self.chunk.len() as u64;
969 if offset >= start && offset < end {
970 let o = (offset - start) as usize;
971 self.chunk.get(o)
972 } else {
973 None
974 }
975 }
976}
977
978impl LazyHashSeq {
979 #[allow(dead_code)]
980 pub fn new(blobs: Blobs, hash: Hash) -> Self {
981 Self {
982 blobs,
983 hash,
984 current_chunk: None,
985 }
986 }
987
988 #[allow(dead_code)]
989 pub async fn get_from_offset(&mut self, offset: u64) -> anyhow::Result<Option<Hash>> {
990 if offset == 0 {
991 Ok(Some(self.hash))
992 } else {
993 self.get(offset - 1).await
994 }
995 }
996
997 #[allow(dead_code)]
998 pub async fn get(&mut self, child_offset: u64) -> anyhow::Result<Option<Hash>> {
999 if let Some(chunk) = &self.current_chunk {
1001 if let Some(hash) = chunk.get(child_offset) {
1002 return Ok(Some(hash));
1003 }
1004 }
1005 let leaf = self
1007 .blobs
1008 .export_chunk(self.hash, child_offset * 32)
1009 .await?;
1010 let hs = HashSeqChunk::try_from(leaf)?;
1012 Ok(hs.get(child_offset).inspect(|_hash| {
1013 self.current_chunk = Some(hs);
1014 }))
1015 }
1016}
1017
1018async fn write_push_request(
1019 request: PushRequest,
1020 stream: &mut SendStream,
1021) -> anyhow::Result<PushRequest> {
1022 let mut request_bytes = Vec::new();
1023 request_bytes.push(RequestType::Push as u8);
1024 request_bytes.write_length_prefixed(&request).unwrap();
1025 stream.write_all(&request_bytes).await?;
1026 Ok(request)
1027}
1028
1029async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> {
1030 let request = Request::Observe(request);
1031 let request_bytes = postcard::to_allocvec(&request)
1032 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1033 stream.write_all(&request_bytes).await?;
1034 Ok(())
1035}
1036
1037struct StreamContext<S> {
1038 payload_bytes_sent: u64,
1039 sender: S,
1040}
1041
1042impl<S> WriteProgress for StreamContext<S>
1043where
1044 S: Sink<u64, Error = io::Error>,
1045{
1046 async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) {
1047 self.payload_bytes_sent += len as u64;
1048 self.sender.send(self.payload_bytes_sent).await.ok();
1049 }
1050
1051 fn log_other_write(&mut self, _len: usize) {}
1052
1053 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058 use bao_tree::{ChunkNum, ChunkRanges};
1059 use testresult::TestResult;
1060
1061 use crate::{
1062 protocol::{ChunkRangesSeq, GetRequest},
1063 store::fs::{tests::INTERESTING_SIZES, FsStore},
1064 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1065 util::ChunkRangesExt,
1066 };
1067
1068 #[tokio::test]
1069 async fn test_local_info_raw() -> TestResult<()> {
1070 let td = tempfile::tempdir()?;
1071 let store = FsStore::load(td.path().join("blobs.db")).await?;
1072 let blobs = store.blobs();
1073 let tt = blobs.add_slice(b"test").temp_tag().await?;
1074 let hash = *tt.hash();
1075 let info = store.remote().local(hash).await?;
1076 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1077 assert_eq!(info.local_bytes(), 4);
1078 assert!(info.is_complete());
1079 assert_eq!(
1080 info.missing(),
1081 GetRequest::new(hash, ChunkRangesSeq::empty())
1082 );
1083 Ok(())
1084 }
1085
1086 #[tokio::test]
1087 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1088 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1089 let relevant_sizes = sizes[32 * 16..32 * 32]
1090 .iter()
1091 .map(|x| *x as u64)
1092 .sum::<u64>();
1093 let td = tempfile::tempdir()?;
1094 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1095 let store = FsStore::load(td.path().join("blobs.db")).await?;
1096 {
1097 let present = |i| {
1099 if i == 0 {
1100 hash_seq_ranges.clone()
1101 } else {
1102 ChunkRanges::from(..ChunkNum(1))
1103 }
1104 };
1105 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1106 let info = store.remote().local(content).await?;
1107 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1108 assert!(!info.is_complete());
1109 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1110 }
1111
1112 Ok(())
1113 }
1114
1115 #[tokio::test]
1116 async fn test_local_info_hash_seq() -> TestResult<()> {
1117 let sizes = INTERESTING_SIZES;
1118 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1119 let hash_seq_size = (sizes.len() as u64) * 32;
1120 let td = tempfile::tempdir()?;
1121 let store = FsStore::load(td.path().join("blobs.db")).await?;
1122 {
1123 let present = |i| {
1125 if i == 0 {
1126 ChunkRanges::all()
1127 } else {
1128 ChunkRanges::empty()
1129 }
1130 };
1131 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1132 let info = store.remote().local(content).await?;
1133 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1134 assert_eq!(info.local_bytes(), hash_seq_size);
1135 assert!(!info.is_complete());
1136 assert_eq!(
1137 info.missing(),
1138 GetRequest::new(
1139 content.hash,
1140 ChunkRangesSeq::from_ranges([
1141 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1145 ChunkRanges::all(),
1146 ChunkRanges::all(),
1147 ChunkRanges::all(),
1148 ChunkRanges::all(),
1149 ChunkRanges::all(),
1150 ])
1151 )
1152 );
1153 store.tags().delete_all().await?;
1154 }
1155 {
1156 let present = |i| {
1158 if i == 0 {
1159 ChunkRanges::all()
1160 } else {
1161 ChunkRanges::from(..ChunkNum(1))
1162 }
1163 };
1164 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1165 let info = store.remote().local(content).await?;
1166 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1167 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1168 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1169 assert!(!info.is_complete());
1170 assert_eq!(
1171 info.missing(),
1172 GetRequest::new(
1173 content.hash,
1174 ChunkRangesSeq::from_ranges([
1175 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1180 ChunkRanges::chunks(1..),
1181 ChunkRanges::chunks(1..),
1182 ChunkRanges::chunks(1..),
1183 ChunkRanges::chunks(1..),
1184 ])
1185 )
1186 );
1187 }
1188 {
1189 let content = add_test_hash_seq(&store, sizes).await?;
1190 let info = store.remote().local(content).await?;
1191 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1192 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1193 assert!(info.is_complete());
1194 assert_eq!(
1195 info.missing(),
1196 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1197 );
1198 }
1199 Ok(())
1200 }
1201
1202 #[tokio::test]
1203 async fn test_local_info_complex_request() -> TestResult<()> {
1204 let sizes = INTERESTING_SIZES;
1205 let hash_seq_size = (sizes.len() as u64) * 32;
1206 let td = tempfile::tempdir()?;
1207 let store = FsStore::load(td.path().join("blobs.db")).await?;
1208 let present = |i| {
1210 if i == 0 {
1211 ChunkRanges::all()
1212 } else {
1213 ChunkRanges::chunks(..2)
1214 }
1215 };
1216 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1217 {
1218 let request: GetRequest = GetRequest::builder()
1219 .root(ChunkRanges::all())
1220 .build(content.hash);
1221 let info = store.remote().local_for_request(request).await?;
1222 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1223 assert_eq!(info.local_bytes(), hash_seq_size);
1224 assert!(info.is_complete());
1225 }
1226 {
1227 let request: GetRequest = GetRequest::builder()
1228 .root(ChunkRanges::all())
1229 .next(ChunkRanges::all())
1230 .build(content.hash);
1231 let info = store.remote().local_for_request(request).await?;
1232 let expected_child_sizes = sizes
1233 .into_iter()
1234 .take(1)
1235 .map(|x| 1024.min(x as u64))
1236 .sum::<u64>();
1237 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1238 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1239 assert!(info.is_complete());
1240 }
1241 {
1242 let request: GetRequest = GetRequest::builder()
1243 .root(ChunkRanges::all())
1244 .next(ChunkRanges::all())
1245 .next(ChunkRanges::all())
1246 .build(content.hash);
1247 let info = store.remote().local_for_request(request).await?;
1248 let expected_child_sizes = sizes
1249 .into_iter()
1250 .take(2)
1251 .map(|x| 1024.min(x as u64))
1252 .sum::<u64>();
1253 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1254 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1255 assert!(info.is_complete());
1256 }
1257 {
1258 let request: GetRequest = GetRequest::builder()
1259 .root(ChunkRanges::all())
1260 .next(ChunkRanges::chunk(0))
1261 .build_open(content.hash);
1262 let info = store.remote().local_for_request(request).await?;
1263 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1264 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1265 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1266 assert!(info.is_complete());
1267 }
1268 Ok(())
1269 }
1270}