1use std::fmt::Debug;
3use std::time::Duration;
4
5use anyhow::{Context, Result};
6use bao_tree::io::fsm::{encode_ranges_validated, Outboard};
7use bao_tree::io::EncodeError;
8use futures_lite::future::Boxed as BoxFuture;
9use iroh_base::rpc::RpcError;
10use iroh_io::stats::{
11 SliceReaderStats, StreamWriterStats, TrackingSliceReader, TrackingStreamWriter,
12};
13use iroh_io::{AsyncSliceReader, AsyncStreamWriter, TokioStreamWriter};
14use serde::{Deserialize, Serialize};
15use tokio_util::task::LocalPoolHandle;
16use tracing::{debug, debug_span, info, trace, warn};
17use tracing_futures::Instrument;
18
19use crate::hashseq::parse_hash_seq;
20use crate::protocol::{GetRequest, RangeSpec, Request};
21use crate::store::*;
22use crate::util::Tag;
23use crate::{BlobFormat, Hash};
24
25#[derive(Debug, Clone)]
27pub enum Event {
28 TaggedBlobAdded {
30 hash: Hash,
32 format: BlobFormat,
34 tag: Tag,
36 },
37 ClientConnected {
39 connection_id: u64,
41 },
42 GetRequestReceived {
44 connection_id: u64,
46 request_id: u64,
48 hash: Hash,
50 },
51 CustomGetRequestReceived {
53 connection_id: u64,
55 request_id: u64,
57 len: usize,
59 },
60 TransferHashSeqStarted {
62 connection_id: u64,
64 request_id: u64,
66 num_blobs: u64,
68 },
69 TransferBlobCompleted {
71 connection_id: u64,
73 request_id: u64,
75 hash: Hash,
77 index: u64,
79 size: u64,
81 },
82 TransferCompleted {
84 connection_id: u64,
86 request_id: u64,
88 stats: Box<TransferStats>,
90 },
91 TransferAborted {
93 connection_id: u64,
95 request_id: u64,
97 stats: Option<Box<TransferStats>>,
100 },
101}
102
103#[derive(Debug, Clone, Copy, Default)]
105pub struct TransferStats {
106 pub send: StreamWriterStats,
108 pub read: SliceReaderStats,
110 pub duration: Duration,
112}
113
114#[derive(Debug, Serialize, Deserialize)]
116pub enum AddProgress {
117 Found {
119 id: u64,
121 name: String,
123 size: u64,
125 },
126 Progress {
128 id: u64,
130 offset: u64,
132 },
133 Done {
135 id: u64,
137 hash: Hash,
139 },
140 AllDone {
142 hash: Hash,
144 format: BlobFormat,
146 tag: Tag,
148 },
149 Abort(RpcError),
153}
154
155pub async fn read_request(mut reader: quinn::RecvStream) -> Result<Request> {
162 let payload = reader
163 .read_to_end(crate::protocol::MAX_MESSAGE_SIZE)
164 .await?;
165 let request: Request = postcard::from_bytes(&payload)?;
166 Ok(request)
167}
168
169pub async fn transfer_collection<D: Map, E: EventSender>(
182 request: GetRequest,
183 db: &D,
185 writer: &mut ResponseWriter<E>,
187 mut outboard: impl Outboard,
189 mut data: impl AsyncSliceReader,
190 stats: &mut TransferStats,
191) -> Result<SentStatus> {
192 let hash = request.hash;
193
194 let just_root = matches!(request.ranges.as_single(), Some((0, _)));
196 let mut c = if !just_root {
197 let (stream, num_blobs) = parse_hash_seq(&mut data).await?;
199 writer
200 .events
201 .send(Event::TransferHashSeqStarted {
202 connection_id: writer.connection_id(),
203 request_id: writer.request_id(),
204 num_blobs,
205 })
206 .await;
207 Some(stream)
208 } else {
209 None
210 };
211
212 let mut prev = 0;
213 for (offset, ranges) in request.ranges.iter_non_empty() {
214 let mut tw = writer.tracking_writer();
216 if offset == 0 {
217 debug!("writing ranges '{:?}' of sequence {}", ranges, hash);
218 let mut tracking_reader = TrackingSliceReader::new(&mut data);
220 tw.write(outboard.tree().size().to_le_bytes().as_slice())
222 .await?;
223 encode_ranges_validated(
224 &mut tracking_reader,
225 &mut outboard,
226 &ranges.to_chunk_ranges(),
227 &mut tw,
228 )
229 .await?;
230 stats.read += tracking_reader.stats();
231 stats.send += tw.stats();
232 debug!(
233 "finished writing ranges '{:?}' of collection {}",
234 ranges, hash
235 );
236 } else {
237 let c = c.as_mut().context("collection parser not available")?;
238 debug!("wrtiting ranges '{:?}' of child {}", ranges, offset);
239 if prev < offset - 1 {
241 c.skip(offset - prev - 1).await?;
242 }
243 if let Some(hash) = c.next().await? {
244 tokio::task::yield_now().await;
245 let (status, size, blob_read_stats) = send_blob(db, hash, ranges, &mut tw).await?;
246 stats.send += tw.stats();
247 stats.read += blob_read_stats;
248 if SentStatus::NotFound == status {
249 writer.inner.finish().await?;
250 return Ok(status);
251 }
252
253 writer
254 .events
255 .send(Event::TransferBlobCompleted {
256 connection_id: writer.connection_id(),
257 request_id: writer.request_id(),
258 hash,
259 index: offset - 1,
260 size,
261 })
262 .await;
263 } else {
264 break;
266 }
267 prev = offset;
268 }
269 }
270
271 debug!("done writing");
272 Ok(SentStatus::Sent)
273}
274
275pub trait EventSender: Clone + Sync + Send + 'static {
277 fn send(&self, event: Event) -> BoxFuture<()>;
279}
280
281pub async fn handle_connection<D: Map, E: EventSender>(
283 connecting: quinn::Connecting,
284 db: D,
285 events: E,
286 rt: LocalPoolHandle,
287) {
288 let remote_addr = connecting.remote_address();
289 let connection = match connecting.await {
290 Ok(conn) => conn,
291 Err(err) => {
292 warn!(%remote_addr, "Error connecting: {err:#}");
293 return;
294 }
295 };
296 let connection_id = connection.stable_id() as u64;
297 let span = debug_span!("connection", connection_id, %remote_addr);
298 async move {
299 while let Ok((writer, reader)) = connection.accept_bi().await {
300 let request_id = reader.id().index();
303 let span = debug_span!("stream", stream_id = %request_id);
304 let writer = ResponseWriter {
305 connection_id,
306 events: events.clone(),
307 inner: writer,
308 };
309 events.send(Event::ClientConnected { connection_id }).await;
310 let db = db.clone();
311 rt.spawn_pinned(|| {
312 async move {
313 if let Err(err) = handle_stream(db, reader, writer).await {
314 warn!("error: {err:#?}",);
315 }
316 }
317 .instrument(span)
318 });
319 }
320 }
321 .instrument(span)
322 .await
323}
324
325async fn handle_stream<D: Map, E: EventSender>(
326 db: D,
327 reader: quinn::RecvStream,
328 writer: ResponseWriter<E>,
329) -> Result<()> {
330 debug!("reading request");
332 let request = match read_request(reader).await {
333 Ok(r) => r,
334 Err(e) => {
335 writer.notify_transfer_aborted(None).await;
336 return Err(e);
337 }
338 };
339
340 match request {
341 Request::Get(request) => handle_get(db, request, writer).await,
342 }
343}
344
345pub async fn handle_get<D: Map, E: EventSender>(
347 db: D,
348 request: GetRequest,
349 mut writer: ResponseWriter<E>,
350) -> Result<()> {
351 let hash = request.hash;
352 debug!(%hash, "received request");
353 writer
354 .events
355 .send(Event::GetRequestReceived {
356 hash,
357 connection_id: writer.connection_id(),
358 request_id: writer.request_id(),
359 })
360 .await;
361
362 match db.get(&hash).await? {
364 Some(entry) => {
366 let mut stats = Box::<TransferStats>::default();
367 let t0 = std::time::Instant::now();
368 let res = transfer_collection(
370 request,
371 &db,
372 &mut writer,
373 entry.outboard().await?,
374 entry.data_reader().await?,
375 &mut stats,
376 )
377 .await;
378 stats.duration = t0.elapsed();
379 match res {
380 Ok(SentStatus::Sent) => {
381 writer.notify_transfer_completed(&hash, stats).await;
382 }
383 Ok(SentStatus::NotFound) => {
384 writer.notify_transfer_aborted(Some(stats)).await;
385 }
386 Err(e) => {
387 writer.notify_transfer_aborted(Some(stats)).await;
388 return Err(e);
389 }
390 }
391
392 debug!("finished response");
393 }
394 None => {
395 debug!("not found {}", hash);
396 writer.notify_transfer_aborted(None).await;
397 writer.inner.finish().await?;
398 }
399 };
400
401 Ok(())
402}
403
404#[derive(Debug)]
406pub struct ResponseWriter<E> {
407 inner: quinn::SendStream,
408 events: E,
409 connection_id: u64,
410}
411
412impl<E: EventSender> ResponseWriter<E> {
413 fn tracking_writer(
414 &mut self,
415 ) -> TrackingStreamWriter<TokioStreamWriter<&mut quinn::SendStream>> {
416 TrackingStreamWriter::new(TokioStreamWriter(&mut self.inner))
417 }
418
419 fn connection_id(&self) -> u64 {
420 self.connection_id
421 }
422
423 fn request_id(&self) -> u64 {
424 self.inner.id().index()
425 }
426
427 fn print_stats(stats: &TransferStats) {
428 let send = stats.send.total();
429 let read = stats.read.total();
430 let total_sent_bytes = send.size;
431 let send_duration = send.stats.duration;
432 let read_duration = read.stats.duration;
433 let total_duration = stats.duration;
434 let other_duration = total_duration
435 .saturating_sub(send_duration)
436 .saturating_sub(read_duration);
437 let avg_send_size = total_sent_bytes.checked_div(send.stats.count).unwrap_or(0);
438 info!(
439 "sent {} bytes in {}s",
440 total_sent_bytes,
441 total_duration.as_secs_f64()
442 );
443 debug!(
444 "{}s sending, {}s reading, {}s other",
445 send_duration.as_secs_f64(),
446 read_duration.as_secs_f64(),
447 other_duration.as_secs_f64()
448 );
449 trace!(
450 "send_count: {} avg_send_size {}",
451 send.stats.count,
452 avg_send_size,
453 )
454 }
455
456 async fn notify_transfer_completed(&self, hash: &Hash, stats: Box<TransferStats>) {
457 info!("transfer completed for {}", hash);
458 Self::print_stats(&stats);
459 self.events
460 .send(Event::TransferCompleted {
461 connection_id: self.connection_id(),
462 request_id: self.request_id(),
463 stats,
464 })
465 .await;
466 }
467
468 async fn notify_transfer_aborted(&self, stats: Option<Box<TransferStats>>) {
469 if let Some(stats) = &stats {
470 Self::print_stats(stats);
471 };
472 self.events
473 .send(Event::TransferAborted {
474 connection_id: self.connection_id(),
475 request_id: self.request_id(),
476 stats,
477 })
478 .await;
479 }
480}
481
482#[derive(Clone, Debug, PartialEq, Eq)]
484pub enum SentStatus {
485 Sent,
487 NotFound,
489}
490
491pub async fn send_blob<D: Map, W: AsyncStreamWriter>(
493 db: &D,
494 hash: Hash,
495 ranges: &RangeSpec,
496 mut writer: W,
497) -> Result<(SentStatus, u64, SliceReaderStats)> {
498 match db.get(&hash).await? {
499 Some(entry) => {
500 let outboard = entry.outboard().await?;
501 let size = outboard.tree().size();
502 let mut file_reader = TrackingSliceReader::new(entry.data_reader().await?);
503 writer.write(size.to_le_bytes().as_slice()).await?;
504 encode_ranges_validated(
505 &mut file_reader,
506 outboard,
507 &ranges.to_chunk_ranges(),
508 writer,
509 )
510 .await
511 .map_err(|e| encode_error_to_anyhow(e, &hash))?;
512
513 Ok((SentStatus::Sent, size, file_reader.stats()))
514 }
515 _ => {
516 debug!("blob not found {}", hash.to_hex());
517 Ok((SentStatus::NotFound, 0, SliceReaderStats::default()))
518 }
519 }
520}
521
522fn encode_error_to_anyhow(err: EncodeError, hash: &Hash) -> anyhow::Error {
523 match err {
524 EncodeError::LeafHashMismatch(x) => anyhow::Error::from(EncodeError::LeafHashMismatch(x))
525 .context(format!("hash {} offset {}", hash.to_hex(), x.to_bytes())),
526 EncodeError::ParentHashMismatch(n) => {
527 let r = n.chunk_range();
528 anyhow::Error::from(EncodeError::ParentHashMismatch(n)).context(format!(
529 "hash {} range {}..{}",
530 hash.to_hex(),
531 r.start.to_bytes(),
532 r.end.to_bytes()
533 ))
534 }
535 e => anyhow::Error::from(e).context(format!("hash {}", hash.to_hex())),
536 }
537}