1use std::{fmt::Debug, sync::Arc, time::Duration};
3
4use anyhow::{Context, Result};
5use bao_tree::io::{
6 fsm::{encode_ranges_validated, Outboard},
7 EncodeError,
8};
9use futures_lite::future::Boxed as BoxFuture;
10use iroh::endpoint::{self, RecvStream, SendStream};
11use iroh_io::{
12 stats::{SliceReaderStats, StreamWriterStats, TrackingSliceReader, TrackingStreamWriter},
13 AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter,
14};
15use serde::{Deserialize, Serialize};
16use tracing::{debug, debug_span, info, trace, warn};
17use tracing_futures::Instrument;
18
19use crate::{
20 hashseq::parse_hash_seq,
21 protocol::{GetRequest, RangeSpec, Request},
22 store::*,
23 util::{local_pool::LocalPoolHandle, Tag},
24 BlobFormat, Hash,
25};
26
27#[derive(Debug, Clone)]
29pub enum Event {
30 TaggedBlobAdded {
32 hash: Hash,
34 format: BlobFormat,
36 tag: Tag,
38 },
39 ClientConnected {
41 connection_id: u64,
43 },
44 GetRequestReceived {
46 connection_id: u64,
48 request_id: u64,
50 hash: Hash,
52 },
53 TransferHashSeqStarted {
55 connection_id: u64,
57 request_id: u64,
59 num_blobs: u64,
61 },
62 TransferProgress {
67 connection_id: u64,
69 request_id: u64,
71 hash: Hash,
73 end_offset: u64,
75 },
76 TransferBlobCompleted {
78 connection_id: u64,
80 request_id: u64,
82 hash: Hash,
84 index: u64,
86 size: u64,
88 },
89 TransferCompleted {
91 connection_id: u64,
93 request_id: u64,
95 stats: Box<TransferStats>,
97 },
98 TransferAborted {
100 connection_id: u64,
102 request_id: u64,
104 stats: Option<Box<TransferStats>>,
107 },
108}
109
110#[derive(Debug, Clone, Copy, Default)]
112pub struct TransferStats {
113 pub send: StreamWriterStats,
115 pub read: SliceReaderStats,
117 pub duration: Duration,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
123pub enum AddProgress {
124 Found {
126 id: u64,
128 name: String,
130 size: u64,
132 },
133 Progress {
135 id: u64,
137 offset: u64,
139 },
140 Done {
142 id: u64,
144 hash: Hash,
146 },
147 AllDone {
149 hash: Hash,
151 format: BlobFormat,
153 tag: Tag,
155 },
156 Abort(serde_error::Error),
160}
161
162#[derive(Debug, Serialize, Deserialize)]
164pub enum BatchAddPathProgress {
165 Found {
167 size: u64,
169 },
170 Progress {
172 offset: u64,
174 },
175 Done {
177 hash: Hash,
179 },
180 Abort(serde_error::Error),
184}
185
186pub async fn read_request(mut reader: RecvStream) -> Result<Request> {
193 let payload = reader
194 .read_to_end(crate::protocol::MAX_MESSAGE_SIZE)
195 .await?;
196 let request: Request = postcard::from_bytes(&payload)?;
197 Ok(request)
198}
199
200pub(crate) async fn transfer_hash_seq<D: Map>(
215 request: GetRequest,
216 db: &D,
218 writer: &mut ResponseWriter,
220 mut outboard: impl Outboard,
222 mut data: impl AsyncSliceReader,
223 stats: &mut TransferStats,
224) -> Result<SentStatus> {
225 let hash = request.hash;
226 let events = writer.events.clone();
227 let request_id = writer.request_id();
228 let connection_id = writer.connection_id();
229
230 let just_root = matches!(request.ranges.as_single(), Some((0, _)));
232 let mut c = if !just_root {
233 let (stream, num_blobs) = parse_hash_seq(&mut data).await?;
235 writer
236 .events
237 .send(|| Event::TransferHashSeqStarted {
238 connection_id: writer.connection_id(),
239 request_id: writer.request_id(),
240 num_blobs,
241 })
242 .await;
243 Some(stream)
244 } else {
245 None
246 };
247
248 let mk_progress = |end_offset| Event::TransferProgress {
249 connection_id,
250 request_id,
251 hash,
252 end_offset,
253 };
254
255 let mut prev = 0;
256 for (offset, ranges) in request.ranges.iter_non_empty() {
257 let mut tw = writer.tracking_writer();
259 if offset == 0 {
260 debug!("writing ranges '{:?}' of sequence {}", ranges, hash);
261 let mut tracking_reader = TrackingSliceReader::new(&mut data);
263 let mut sending_reader =
264 SendingSliceReader::new(&mut tracking_reader, &events, mk_progress);
265 tw.write(outboard.tree().size().to_le_bytes().as_slice())
267 .await?;
268 encode_ranges_validated(
269 &mut sending_reader,
270 &mut outboard,
271 &ranges.to_chunk_ranges(),
272 &mut tw,
273 )
274 .await?;
275 stats.read += tracking_reader.stats();
276 stats.send += tw.stats();
277 debug!(
278 "finished writing ranges '{:?}' of collection {}",
279 ranges, hash
280 );
281 } else {
282 let c = c.as_mut().context("collection parser not available")?;
283 debug!("wrtiting ranges '{:?}' of child {}", ranges, offset);
284 if prev < offset - 1 {
286 c.skip(offset - prev - 1).await?;
287 }
288 if let Some(hash) = c.next().await? {
289 tokio::task::yield_now().await;
290 let (status, size, blob_read_stats) =
291 send_blob(db, hash, ranges, &mut tw, events.clone(), mk_progress).await?;
292 stats.send += tw.stats();
293 stats.read += blob_read_stats;
294 if SentStatus::NotFound == status {
295 writer.inner.finish()?;
296 return Ok(status);
297 }
298
299 writer
300 .events
301 .send(|| Event::TransferBlobCompleted {
302 connection_id: writer.connection_id(),
303 request_id: writer.request_id(),
304 hash,
305 index: offset - 1,
306 size,
307 })
308 .await;
309 } else {
310 break;
312 }
313 prev = offset;
314 }
315 }
316
317 debug!("done writing");
318 Ok(SentStatus::Sent)
319}
320
321struct SendingSliceReader<'a, R, F> {
322 inner: R,
323 sender: &'a EventSender,
324 make_event: F,
325}
326
327impl<'a, R: AsyncSliceReader, F: Fn(u64) -> Event> SendingSliceReader<'a, R, F> {
328 fn new(inner: R, sender: &'a EventSender, make_event: F) -> Self {
329 Self {
330 inner,
331 sender,
332 make_event,
333 }
334 }
335}
336
337impl<R: AsyncSliceReader, F: Fn(u64) -> Event> AsyncSliceReader for SendingSliceReader<'_, R, F> {
338 async fn read_at(&mut self, offset: u64, len: usize) -> std::io::Result<bytes::Bytes> {
339 let res = self.inner.read_at(offset, len).await;
340 if let Ok(res) = res.as_ref() {
341 let end_offset = offset + res.len() as u64;
342 self.sender.try_send(|| (self.make_event)(end_offset));
343 }
344 res
345 }
346
347 async fn size(&mut self) -> std::io::Result<u64> {
348 self.inner.size().await
349 }
350}
351
352pub trait CustomEventSender: std::fmt::Debug + Sync + Send + 'static {
354 fn send(&self, event: Event) -> BoxFuture<()>;
356
357 fn try_send(&self, event: Event);
359}
360
361#[derive(Debug, Clone, Default)]
365pub struct EventSender {
366 inner: Option<Arc<dyn CustomEventSender>>,
367}
368
369impl<T: CustomEventSender> From<T> for EventSender {
370 fn from(inner: T) -> Self {
371 Self {
372 inner: Some(Arc::new(inner)),
373 }
374 }
375}
376
377impl EventSender {
378 pub fn new(inner: Option<Arc<dyn CustomEventSender>>) -> Self {
380 Self { inner }
381 }
382
383 pub async fn send(&self, event: impl FnOnce() -> Event) {
389 if let Some(inner) = &self.inner {
390 let event = event();
391 inner.as_ref().send(event).await;
392 }
393 }
394
395 pub fn try_send(&self, event: impl FnOnce() -> Event) {
401 if let Some(inner) = &self.inner {
402 let event = event();
403 inner.as_ref().try_send(event);
404 }
405 }
406}
407
408pub async fn handle_connection<D: Map>(
410 connection: endpoint::Connection,
411 db: D,
412 events: EventSender,
413 rt: LocalPoolHandle,
414) {
415 let connection_id = connection.stable_id() as u64;
416 let span = debug_span!("connection", connection_id);
417 async move {
418 while let Ok((writer, reader)) = connection.accept_bi().await {
419 let request_id = reader.id().index();
422 let span = debug_span!("stream", stream_id = %request_id);
423 let writer = ResponseWriter {
424 connection_id,
425 events: events.clone(),
426 inner: writer,
427 };
428 events
429 .send(|| Event::ClientConnected { connection_id })
430 .await;
431 let db = db.clone();
432 rt.spawn_detached(|| {
433 async move {
434 if let Err(err) = handle_stream(db, reader, writer).await {
435 warn!("error: {err:#?}",);
436 }
437 }
438 .instrument(span)
439 });
440 }
441 }
442 .instrument(span)
443 .await
444}
445
446async fn handle_stream<D: Map>(db: D, reader: RecvStream, writer: ResponseWriter) -> Result<()> {
447 debug!("reading request");
449 let request = match read_request(reader).await {
450 Ok(r) => r,
451 Err(e) => {
452 writer.notify_transfer_aborted(None).await;
453 return Err(e);
454 }
455 };
456
457 match request {
458 Request::Get(request) => handle_get(db, request, writer).await,
459 }
460}
461
462pub async fn handle_get<D: Map>(
466 db: D,
467 request: GetRequest,
468 mut writer: ResponseWriter,
469) -> Result<()> {
470 let hash = request.hash;
471 debug!(%hash, "received request");
472 writer
473 .events
474 .send(|| Event::GetRequestReceived {
475 hash,
476 connection_id: writer.connection_id(),
477 request_id: writer.request_id(),
478 })
479 .await;
480
481 match db.get(&hash).await? {
483 Some(entry) => {
485 let mut stats = Box::<TransferStats>::default();
486 let t0 = std::time::Instant::now();
487 let res = transfer_hash_seq(
489 request,
490 &db,
491 &mut writer,
492 entry.outboard().await?,
493 entry.data_reader().await?,
494 &mut stats,
495 )
496 .await;
497 stats.duration = t0.elapsed();
498 match res {
499 Ok(SentStatus::Sent) => {
500 writer.notify_transfer_completed(&hash, stats).await;
501 }
502 Ok(SentStatus::NotFound) => {
503 writer.notify_transfer_aborted(Some(stats)).await;
504 }
505 Err(e) => {
506 writer.notify_transfer_aborted(Some(stats)).await;
507 return Err(e);
508 }
509 }
510
511 debug!("finished response");
512 }
513 None => {
514 debug!("not found {}", hash);
515 writer.notify_transfer_aborted(None).await;
516 writer.inner.finish()?;
517 }
518 };
519
520 Ok(())
521}
522
523#[derive(Debug)]
525pub struct ResponseWriter {
526 inner: SendStream,
527 events: EventSender,
528 connection_id: u64,
529}
530
531impl ResponseWriter {
532 fn tracking_writer(&mut self) -> TrackingStreamWriter<TokioStreamWriter<&mut SendStream>> {
533 TrackingStreamWriter::new(TokioStreamWriter(&mut self.inner))
534 }
535
536 fn connection_id(&self) -> u64 {
537 self.connection_id
538 }
539
540 fn request_id(&self) -> u64 {
541 self.inner.id().index()
542 }
543
544 fn print_stats(stats: &TransferStats) {
545 let send = stats.send.total();
546 let read = stats.read.total();
547 let total_sent_bytes = send.size;
548 let send_duration = send.stats.duration;
549 let read_duration = read.stats.duration;
550 let total_duration = stats.duration;
551 let other_duration = total_duration
552 .saturating_sub(send_duration)
553 .saturating_sub(read_duration);
554 let avg_send_size = total_sent_bytes.checked_div(send.stats.count).unwrap_or(0);
555 info!(
556 "sent {} bytes in {}s",
557 total_sent_bytes,
558 total_duration.as_secs_f64()
559 );
560 debug!(
561 "{}s sending, {}s reading, {}s other",
562 send_duration.as_secs_f64(),
563 read_duration.as_secs_f64(),
564 other_duration.as_secs_f64()
565 );
566 trace!(
567 "send_count: {} avg_send_size {}",
568 send.stats.count,
569 avg_send_size,
570 )
571 }
572
573 async fn notify_transfer_completed(&self, hash: &Hash, stats: Box<TransferStats>) {
574 info!("transfer completed for {}", hash);
575 Self::print_stats(&stats);
576 self.events
577 .send(move || Event::TransferCompleted {
578 connection_id: self.connection_id(),
579 request_id: self.request_id(),
580 stats,
581 })
582 .await;
583 }
584
585 async fn notify_transfer_aborted(&self, stats: Option<Box<TransferStats>>) {
586 if let Some(stats) = &stats {
587 Self::print_stats(stats);
588 };
589 self.events
590 .send(move || Event::TransferAborted {
591 connection_id: self.connection_id(),
592 request_id: self.request_id(),
593 stats,
594 })
595 .await;
596 }
597}
598
599#[derive(Clone, Debug, PartialEq, Eq)]
601pub enum SentStatus {
602 Sent,
604 NotFound,
606}
607
608pub async fn send_blob<D: Map, W: AsyncStreamWriter>(
610 db: &D,
611 hash: Hash,
612 ranges: &RangeSpec,
613 mut writer: W,
614 events: EventSender,
615 mk_progress: impl Fn(u64) -> Event,
616) -> Result<(SentStatus, u64, SliceReaderStats)> {
617 match db.get(&hash).await? {
618 Some(entry) => {
619 let outboard = entry.outboard().await?;
620 let size = outboard.tree().size();
621 let mut file_reader = TrackingSliceReader::new(entry.data_reader().await?);
622 let mut sending_reader =
623 SendingSliceReader::new(&mut file_reader, &events, mk_progress);
624 writer.write(size.to_le_bytes().as_slice()).await?;
625 encode_ranges_validated(
626 &mut sending_reader,
627 outboard,
628 &ranges.to_chunk_ranges(),
629 writer,
630 )
631 .await
632 .map_err(|e| encode_error_to_anyhow(e, &hash))?;
633
634 Ok((SentStatus::Sent, size, file_reader.stats()))
635 }
636 _ => {
637 debug!("blob not found {}", hash.to_hex());
638 Ok((SentStatus::NotFound, 0, SliceReaderStats::default()))
639 }
640 }
641}
642
643fn encode_error_to_anyhow(err: EncodeError, hash: &Hash) -> anyhow::Error {
644 match err {
645 EncodeError::LeafHashMismatch(x) => anyhow::Error::from(EncodeError::LeafHashMismatch(x))
646 .context(format!("hash {} offset {}", hash.to_hex(), x.to_bytes())),
647 EncodeError::ParentHashMismatch(n) => {
648 let r = n.chunk_range();
649 anyhow::Error::from(EncodeError::ParentHashMismatch(n)).context(format!(
650 "hash {} range {}..{}",
651 hash.to_hex(),
652 r.start.to_bytes(),
653 r.end.to_bytes()
654 ))
655 }
656 e => anyhow::Error::from(e).context(format!("hash {}", hash.to_hex())),
657 }
658}