1use std::{
7 fmt::Debug,
8 future::Future,
9 io,
10 time::{Duration, Instant},
11};
12
13use anyhow::Result;
14use bao_tree::ChunkRanges;
15use iroh::endpoint::{self, VarInt};
16use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
17use n0_future::StreamExt;
18use quinn::ConnectionError;
19use serde::{Deserialize, Serialize};
20use snafu::Snafu;
21use tokio::select;
22use tracing::{debug, debug_span, Instrument};
23
24use crate::{
25 api::{
26 blobs::{Bitfield, WriteProgress},
27 ExportBaoError, ExportBaoResult, RequestError, Store,
28 },
29 hashseq::HashSeq,
30 protocol::{
31 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
32 },
33 provider::events::{
34 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
35 RequestTracker,
36 },
37 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
38 Hash,
39};
40pub mod events;
41use events::EventSender;
42
43type DefaultReader = iroh::endpoint::RecvStream;
44type DefaultWriter = iroh::endpoint::SendStream;
45
46#[derive(Debug, Serialize, Deserialize)]
48pub struct TransferStats {
49 pub payload_bytes_sent: u64,
51 pub other_bytes_sent: u64,
55 pub other_bytes_read: u64,
60 pub duration: Duration,
62}
63
64#[derive(Debug)]
66pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
67 t0: Instant,
68 connection_id: u64,
69 reader: R,
70 writer: W,
71 other_bytes_read: u64,
72 events: EventSender,
73}
74
75impl StreamPair {
76 pub async fn accept(
77 conn: &endpoint::Connection,
78 events: EventSender,
79 ) -> Result<Self, ConnectionError> {
80 let (writer, reader) = conn.accept_bi().await?;
81 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
82 }
83}
84
85impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
86 pub fn stream_id(&self) -> u64 {
87 self.reader.id()
88 }
89
90 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
91 Self {
92 t0: Instant::now(),
93 connection_id,
94 reader,
95 writer,
96 other_bytes_read: 0,
97 events,
98 }
99 }
100
101 pub async fn read_request(&mut self) -> Result<Request> {
110 let (res, size) = Request::read_async(&mut self.reader).await?;
111 self.other_bytes_read += size as u64;
112 Ok(res)
113 }
114
115 pub async fn into_writer(
117 mut self,
118 tracker: RequestTracker,
119 ) -> Result<ProgressWriter<W>, io::Error> {
120 self.reader.expect_eof().await?;
121 drop(self.reader);
122 Ok(ProgressWriter::new(
123 self.writer,
124 WriterContext {
125 t0: self.t0,
126 other_bytes_read: self.other_bytes_read,
127 payload_bytes_written: 0,
128 other_bytes_written: 0,
129 tracker,
130 },
131 ))
132 }
133
134 pub async fn into_reader(
135 mut self,
136 tracker: RequestTracker,
137 ) -> Result<ProgressReader<R>, io::Error> {
138 self.writer.sync().await?;
139 drop(self.writer);
140 Ok(ProgressReader {
141 inner: self.reader,
142 context: ReaderContext {
143 t0: self.t0,
144 other_bytes_read: self.other_bytes_read,
145 tracker,
146 },
147 })
148 }
149
150 pub async fn get_request(
151 &self,
152 f: impl FnOnce() -> GetRequest,
153 ) -> Result<RequestTracker, ProgressError> {
154 self.events
155 .request(f, self.connection_id, self.reader.id())
156 .await
157 }
158
159 pub async fn get_many_request(
160 &self,
161 f: impl FnOnce() -> GetManyRequest,
162 ) -> Result<RequestTracker, ProgressError> {
163 self.events
164 .request(f, self.connection_id, self.reader.id())
165 .await
166 }
167
168 pub async fn push_request(
169 &self,
170 f: impl FnOnce() -> PushRequest,
171 ) -> Result<RequestTracker, ProgressError> {
172 self.events
173 .request(f, self.connection_id, self.reader.id())
174 .await
175 }
176
177 pub async fn observe_request(
178 &self,
179 f: impl FnOnce() -> ObserveRequest,
180 ) -> Result<RequestTracker, ProgressError> {
181 self.events
182 .request(f, self.connection_id, self.reader.id())
183 .await
184 }
185
186 pub fn stats(&self) -> TransferStats {
187 TransferStats {
188 payload_bytes_sent: 0,
189 other_bytes_sent: 0,
190 other_bytes_read: self.other_bytes_read,
191 duration: self.t0.elapsed(),
192 }
193 }
194}
195
196#[derive(Debug)]
197struct ReaderContext {
198 t0: Instant,
200 other_bytes_read: u64,
202 tracker: RequestTracker,
204}
205
206impl ReaderContext {
207 fn stats(&self) -> TransferStats {
208 TransferStats {
209 payload_bytes_sent: 0,
210 other_bytes_sent: 0,
211 other_bytes_read: self.other_bytes_read,
212 duration: self.t0.elapsed(),
213 }
214 }
215}
216
217#[derive(Debug)]
218pub(crate) struct WriterContext {
219 t0: Instant,
221 other_bytes_read: u64,
223 payload_bytes_written: u64,
225 other_bytes_written: u64,
227 tracker: RequestTracker,
229}
230
231impl WriterContext {
232 fn stats(&self) -> TransferStats {
233 TransferStats {
234 payload_bytes_sent: self.payload_bytes_written,
235 other_bytes_sent: self.other_bytes_written,
236 other_bytes_read: self.other_bytes_read,
237 duration: self.t0.elapsed(),
238 }
239 }
240}
241
242impl WriteProgress for WriterContext {
243 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
244 let len = len as u64;
245 let end_offset = offset + len;
246 self.payload_bytes_written += len;
247 self.tracker.transfer_progress(len, end_offset).await
248 }
249
250 fn log_other_write(&mut self, len: usize) {
251 self.other_bytes_written += len as u64;
252 }
253
254 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
255 self.tracker.transfer_started(index, hash, size).await.ok();
256 }
257}
258
259#[derive(Debug)]
261pub struct ProgressWriter<W: SendStream = DefaultWriter> {
262 pub inner: W,
264 pub(crate) context: WriterContext,
265}
266
267impl<W: SendStream> ProgressWriter<W> {
268 fn new(inner: W, context: WriterContext) -> Self {
269 Self { inner, context }
270 }
271
272 async fn transfer_aborted(&self) {
273 self.context
274 .tracker
275 .transfer_aborted(|| Box::new(self.context.stats()))
276 .await
277 .ok();
278 }
279
280 async fn transfer_completed(&self) {
281 self.context
282 .tracker
283 .transfer_completed(|| Box::new(self.context.stats()))
284 .await
285 .ok();
286 }
287}
288
289pub async fn handle_connection(
291 connection: endpoint::Connection,
292 store: Store,
293 progress: EventSender,
294) {
295 let connection_id = connection.stable_id() as u64;
296 let span = debug_span!("connection", connection_id);
297 async move {
298 if let Err(cause) = progress
299 .client_connected(|| ClientConnected {
300 connection_id,
301 node_id: connection.remote_node_id().ok(),
302 })
303 .await
304 {
305 connection.close(cause.code(), cause.reason());
306 debug!("closing connection: {cause}");
307 return;
308 }
309 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
310 let span = debug_span!("stream", stream_id = %pair.stream_id());
311 let store = store.clone();
312 tokio::spawn(handle_stream(pair, store).instrument(span));
313 }
314 progress
315 .connection_closed(|| ConnectionClosed { connection_id })
316 .await
317 .ok();
318 }
319 .instrument(span)
320 .await
321}
322
323pub trait ErrorHandler {
325 type W: AsyncStreamWriter;
326 type R: AsyncStreamReader;
327 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
328 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
329}
330
331async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
332 pair: &mut StreamPair<R, W>,
333 r: Result<T, E>,
334) -> Result<T, E> {
335 match r {
336 Ok(x) => Ok(x),
337 Err(e) => {
338 pair.writer.reset(e.code()).ok();
339 Err(e)
340 }
341 }
342}
343async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
344 writer: &mut ProgressWriter<W>,
345 r: Result<T, E>,
346) -> Result<T, E> {
347 match r {
348 Ok(x) => {
349 writer.transfer_completed().await;
350 Ok(x)
351 }
352 Err(e) => {
353 writer.inner.reset(e.code()).ok();
354 writer.transfer_aborted().await;
355 Err(e)
356 }
357 }
358}
359async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
360 reader: &mut ProgressReader<R>,
361 r: Result<T, E>,
362) -> Result<T, E> {
363 match r {
364 Ok(x) => {
365 reader.transfer_completed().await;
366 Ok(x)
367 }
368 Err(e) => {
369 reader.inner.stop(e.code()).ok();
370 reader.transfer_aborted().await;
371 Err(e)
372 }
373 }
374}
375
376pub async fn handle_stream<R: RecvStream, W: SendStream>(
377 mut pair: StreamPair<R, W>,
378 store: Store,
379) -> anyhow::Result<()> {
380 let request = pair.read_request().await?;
381 match request {
382 Request::Get(request) => handle_get(pair, store, request).await?,
383 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
384 Request::Observe(request) => handle_observe(pair, store, request).await?,
385 Request::Push(request) => handle_push(pair, store, request).await?,
386 _ => {}
387 }
388 Ok(())
389}
390
391#[derive(Debug, Snafu)]
392#[snafu(module)]
393pub enum HandleGetError {
394 #[snafu(transparent)]
395 ExportBao {
396 source: ExportBaoError,
397 },
398 InvalidHashSeq,
399 InvalidOffset,
400}
401
402impl HasErrorCode for HandleGetError {
403 fn code(&self) -> VarInt {
404 match self {
405 HandleGetError::ExportBao {
406 source: ExportBaoError::ClientError { source, .. },
407 } => source.code(),
408 HandleGetError::InvalidHashSeq => ERR_INTERNAL,
409 HandleGetError::InvalidOffset => ERR_INTERNAL,
410 _ => ERR_INTERNAL,
411 }
412 }
413}
414
415async fn handle_get_impl<W: SendStream>(
419 store: Store,
420 request: GetRequest,
421 writer: &mut ProgressWriter<W>,
422) -> Result<(), HandleGetError> {
423 let hash = request.hash;
424 debug!(%hash, "get received request");
425 let mut hash_seq = None;
426 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
427 if offset == 0 {
428 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
429 } else {
430 let hash_seq = match &hash_seq {
436 Some(b) => b,
437 None => {
438 let bytes = store.get_bytes(hash).await?;
439 let hs =
440 HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
441 hash_seq = Some(hs);
442 hash_seq.as_ref().unwrap()
443 }
444 };
445 let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
446 let Some(hash) = hash_seq.get(o) else {
447 break;
448 };
449 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
450 }
451 }
452 writer
453 .inner
454 .sync()
455 .await
456 .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
457
458 Ok(())
459}
460
461pub async fn handle_get<R: RecvStream, W: SendStream>(
462 mut pair: StreamPair<R, W>,
463 store: Store,
464 request: GetRequest,
465) -> anyhow::Result<()> {
466 let res = pair.get_request(|| request.clone()).await;
467 let tracker = handle_read_request_result(&mut pair, res).await?;
468 let mut writer = pair.into_writer(tracker).await?;
469 let res = handle_get_impl(store, request, &mut writer).await;
470 handle_write_result(&mut writer, res).await?;
471 Ok(())
472}
473
474#[derive(Debug, Snafu)]
475pub enum HandleGetManyError {
476 #[snafu(transparent)]
477 ExportBao { source: ExportBaoError },
478}
479
480impl HasErrorCode for HandleGetManyError {
481 fn code(&self) -> VarInt {
482 match self {
483 Self::ExportBao {
484 source: ExportBaoError::ClientError { source, .. },
485 } => source.code(),
486 _ => ERR_INTERNAL,
487 }
488 }
489}
490
491async fn handle_get_many_impl<W: SendStream>(
495 store: Store,
496 request: GetManyRequest,
497 writer: &mut ProgressWriter<W>,
498) -> Result<(), HandleGetManyError> {
499 debug!("get_many received request");
500 let request_ranges = request.ranges.iter_infinite();
501 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
502 if !ranges.is_empty() {
503 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
504 }
505 }
506 Ok(())
507}
508
509pub async fn handle_get_many<R: RecvStream, W: SendStream>(
510 mut pair: StreamPair<R, W>,
511 store: Store,
512 request: GetManyRequest,
513) -> anyhow::Result<()> {
514 let res = pair.get_many_request(|| request.clone()).await;
515 let tracker = handle_read_request_result(&mut pair, res).await?;
516 let mut writer = pair.into_writer(tracker).await?;
517 let res = handle_get_many_impl(store, request, &mut writer).await;
518 handle_write_result(&mut writer, res).await?;
519 Ok(())
520}
521
522#[derive(Debug, Snafu)]
523pub enum HandlePushError {
524 #[snafu(transparent)]
525 ExportBao {
526 source: ExportBaoError,
527 },
528
529 InvalidHashSeq,
530
531 #[snafu(transparent)]
532 Request {
533 source: RequestError,
534 },
535}
536
537impl HasErrorCode for HandlePushError {
538 fn code(&self) -> VarInt {
539 match self {
540 Self::ExportBao {
541 source: ExportBaoError::ClientError { source, .. },
542 } => source.code(),
543 _ => ERR_INTERNAL,
544 }
545 }
546}
547
548async fn handle_push_impl<R: RecvStream>(
552 store: Store,
553 request: PushRequest,
554 reader: &mut ProgressReader<R>,
555) -> Result<(), HandlePushError> {
556 let hash = request.hash;
557 debug!(%hash, "push received request");
558 let mut request_ranges = request.ranges.iter_infinite();
559 let root_ranges = request_ranges.next().expect("infinite iterator");
560 if !root_ranges.is_empty() {
561 store
563 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
564 .await?;
565 }
566 if request.ranges.is_blob() {
567 debug!("push request complete");
568 return Ok(());
569 }
570 let hash_seq = store.get_bytes(hash).await?;
572 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
573 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
574 if child_ranges.is_empty() {
575 continue;
576 }
577 store
578 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
579 .await?;
580 }
581 Ok(())
582}
583
584pub async fn handle_push<R: RecvStream, W: SendStream>(
585 mut pair: StreamPair<R, W>,
586 store: Store,
587 request: PushRequest,
588) -> anyhow::Result<()> {
589 let res = pair.push_request(|| request.clone()).await;
590 let tracker = handle_read_request_result(&mut pair, res).await?;
591 let mut reader = pair.into_reader(tracker).await?;
592 let res = handle_push_impl(store, request, &mut reader).await;
593 handle_read_result(&mut reader, res).await?;
594 Ok(())
595}
596
597pub(crate) async fn send_blob<W: SendStream>(
599 store: &Store,
600 index: u64,
601 hash: Hash,
602 ranges: ChunkRanges,
603 writer: &mut ProgressWriter<W>,
604) -> ExportBaoResult<()> {
605 store
606 .export_bao(hash, ranges)
607 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
608 .await
609}
610
611#[derive(Debug, Snafu)]
612pub enum HandleObserveError {
613 ObserveStreamClosed,
614
615 #[snafu(transparent)]
616 RemoteClosed {
617 source: io::Error,
618 },
619}
620
621impl HasErrorCode for HandleObserveError {
622 fn code(&self) -> VarInt {
623 ERR_INTERNAL
624 }
625}
626
627async fn handle_observe_impl<W: SendStream>(
631 store: Store,
632 request: ObserveRequest,
633 writer: &mut ProgressWriter<W>,
634) -> std::result::Result<(), HandleObserveError> {
635 let mut stream = store
636 .observe(request.hash)
637 .stream()
638 .await
639 .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
640 let mut old = stream
641 .next()
642 .await
643 .ok_or(HandleObserveError::ObserveStreamClosed)?;
644 send_observe_item(writer, &old).await?;
646 loop {
648 select! {
649 new = stream.next() => {
650 let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
651 let diff = old.diff(&new);
652 if diff.is_empty() {
653 continue;
654 }
655 send_observe_item(writer, &diff).await?;
656 old = new;
657 }
658 _ = writer.inner.stopped() => {
659 debug!("observer closed");
660 break;
661 }
662 }
663 }
664 Ok(())
665}
666
667async fn send_observe_item<W: SendStream>(
668 writer: &mut ProgressWriter<W>,
669 item: &Bitfield,
670) -> io::Result<()> {
671 let item = ObserveItem::from(item);
672 let len = writer.inner.write_length_prefixed(item).await?;
673 writer.context.log_other_write(len);
674 Ok(())
675}
676
677pub async fn handle_observe<R: RecvStream, W: SendStream>(
678 mut pair: StreamPair<R, W>,
679 store: Store,
680 request: ObserveRequest,
681) -> anyhow::Result<()> {
682 let res = pair.observe_request(|| request.clone()).await;
683 let tracker = handle_read_request_result(&mut pair, res).await?;
684 let mut writer = pair.into_writer(tracker).await?;
685 let res = handle_observe_impl(store, request, &mut writer).await;
686 handle_write_result(&mut writer, res).await?;
687 Ok(())
688}
689
690pub struct ProgressReader<R: RecvStream = DefaultReader> {
691 inner: R,
692 context: ReaderContext,
693}
694
695impl<R: RecvStream> ProgressReader<R> {
696 async fn transfer_aborted(&self) {
697 self.context
698 .tracker
699 .transfer_aborted(|| Box::new(self.context.stats()))
700 .await
701 .ok();
702 }
703
704 async fn transfer_completed(&self) {
705 self.context
706 .tracker
707 .transfer_completed(|| Box::new(self.context.stats()))
708 .await
709 .ok();
710 }
711}