1use std::future::Future;
4use std::io;
5use std::num::NonZeroU64;
6
7use futures_lite::StreamExt;
8use iroh_base::hash::Hash;
9use iroh_base::rpc::RpcError;
10use serde::{Deserialize, Serialize};
11
12use crate::hashseq::parse_hash_seq;
13use crate::protocol::RangeSpec;
14use crate::store::BaoBatchWriter;
15use crate::store::BaoBlobSize;
16use crate::store::FallibleProgressBatchWriter;
17
18use crate::{
19 get::{
20 self,
21 error::GetError,
22 fsm::{AtBlobHeader, AtEndBlob, ConnectedNext, EndBlobNext},
23 progress::TransferState,
24 Stats,
25 },
26 protocol::{GetRequest, RangeSpecSeq},
27 store::{MapEntry, MapEntryMut, MapMut, Store as BaoStore},
28 util::progress::{IdGenerator, ProgressSender},
29 BlobFormat, HashAndFormat,
30};
31use anyhow::anyhow;
32use bao_tree::{ChunkNum, ChunkRanges};
33use iroh_io::AsyncSliceReader;
34use tracing::trace;
35
36pub async fn get_to_db<
45 D: BaoStore,
46 C: FnOnce() -> F,
47 F: Future<Output = anyhow::Result<quinn::Connection>>,
48>(
49 db: &D,
50 get_conn: C,
51 hash_and_format: &HashAndFormat,
52 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
53) -> Result<Stats, GetError> {
54 let HashAndFormat { hash, format } = hash_and_format;
55 match format {
56 BlobFormat::Raw => get_blob(db, get_conn, hash, sender).await,
57 BlobFormat::HashSeq => get_hash_seq(db, get_conn, hash, sender).await,
58 }
59}
60
61async fn get_blob<
66 D: BaoStore,
67 C: FnOnce() -> F,
68 F: Future<Output = anyhow::Result<quinn::Connection>>,
69>(
70 db: &D,
71 get_conn: C,
72 hash: &Hash,
73 progress: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
74) -> Result<Stats, GetError> {
75 let end = match db.get_mut(hash).await? {
76 Some(entry) if entry.is_complete() => {
77 tracing::info!("already got entire blob");
78 progress
79 .send(DownloadProgress::FoundLocal {
80 child: BlobId::Root,
81 hash: *hash,
82 size: entry.size(),
83 valid_ranges: RangeSpec::all(),
84 })
85 .await?;
86 return Ok(Stats::default());
87 }
88 Some(entry) => {
89 trace!("got partial data for {}", hash);
90 let valid_ranges = valid_ranges::<D>(&entry)
91 .await
92 .ok()
93 .unwrap_or_else(ChunkRanges::all);
94 progress
95 .send(DownloadProgress::FoundLocal {
96 child: BlobId::Root,
97 hash: *hash,
98 size: entry.size(),
99 valid_ranges: RangeSpec::new(&valid_ranges),
100 })
101 .await?;
102 let required_ranges: ChunkRanges = ChunkRanges::all().difference(&valid_ranges);
103
104 let request = GetRequest::new(*hash, RangeSpecSeq::from_ranges([required_ranges]));
105 let conn = get_conn().await.map_err(GetError::Io)?;
107 let request = get::fsm::start(conn, request);
108 let connected = request.next().await?;
110 let ConnectedNext::StartRoot(start) = connected.next().await? else {
112 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
113 };
114 let header = start.next();
116 get_blob_inner_partial(db, header, entry, progress).await?
119 }
120 None => {
121 let conn = get_conn().await.map_err(GetError::Io)?;
123 let request = get::fsm::start(conn, GetRequest::single(*hash));
124 let connected = request.next().await?;
126 let ConnectedNext::StartRoot(start) = connected.next().await? else {
128 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
129 };
130 let header = start.next();
132 get_blob_inner(db, header, progress).await?
134 }
135 };
136
137 let EndBlobNext::Closing(end) = end.next() else {
139 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
140 };
141 let stats = end.next().await?;
143 Ok(stats)
144}
145
146pub async fn valid_ranges<D: MapMut>(entry: &D::EntryMut) -> anyhow::Result<ChunkRanges> {
148 use tracing::trace as log;
149 let mut data_reader = entry.data_reader().await?;
151 let data_size = data_reader.size().await?;
152 let valid_from_data = ChunkRanges::from(..ChunkNum::full_chunks(data_size));
153 let mut outboard = entry.outboard().await?;
155 let all = ChunkRanges::all();
156 let mut stream = bao_tree::io::fsm::valid_outboard_ranges(&mut outboard, &all);
157 let mut valid_from_outboard = ChunkRanges::empty();
158 while let Some(range) = stream.next().await {
159 valid_from_outboard |= ChunkRanges::from(range?);
160 }
161 let valid: ChunkRanges = valid_from_data.intersection(&valid_from_outboard);
162 log!("valid_from_data: {:?}", valid_from_data);
163 log!("valid_from_outboard: {:?}", valid_from_data);
164 Ok(valid)
165}
166
167async fn get_blob_inner<D: BaoStore>(
172 db: &D,
173 at_header: AtBlobHeader,
174 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
175) -> Result<AtEndBlob, GetError> {
176 let (at_content, size) = at_header.next().await?;
179 let hash = at_content.hash();
180 let child_offset = at_content.offset();
181 let entry = db.get_or_create(hash, size).await?;
183 let bw = entry.batch_writer().await?;
185 let id = sender.new_id();
187 sender
188 .send(DownloadProgress::Found {
189 id,
190 hash,
191 size,
192 child: BlobId::from_offset(child_offset),
193 })
194 .await?;
195 let sender2 = sender.clone();
196 let on_write = move |offset: u64, _length: usize| {
197 sender2
200 .try_send(DownloadProgress::Progress { id, offset })
201 .map_err(|e| {
202 tracing::info!("aborting download of {}", hash);
203 e
204 })?;
205 Ok(())
206 };
207 let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
208 let end = at_content.write_all_batch(&mut bw).await?;
210 bw.sync().await?;
212 drop(bw);
213 db.insert_complete(entry).await?;
214 sender.send(DownloadProgress::Done { id }).await?;
216 Ok(end)
217}
218
219async fn get_blob_inner_partial<D: BaoStore>(
224 db: &D,
225 at_header: AtBlobHeader,
226 entry: D::EntryMut,
227 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
228) -> Result<AtEndBlob, GetError> {
229 let (at_content, size) = at_header.next().await?;
232 let bw = entry.batch_writer().await?;
234 let id = sender.new_id();
236 let hash = at_content.hash();
237 let child_offset = at_content.offset();
238 sender
239 .send(DownloadProgress::Found {
240 id,
241 hash,
242 size,
243 child: BlobId::from_offset(child_offset),
244 })
245 .await?;
246 let sender2 = sender.clone();
247 let on_write = move |offset: u64, _length: usize| {
248 sender2
251 .try_send(DownloadProgress::Progress { id, offset })
252 .map_err(|e| {
253 tracing::info!("aborting download of {}", hash);
254 e
255 })?;
256 Ok(())
257 };
258 let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
259 let at_end = at_content.write_all_batch(&mut bw).await?;
261 bw.sync().await?;
263 drop(bw);
264 db.insert_complete(entry).await?;
269 sender.send(DownloadProgress::Done { id }).await?;
271 Ok(at_end)
272}
273
274pub async fn blob_info<D: BaoStore>(db: &D, hash: &Hash) -> io::Result<BlobInfo<D>> {
278 io::Result::Ok(match db.get_mut(hash).await? {
279 Some(entry) if entry.is_complete() => BlobInfo::Complete {
280 size: entry.size().value(),
281 },
282 Some(entry) => {
283 let valid_ranges = valid_ranges::<D>(&entry)
284 .await
285 .ok()
286 .unwrap_or_else(ChunkRanges::all);
287 BlobInfo::Partial {
288 entry,
289 valid_ranges,
290 }
291 }
292 None => BlobInfo::Missing,
293 })
294}
295
296async fn blob_infos<D: BaoStore>(db: &D, hash_seq: &[Hash]) -> io::Result<Vec<BlobInfo<D>>> {
298 let items = futures_lite::stream::iter(hash_seq)
299 .then(|hash| blob_info(db, hash))
300 .collect::<Vec<_>>();
301 items.await.into_iter().collect()
302}
303
304async fn get_hash_seq<
306 D: BaoStore,
307 C: FnOnce() -> F,
308 F: Future<Output = anyhow::Result<quinn::Connection>>,
309>(
310 db: &D,
311 get_conn: C,
312 root_hash: &Hash,
313 sender: impl ProgressSender<Msg = DownloadProgress> + IdGenerator,
314) -> Result<Stats, GetError> {
315 use tracing::info as log;
316 let finishing = match db.get_mut(root_hash).await? {
317 Some(entry) if entry.is_complete() => {
318 log!("already got collection - doing partial download");
319 sender
321 .send(DownloadProgress::FoundLocal {
322 child: BlobId::Root,
323 hash: *root_hash,
324 size: entry.size(),
325 valid_ranges: RangeSpec::all(),
326 })
327 .await?;
328 let reader = entry.data_reader().await?;
330 let (mut hash_seq, children) = parse_hash_seq(reader).await.map_err(|err| {
331 GetError::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
332 })?;
333 sender
334 .send(DownloadProgress::FoundHashSeq {
335 hash: *root_hash,
336 children,
337 })
338 .await?;
339 let mut children: Vec<Hash> = vec![];
340 while let Some(hash) = hash_seq.next().await? {
341 children.push(hash);
342 }
343 let missing_info = blob_infos(db, &children).await?;
344 for (i, info) in missing_info.iter().enumerate() {
346 if let Some(size) = info.size() {
347 sender
348 .send(DownloadProgress::FoundLocal {
349 child: BlobId::from_offset((i as u64) + 1),
350 hash: children[i],
351 size,
352 valid_ranges: RangeSpec::new(&info.valid_ranges()),
353 })
354 .await?;
355 }
356 }
357 if missing_info
358 .iter()
359 .all(|x| matches!(x, BlobInfo::Complete { .. }))
360 {
361 log!("nothing to do");
362 return Ok(Stats::default());
363 }
364
365 let missing_iter = std::iter::once(ChunkRanges::empty())
366 .chain(missing_info.iter().map(|x| x.missing_ranges()))
367 .collect::<Vec<_>>();
368 log!("requesting chunks {:?}", missing_iter);
369 let request = GetRequest::new(*root_hash, RangeSpecSeq::from_ranges(missing_iter));
370 let conn = get_conn().await.map_err(GetError::Io)?;
371 let request = get::fsm::start(conn, request);
372 let connected = request.next().await?;
374 log!("connected");
375 let ConnectedNext::StartChild(start) = connected.next().await? else {
377 return Err(GetError::NoncompliantNode(anyhow!("expected StartChild")));
378 };
379 let mut next = EndBlobNext::MoreChildren(start);
380 loop {
382 let start = match next {
383 EndBlobNext::MoreChildren(start) => start,
384 EndBlobNext::Closing(finish) => break finish,
385 };
386 let child_offset = usize::try_from(start.child_offset())
387 .map_err(|_| GetError::NoncompliantNode(anyhow!("child offset too large")))?;
388 let (child_hash, info) =
389 match (children.get(child_offset), missing_info.get(child_offset)) {
390 (Some(blob), Some(info)) => (*blob, info),
391 _ => break start.finish(),
392 };
393 tracing::info!(
394 "requesting child {} {:?}",
395 child_hash,
396 info.missing_ranges()
397 );
398 let header = start.next(child_hash);
399 let end_blob = match info {
400 BlobInfo::Missing => get_blob_inner(db, header, sender.clone()).await?,
401 BlobInfo::Partial { entry, .. } => {
402 get_blob_inner_partial(db, header, entry.clone(), sender.clone()).await?
403 }
404 BlobInfo::Complete { .. } => {
405 return Err(GetError::NoncompliantNode(anyhow!(
406 "got data we have not requested"
407 )));
408 }
409 };
410 next = end_blob.next();
411 }
412 }
413 _ => {
414 tracing::debug!("don't have collection - doing full download");
415 let conn = get_conn().await.map_err(GetError::Io)?;
417 let request = get::fsm::start(conn, GetRequest::all(*root_hash));
418 let connected = request.next().await?;
420 let ConnectedNext::StartRoot(start) = connected.next().await? else {
422 return Err(GetError::NoncompliantNode(anyhow!("expected StartRoot")));
423 };
424 let header = start.next();
426 let end_root = get_blob_inner(db, header, sender.clone()).await?;
428 let entry = db
430 .get(root_hash)
431 .await?
432 .ok_or_else(|| GetError::LocalFailure(anyhow!("just downloaded but not in db")))?;
433 let reader = entry.data_reader().await?;
434 let (mut collection, count) = parse_hash_seq(reader).await.map_err(|err| {
435 GetError::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
436 })?;
437 sender
438 .send(DownloadProgress::FoundHashSeq {
439 hash: *root_hash,
440 children: count,
441 })
442 .await?;
443 let mut children = vec![];
444 while let Some(hash) = collection.next().await? {
445 children.push(hash);
446 }
447 let mut next = end_root.next();
448 loop {
450 let start = match next {
451 EndBlobNext::MoreChildren(start) => start,
452 EndBlobNext::Closing(finish) => break finish,
453 };
454 let child_offset = usize::try_from(start.child_offset())
455 .map_err(|_| GetError::NoncompliantNode(anyhow!("child offset too large")))?;
456
457 let child_hash = match children.get(child_offset) {
458 Some(blob) => *blob,
459 None => break start.finish(),
460 };
461 let header = start.next(child_hash);
462 let end_blob = get_blob_inner(db, header, sender.clone()).await?;
463 next = end_blob.next();
464 }
465 }
466 };
467 let stats = finishing.next().await?;
469 Ok(stats)
470}
471
472#[derive(Debug, Clone)]
474pub enum BlobInfo<D: BaoStore> {
475 Complete {
477 size: u64,
479 },
480 Partial {
482 entry: D::EntryMut,
484 valid_ranges: ChunkRanges,
486 },
487 Missing,
489}
490
491impl<D: BaoStore> BlobInfo<D> {
492 pub fn size(&self) -> Option<BaoBlobSize> {
494 match self {
495 BlobInfo::Complete { size } => Some(BaoBlobSize::Verified(*size)),
496 BlobInfo::Partial { entry, .. } => Some(entry.size()),
497 BlobInfo::Missing => None,
498 }
499 }
500
501 pub fn valid_ranges(&self) -> ChunkRanges {
506 match self {
507 BlobInfo::Complete { .. } => ChunkRanges::all(),
508 BlobInfo::Partial { valid_ranges, .. } => valid_ranges.clone(),
509 BlobInfo::Missing => ChunkRanges::empty(),
510 }
511 }
512
513 pub fn missing_ranges(&self) -> ChunkRanges {
518 match self {
519 BlobInfo::Complete { .. } => ChunkRanges::empty(),
520 BlobInfo::Partial { valid_ranges, .. } => ChunkRanges::all().difference(valid_ranges),
521 BlobInfo::Missing => ChunkRanges::all(),
522 }
523 }
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum DownloadProgress {
530 InitialState(TransferState),
532 FoundLocal {
534 child: BlobId,
536 hash: Hash,
538 size: BaoBlobSize,
540 valid_ranges: RangeSpec,
542 },
543 Connected,
545 Found {
547 id: u64,
549 child: BlobId,
554 hash: Hash,
556 size: u64,
558 },
559 FoundHashSeq {
561 hash: Hash,
563 children: u64,
565 },
566 Progress {
568 id: u64,
570 offset: u64,
572 },
573 Done {
575 id: u64,
577 },
578 AllDone(Stats),
582 Abort(RpcError),
586}
587
588#[derive(
590 Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, std::hash::Hash, Serialize, Deserialize,
591)]
592pub enum BlobId {
593 Root,
595 Child(NonZeroU64),
597}
598
599impl BlobId {
600 fn from_offset(id: u64) -> Self {
601 NonZeroU64::new(id).map(Self::Child).unwrap_or(Self::Root)
602 }
603}
604
605impl From<BlobId> for u64 {
606 fn from(value: BlobId) -> Self {
607 match value {
608 BlobId::Root => 0,
609 BlobId::Child(id) => id.into(),
610 }
611 }
612}