1use std::{future::Future, io, num::NonZeroU64, pin::Pin};
4
5use anyhow::anyhow;
6use bao_tree::{ChunkNum, ChunkRanges};
7use futures_lite::StreamExt;
8use genawaiter::{
9 rc::{Co, Gen},
10 GeneratorState,
11};
12use iroh::endpoint::Connection;
13use iroh_io::AsyncSliceReader;
14use serde::{Deserialize, Serialize};
15use tokio::sync::oneshot;
16use tracing::trace;
17
18use crate::{
19 get::{
20 self,
21 error::GetError,
22 fsm::{AtBlobHeader, AtEndBlob, ConnectedNext, EndBlobNext},
23 progress::TransferState,
24 Stats,
25 },
26 hashseq::parse_hash_seq,
27 protocol::{GetRequest, RangeSpec, RangeSpecSeq},
28 store::{
29 BaoBatchWriter, BaoBlobSize, FallibleProgressBatchWriter, MapEntry, MapEntryMut, MapMut,
30 Store as BaoStore,
31 },
32 util::progress::{IdGenerator, ProgressSender},
33 BlobFormat, Hash, HashAndFormat,
34};
35
36type GetGenerator = Gen<Yield, (), Pin<Box<dyn Future<Output = Result<Stats, GetError>>>>>;
37type GetFuture = Pin<Box<dyn Future<Output = Result<Stats, GetError>> + 'static>>;
38
39pub async fn get_to_db<
48 D: BaoStore,
49 C: FnOnce() -> F,
50 F: Future<Output = anyhow::Result<Connection>>,
51>(
52 db: &D,
53 get_conn: C,
54 hash_and_format: &HashAndFormat,
55 progress_sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
56) -> Result<Stats, GetError> {
57 match get_to_db_in_steps(db.clone(), *hash_and_format, progress_sender).await? {
58 GetState::Complete(res) => Ok(res),
59 GetState::NeedsConn(state) => {
60 let conn = get_conn().await.map_err(GetError::Io)?;
61 state.proceed(conn).await
62 }
63 }
64}
65
66pub async fn get_to_db_in_steps<
77 D: BaoStore,
78 P: ProgressSender<Msg = DownloadProgress> + IdGenerator,
79>(
80 db: D,
81 hash_and_format: HashAndFormat,
82 progress_sender: P,
83) -> Result<GetState, GetError> {
84 let mut gen: GetGenerator = genawaiter::rc::Gen::new(move |co| {
85 let fut = async move { producer(co, &db, &hash_and_format, progress_sender).await };
86 let fut: GetFuture = Box::pin(fut);
87 fut
88 });
89 match gen.async_resume().await {
90 GeneratorState::Yielded(Yield::NeedConn(reply)) => {
91 Ok(GetState::NeedsConn(GetStateNeedsConn(gen, reply)))
92 }
93 GeneratorState::Complete(res) => res.map(GetState::Complete),
94 }
95}
96
97#[derive(derive_more::Debug)]
100#[debug("GetStateNeedsConn")]
101pub struct GetStateNeedsConn(GetGenerator, oneshot::Sender<Connection>);
102
103impl GetStateNeedsConn {
104 pub async fn proceed(mut self, conn: Connection) -> Result<Stats, GetError> {
106 self.1.send(conn).expect("receiver is not dropped");
107 match self.0.async_resume().await {
108 GeneratorState::Yielded(y) => match y {
109 Yield::NeedConn(_) => panic!("NeedsConn may only be yielded once"),
110 },
111 GeneratorState::Complete(res) => res,
112 }
113 }
114}
115
116#[derive(Debug)]
118pub enum GetState {
119 Complete(Stats),
122 NeedsConn(GetStateNeedsConn),
127}
128
129struct GetCo(Co<Yield>);
130
131impl GetCo {
132 async fn get_conn(&self) -> Connection {
133 let (tx, rx) = oneshot::channel();
134 self.0.yield_(Yield::NeedConn(tx)).await;
135 rx.await.expect("sender may not be dropped")
136 }
137}
138
139enum Yield {
140 NeedConn(oneshot::Sender<Connection>),
141}
142
143async fn producer<D: BaoStore>(
144 co: Co<Yield, ()>,
145 db: &D,
146 hash_and_format: &HashAndFormat,
147 progress: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
148) -> Result<Stats, GetError> {
149 let HashAndFormat { hash, format } = hash_and_format;
150 let co = GetCo(co);
151 match format {
152 BlobFormat::Raw => get_blob(db, co, hash, progress).await,
153 BlobFormat::HashSeq => get_hash_seq(db, co, hash, progress).await,
154 }
155}
156
157async fn get_blob<D: BaoStore>(
162 db: &D,
163 co: GetCo,
164 hash: &Hash,
165 progress: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
166) -> Result<Stats, GetError> {
167 let end = match db.get_mut(hash).await? {
168 Some(entry) if entry.is_complete() => {
169 tracing::info!("already got entire blob");
170 progress
171 .send(DownloadProgress::FoundLocal {
172 child: BlobId::Root,
173 hash: *hash,
174 size: entry.size(),
175 valid_ranges: RangeSpec::all(),
176 })
177 .await?;
178 return Ok(Stats::default());
179 }
180 Some(entry) => {
181 trace!("got partial data for {}", hash);
182 let valid_ranges = valid_ranges::<D>(&entry)
183 .await
184 .ok()
185 .unwrap_or_else(ChunkRanges::all);
186 progress
187 .send(DownloadProgress::FoundLocal {
188 child: BlobId::Root,
189 hash: *hash,
190 size: entry.size(),
191 valid_ranges: RangeSpec::new(&valid_ranges),
192 })
193 .await?;
194 let required_ranges: ChunkRanges = ChunkRanges::all().difference(&valid_ranges);
195
196 let request = GetRequest::new(*hash, RangeSpecSeq::from_ranges([required_ranges]));
197 let conn = co.get_conn().await;
199 let request = get::fsm::start(conn, request);
200 let connected = request.next().await?;
202 let ConnectedNext::StartRoot(start) = connected.next().await? else {
204 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
205 };
206 let header = start.next();
208 get_blob_inner_partial(db, header, entry, progress).await?
211 }
212 None => {
213 let conn = co.get_conn().await;
215 let request = get::fsm::start(conn, GetRequest::single(*hash));
216 let connected = request.next().await?;
218 let ConnectedNext::StartRoot(start) = connected.next().await? else {
220 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
221 };
222 let header = start.next();
224 get_blob_inner(db, header, progress).await?
226 }
227 };
228
229 let EndBlobNext::Closing(end) = end.next() else {
231 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
232 };
233 let stats = end.next().await?;
235 Ok(stats)
236}
237
238pub async fn valid_ranges<D: MapMut>(entry: &D::EntryMut) -> anyhow::Result<ChunkRanges> {
240 use tracing::trace as log;
241 let mut data_reader = entry.data_reader().await?;
243 let data_size = data_reader.size().await?;
244 let valid_from_data = ChunkRanges::from(..ChunkNum::full_chunks(data_size));
245 let mut outboard = entry.outboard().await?;
247 let all = ChunkRanges::all();
248 let mut stream = bao_tree::io::fsm::valid_outboard_ranges(&mut outboard, &all);
249 let mut valid_from_outboard = ChunkRanges::empty();
250 while let Some(range) = stream.next().await {
251 valid_from_outboard |= ChunkRanges::from(range?);
252 }
253 let valid: ChunkRanges = valid_from_data.intersection(&valid_from_outboard);
254 log!("valid_from_data: {:?}", valid_from_data);
255 log!("valid_from_outboard: {:?}", valid_from_data);
256 Ok(valid)
257}
258
259async fn get_blob_inner<D: BaoStore>(
264 db: &D,
265 at_header: AtBlobHeader,
266 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
267) -> Result<AtEndBlob, GetError> {
268 let (at_content, size) = at_header.next().await?;
271 let hash = at_content.hash();
272 let child_offset = at_content.offset();
273 let entry = db.get_or_create(hash, size).await?;
275 let bw = entry.batch_writer().await?;
277 let id = sender.new_id();
279 sender
280 .send(DownloadProgress::Found {
281 id,
282 hash,
283 size,
284 child: BlobId::from_offset(child_offset),
285 })
286 .await?;
287 let sender2 = sender.clone();
288 let on_write = move |offset: u64, _length: usize| {
289 sender2
292 .try_send(DownloadProgress::Progress { id, offset })
293 .inspect_err(|_| {
294 tracing::info!("aborting download of {}", hash);
295 })?;
296 Ok(())
297 };
298 let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
299 let end = at_content.write_all_batch(&mut bw).await?;
301 bw.sync().await?;
303 drop(bw);
304 db.insert_complete(entry).await?;
305 sender.send(DownloadProgress::Done { id }).await?;
307 Ok(end)
308}
309
310async fn get_blob_inner_partial<D: BaoStore>(
315 db: &D,
316 at_header: AtBlobHeader,
317 entry: D::EntryMut,
318 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
319) -> Result<AtEndBlob, GetError> {
320 let (at_content, size) = at_header.next().await?;
323 let bw = entry.batch_writer().await?;
325 let id = sender.new_id();
327 let hash = at_content.hash();
328 let child_offset = at_content.offset();
329 sender
330 .send(DownloadProgress::Found {
331 id,
332 hash,
333 size,
334 child: BlobId::from_offset(child_offset),
335 })
336 .await?;
337 let sender2 = sender.clone();
338 let on_write = move |offset: u64, _length: usize| {
339 sender2
342 .try_send(DownloadProgress::Progress { id, offset })
343 .inspect_err(|_| {
344 tracing::info!("aborting download of {}", hash);
345 })?;
346 Ok(())
347 };
348 let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
349 let at_end = at_content.write_all_batch(&mut bw).await?;
351 bw.sync().await?;
353 drop(bw);
354 db.insert_complete(entry).await?;
359 sender.send(DownloadProgress::Done { id }).await?;
361 Ok(at_end)
362}
363
364pub async fn blob_info<D: BaoStore>(db: &D, hash: &Hash) -> io::Result<BlobInfo<D>> {
368 io::Result::Ok(match db.get_mut(hash).await? {
369 Some(entry) if entry.is_complete() => BlobInfo::Complete {
370 size: entry.size().value(),
371 },
372 Some(entry) => {
373 let valid_ranges = valid_ranges::<D>(&entry)
374 .await
375 .ok()
376 .unwrap_or_else(ChunkRanges::all);
377 BlobInfo::Partial {
378 entry,
379 valid_ranges,
380 }
381 }
382 None => BlobInfo::Missing,
383 })
384}
385
386async fn blob_infos<D: BaoStore>(db: &D, hash_seq: &[Hash]) -> io::Result<Vec<BlobInfo<D>>> {
388 let items = futures_lite::stream::iter(hash_seq)
389 .then(|hash| blob_info(db, hash))
390 .collect::<Vec<_>>();
391 items.await.into_iter().collect()
392}
393
394async fn get_hash_seq<D: BaoStore>(
396 db: &D,
397 co: GetCo,
398 root_hash: &Hash,
399 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
400) -> Result<Stats, GetError> {
401 use tracing::info as log;
402 let finishing = match db.get_mut(root_hash).await? {
403 Some(entry) if entry.is_complete() => {
404 log!("already got collection - doing partial download");
405 sender
407 .send(DownloadProgress::FoundLocal {
408 child: BlobId::Root,
409 hash: *root_hash,
410 size: entry.size(),
411 valid_ranges: RangeSpec::all(),
412 })
413 .await?;
414 let reader = entry.data_reader().await?;
416 let (mut hash_seq, children) = parse_hash_seq(reader).await.map_err(|err| {
417 GetError::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
418 })?;
419 sender
420 .send(DownloadProgress::FoundHashSeq {
421 hash: *root_hash,
422 children,
423 })
424 .await?;
425 let mut children: Vec<Hash> = vec![];
426 while let Some(hash) = hash_seq.next().await? {
427 children.push(hash);
428 }
429 let missing_info = blob_infos(db, &children).await?;
430 for (i, info) in missing_info.iter().enumerate() {
432 if let Some(size) = info.size() {
433 sender
434 .send(DownloadProgress::FoundLocal {
435 child: BlobId::from_offset((i as u64) + 1),
436 hash: children[i],
437 size,
438 valid_ranges: RangeSpec::new(info.valid_ranges()),
439 })
440 .await?;
441 }
442 }
443 if missing_info
444 .iter()
445 .all(|x| matches!(x, BlobInfo::Complete { .. }))
446 {
447 log!("nothing to do");
448 return Ok(Stats::default());
449 }
450
451 let missing_iter = std::iter::once(ChunkRanges::empty())
452 .chain(missing_info.iter().map(|x| x.missing_ranges()))
453 .collect::<Vec<_>>();
454 log!("requesting chunks {:?}", missing_iter);
455 let request = GetRequest::new(*root_hash, RangeSpecSeq::from_ranges(missing_iter));
456 let conn = co.get_conn().await;
457 let request = get::fsm::start(conn, request);
458 let connected = request.next().await?;
460 log!("connected");
461 let ConnectedNext::StartChild(start) = connected.next().await? else {
463 return Err(GetError::NoncompliantNode(anyhow!("expected StartChild")));
464 };
465 let mut next = EndBlobNext::MoreChildren(start);
466 loop {
468 let start = match next {
469 EndBlobNext::MoreChildren(start) => start,
470 EndBlobNext::Closing(finish) => break finish,
471 };
472 let child_offset = usize::try_from(start.child_offset())
473 .map_err(|_| GetError::NoncompliantNode(anyhow!("child offset too large")))?;
474 let (child_hash, info) =
475 match (children.get(child_offset), missing_info.get(child_offset)) {
476 (Some(blob), Some(info)) => (*blob, info),
477 _ => break start.finish(),
478 };
479 tracing::info!(
480 "requesting child {} {:?}",
481 child_hash,
482 info.missing_ranges()
483 );
484 let header = start.next(child_hash);
485 let end_blob = match info {
486 BlobInfo::Missing => get_blob_inner(db, header, sender.clone()).await?,
487 BlobInfo::Partial { entry, .. } => {
488 get_blob_inner_partial(db, header, entry.clone(), sender.clone()).await?
489 }
490 BlobInfo::Complete { .. } => {
491 return Err(GetError::NoncompliantNode(anyhow!(
492 "got data we have not requested"
493 )));
494 }
495 };
496 next = end_blob.next();
497 }
498 }
499 _ => {
500 tracing::debug!("don't have collection - doing full download");
501 let conn = co.get_conn().await;
503 let request = get::fsm::start(conn, GetRequest::all(*root_hash));
504 let connected = request.next().await?;
506 let ConnectedNext::StartRoot(start) = connected.next().await? else {
508 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
509 };
510 let header = start.next();
512 let end_root = get_blob_inner(db, header, sender.clone()).await?;
514 let entry = db
516 .get(root_hash)
517 .await?
518 .ok_or_else(|| GetError::LocalFailure(anyhow!("just downloaded but not in db")))?;
519 let reader = entry.data_reader().await?;
520 let (mut collection, count) = parse_hash_seq(reader).await.map_err(|err| {
521 GetError::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
522 })?;
523 sender
524 .send(DownloadProgress::FoundHashSeq {
525 hash: *root_hash,
526 children: count,
527 })
528 .await?;
529 let mut children = vec![];
530 while let Some(hash) = collection.next().await? {
531 children.push(hash);
532 }
533 let mut next = end_root.next();
534 loop {
536 let start = match next {
537 EndBlobNext::MoreChildren(start) => start,
538 EndBlobNext::Closing(finish) => break finish,
539 };
540 let child_offset = usize::try_from(start.child_offset())
541 .map_err(|_| GetError::NoncompliantNode(anyhow!("child offset too large")))?;
542
543 let child_hash = match children.get(child_offset) {
544 Some(blob) => *blob,
545 None => break start.finish(),
546 };
547 let header = start.next(child_hash);
548 let end_blob = get_blob_inner(db, header, sender.clone()).await?;
549 next = end_blob.next();
550 }
551 }
552 };
553 let stats = finishing.next().await?;
555 Ok(stats)
556}
557
558#[derive(Debug, Clone)]
560pub enum BlobInfo<D: BaoStore> {
561 Complete {
563 size: u64,
565 },
566 Partial {
568 entry: D::EntryMut,
570 valid_ranges: ChunkRanges,
572 },
573 Missing,
575}
576
577impl<D: BaoStore> BlobInfo<D> {
578 pub fn size(&self) -> Option<BaoBlobSize> {
580 match self {
581 BlobInfo::Complete { size } => Some(BaoBlobSize::Verified(*size)),
582 BlobInfo::Partial { entry, .. } => Some(entry.size()),
583 BlobInfo::Missing => None,
584 }
585 }
586
587 pub fn valid_ranges(&self) -> ChunkRanges {
592 match self {
593 BlobInfo::Complete { .. } => ChunkRanges::all(),
594 BlobInfo::Partial { valid_ranges, .. } => valid_ranges.clone(),
595 BlobInfo::Missing => ChunkRanges::empty(),
596 }
597 }
598
599 pub fn missing_ranges(&self) -> ChunkRanges {
604 match self {
605 BlobInfo::Complete { .. } => ChunkRanges::empty(),
606 BlobInfo::Partial { valid_ranges, .. } => ChunkRanges::all().difference(valid_ranges),
607 BlobInfo::Missing => ChunkRanges::all(),
608 }
609 }
610}
611
612#[derive(Debug, Clone, Serialize, Deserialize)]
615pub enum DownloadProgress {
616 InitialState(TransferState),
618 FoundLocal {
620 child: BlobId,
622 hash: Hash,
624 size: BaoBlobSize,
626 valid_ranges: RangeSpec,
628 },
629 Connected,
631 Found {
633 id: u64,
635 child: BlobId,
640 hash: Hash,
642 size: u64,
644 },
645 FoundHashSeq {
647 hash: Hash,
649 children: u64,
651 },
652 Progress {
654 id: u64,
656 offset: u64,
658 },
659 Done {
661 id: u64,
663 },
664 AllDone(Stats),
668 Abort(serde_error::Error),
672}
673
674#[derive(
676 Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, std::hash::Hash, Serialize, Deserialize,
677)]
678pub enum BlobId {
679 Root,
681 Child(NonZeroU64),
683}
684
685impl BlobId {
686 fn from_offset(id: u64) -> Self {
687 NonZeroU64::new(id).map(Self::Child).unwrap_or(Self::Root)
688 }
689}
690
691impl From<BlobId> for u64 {
692 fn from(value: BlobId) -> Self {
693 match value {
694 BlobId::Root => 0,
695 BlobId::Child(id) => id.into(),
696 }
697 }
698}