1use std::{
7 fmt::Debug,
8 io,
9 ops::{Deref, DerefMut},
10 pin::Pin,
11 task::Poll,
12 time::Duration,
13};
14
15use anyhow::{Context, Result};
16use bao_tree::ChunkRanges;
17use iroh::{
18 endpoint::{self, RecvStream, SendStream},
19 NodeId,
20};
21use irpc::channel::oneshot;
22use n0_future::StreamExt;
23use serde::de::DeserializeOwned;
24use tokio::{io::AsyncRead, select, sync::mpsc};
25use tracing::{debug, debug_span, error, warn, Instrument};
26
27use crate::{
28 api::{self, blobs::Bitfield, Store},
29 hashseq::HashSeq,
30 protocol::{
31 ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest,
32 Request,
33 },
34 Hash,
35};
36
37#[derive(Debug)]
43pub enum Event {
44 ClientConnected {
46 connection_id: u64,
47 node_id: NodeId,
48 permitted: oneshot::Sender<bool>,
49 },
50 ConnectionClosed { connection_id: u64 },
52 GetRequestReceived {
54 connection_id: u64,
56 request_id: u64,
58 hash: Hash,
60 ranges: ChunkRangesSeq,
62 },
63 GetManyRequestReceived {
65 connection_id: u64,
67 request_id: u64,
69 hashes: Vec<Hash>,
71 ranges: ChunkRangesSeq,
73 },
74 PushRequestReceived {
76 connection_id: u64,
78 request_id: u64,
80 hash: Hash,
82 ranges: ChunkRangesSeq,
84 permitted: oneshot::Sender<bool>,
86 },
87 TransferStarted {
89 connection_id: u64,
91 request_id: u64,
93 index: u64,
95 hash: Hash,
97 size: u64,
99 },
100 TransferProgress {
102 connection_id: u64,
104 request_id: u64,
106 index: u64,
108 end_offset: u64,
110 },
111 TransferCompleted {
113 connection_id: u64,
115 request_id: u64,
117 stats: Box<TransferStats>,
119 },
120 TransferAborted {
122 connection_id: u64,
124 request_id: u64,
126 stats: Option<Box<TransferStats>>,
128 },
129}
130
131#[derive(Debug)]
133pub struct TransferStats {
134 pub payload_bytes_sent: u64,
136 pub other_bytes_sent: u64,
140 pub bytes_read: u64,
144 pub duration: Duration,
146}
147
148pub async fn read_request(reader: &mut ProgressReader) -> Result<Request> {
157 let mut counting = CountingReader::new(&mut reader.inner);
158 let res = Request::read_async(&mut counting).await?;
159 reader.bytes_read += counting.read();
160 Ok(res)
161}
162
163#[derive(Debug)]
164pub struct StreamContext {
165 pub connection_id: u64,
167 pub request_id: u64,
169 pub payload_bytes_sent: u64,
171 pub other_bytes_sent: u64,
173 pub bytes_read: u64,
175 pub progress: EventSender,
177}
178
179#[derive(Debug)]
181pub struct ProgressWriter {
182 pub inner: SendStream,
184 pub(crate) context: StreamContext,
185}
186
187impl Deref for ProgressWriter {
188 type Target = StreamContext;
189
190 fn deref(&self) -> &Self::Target {
191 &self.context
192 }
193}
194
195impl DerefMut for ProgressWriter {
196 fn deref_mut(&mut self) -> &mut Self::Target {
197 &mut self.context
198 }
199}
200
201impl StreamContext {
202 pub fn log_other_write(&mut self, len: usize) {
204 self.other_bytes_sent += len as u64;
205 }
206
207 pub async fn send_transfer_completed(&mut self) {
208 self.progress
209 .send(|| Event::TransferCompleted {
210 connection_id: self.connection_id,
211 request_id: self.request_id,
212 stats: Box::new(TransferStats {
213 payload_bytes_sent: self.payload_bytes_sent,
214 other_bytes_sent: self.other_bytes_sent,
215 bytes_read: self.bytes_read,
216 duration: Duration::ZERO,
217 }),
218 })
219 .await;
220 }
221
222 pub async fn send_transfer_aborted(&mut self) {
223 self.progress
224 .send(|| Event::TransferAborted {
225 connection_id: self.connection_id,
226 request_id: self.request_id,
227 stats: Some(Box::new(TransferStats {
228 payload_bytes_sent: self.payload_bytes_sent,
229 other_bytes_sent: self.other_bytes_sent,
230 bytes_read: self.bytes_read,
231 duration: Duration::ZERO,
232 })),
233 })
234 .await;
235 }
236
237 pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) {
243 self.payload_bytes_sent += len as u64;
244 self.progress.try_send(|| Event::TransferProgress {
245 connection_id: self.connection_id,
246 request_id: self.request_id,
247 index,
248 end_offset: offset + len as u64,
249 });
250 }
251
252 pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) {
257 self.progress
258 .send(|| Event::GetRequestReceived {
259 connection_id: self.connection_id,
260 request_id: self.request_id,
261 hash: *hash,
262 ranges: ranges.clone(),
263 })
264 .await;
265 }
266
267 pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) {
272 self.progress
273 .send(|| Event::GetManyRequestReceived {
274 connection_id: self.connection_id,
275 request_id: self.request_id,
276 hashes: hashes.to_vec(),
277 ranges: ranges.clone(),
278 })
279 .await;
280 }
281
282 #[must_use = "permit should be checked by the caller"]
290 pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool {
291 let mut wait_for_permit = None;
292 self.progress
294 .send(|| {
295 let (tx, rx) = oneshot::channel();
296 wait_for_permit = Some(rx);
297 Event::PushRequestReceived {
298 connection_id: self.connection_id,
299 request_id: self.request_id,
300 hash: *hash,
301 ranges: ranges.clone(),
302 permitted: tx,
303 }
304 })
305 .await;
306 if let Some(wait_for_permit) = wait_for_permit {
308 wait_for_permit.await.unwrap_or(false)
311 } else {
312 false
313 }
314 }
315
316 pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) {
318 self.progress
319 .send(|| Event::TransferStarted {
320 connection_id: self.connection_id,
321 request_id: self.request_id,
322 index,
323 hash: *hash,
324 size,
325 })
326 .await;
327 }
328}
329
330pub async fn handle_connection(
332 connection: endpoint::Connection,
333 store: Store,
334 progress: EventSender,
335) {
336 let connection_id = connection.stable_id() as u64;
337 let span = debug_span!("connection", connection_id);
338 async move {
339 let Ok(node_id) = connection.remote_node_id() else {
340 warn!("failed to get node id");
341 return;
342 };
343 if !progress
344 .authorize_client_connection(connection_id, node_id)
345 .await
346 {
347 debug!("client not authorized to connect");
348 return;
349 }
350 while let Ok((writer, reader)) = connection.accept_bi().await {
351 let request_id = reader.id().index();
354 let span = debug_span!("stream", stream_id = %request_id);
355 let store = store.clone();
356 let mut writer = ProgressWriter {
357 inner: writer,
358 context: StreamContext {
359 connection_id,
360 request_id,
361 payload_bytes_sent: 0,
362 other_bytes_sent: 0,
363 bytes_read: 0,
364 progress: progress.clone(),
365 },
366 };
367 tokio::spawn(
368 async move {
369 match handle_stream(store, reader, &mut writer).await {
370 Ok(()) => {
371 writer.send_transfer_completed().await;
372 }
373 Err(err) => {
374 warn!("error: {err:#?}",);
375 writer.send_transfer_aborted().await;
376 }
377 }
378 }
379 .instrument(span),
380 );
381 }
382 progress
383 .send(Event::ConnectionClosed { connection_id })
384 .await;
385 }
386 .instrument(span)
387 .await
388}
389
390async fn handle_stream(
391 store: Store,
392 reader: RecvStream,
393 writer: &mut ProgressWriter,
394) -> Result<()> {
395 debug!("reading request");
397 let mut reader = ProgressReader {
398 inner: reader,
399 context: StreamContext {
400 connection_id: writer.connection_id,
401 request_id: writer.request_id,
402 payload_bytes_sent: 0,
403 other_bytes_sent: 0,
404 bytes_read: 0,
405 progress: writer.progress.clone(),
406 },
407 };
408 let request = match read_request(&mut reader).await {
409 Ok(request) => request,
410 Err(e) => {
411 return Err(e);
413 }
414 };
415
416 match request {
417 Request::Get(request) => {
418 reader.inner.read_to_end(0).await?;
420 writer.context = reader.context;
422 handle_get(store, request, writer).await
423 }
424 Request::GetMany(request) => {
425 reader.inner.read_to_end(0).await?;
427 writer.context = reader.context;
429 handle_get_many(store, request, writer).await
430 }
431 Request::Observe(request) => {
432 reader.inner.read_to_end(0).await?;
434 handle_observe(store, request, writer).await
435 }
436 Request::Push(request) => {
437 writer.inner.finish()?;
438 handle_push(store, request, reader).await
439 }
440 _ => anyhow::bail!("unsupported request: {request:?}"),
441 }
443}
444
445pub async fn handle_get(
449 store: Store,
450 request: GetRequest,
451 writer: &mut ProgressWriter,
452) -> Result<()> {
453 let hash = request.hash;
454 debug!(%hash, "get received request");
455
456 writer
457 .send_get_request_received(&hash, &request.ranges)
458 .await;
459 let mut hash_seq = None;
460 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
461 if offset == 0 {
462 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
463 } else {
464 let hash_seq = match &hash_seq {
470 Some(b) => b,
471 None => {
472 let bytes = store.get_bytes(hash).await?;
473 let hs = HashSeq::try_from(bytes)?;
474 hash_seq = Some(hs);
475 hash_seq.as_ref().unwrap()
476 }
477 };
478 let o = usize::try_from(offset - 1).context("offset too large")?;
479 let Some(hash) = hash_seq.get(o) else {
480 break;
481 };
482 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
483 }
484 }
485
486 Ok(())
487}
488
489pub async fn handle_get_many(
493 store: Store,
494 request: GetManyRequest,
495 writer: &mut ProgressWriter,
496) -> Result<()> {
497 debug!("get_many received request");
498 writer
499 .send_get_many_request_received(&request.hashes, &request.ranges)
500 .await;
501 let request_ranges = request.ranges.iter_infinite();
502 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
503 if !ranges.is_empty() {
504 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
505 }
506 }
507 Ok(())
508}
509
510pub async fn handle_push(
514 store: Store,
515 request: PushRequest,
516 mut reader: ProgressReader,
517) -> Result<()> {
518 let hash = request.hash;
519 debug!(%hash, "push received request");
520 if !reader.authorize_push_request(&hash, &request.ranges).await {
521 debug!("push request not authorized");
522 return Ok(());
523 };
524 let mut request_ranges = request.ranges.iter_infinite();
525 let root_ranges = request_ranges.next().expect("infinite iterator");
526 if !root_ranges.is_empty() {
527 store
529 .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner)
530 .await?;
531 }
532 if request.ranges.is_blob() {
533 debug!("push request complete");
534 return Ok(());
535 }
536 let hash_seq = store.get_bytes(hash).await?;
538 let hash_seq = HashSeq::try_from(hash_seq)?;
539 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
540 if child_ranges.is_empty() {
541 continue;
542 }
543 store
544 .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner)
545 .await?;
546 }
547 Ok(())
548}
549
550pub(crate) async fn send_blob(
552 store: &Store,
553 index: u64,
554 hash: Hash,
555 ranges: ChunkRanges,
556 writer: &mut ProgressWriter,
557) -> api::Result<()> {
558 Ok(store
559 .export_bao(hash, ranges)
560 .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
561 .await?)
562}
563
564pub async fn handle_observe(
568 store: Store,
569 request: ObserveRequest,
570 writer: &mut ProgressWriter,
571) -> Result<()> {
572 let mut stream = store.observe(request.hash).stream().await?;
573 let mut old = stream
574 .next()
575 .await
576 .ok_or(anyhow::anyhow!("observe stream closed before first value"))?;
577 send_observe_item(writer, &old).await?;
579 loop {
581 select! {
582 new = stream.next() => {
583 let new = new.context("observe stream closed")?;
584 let diff = old.diff(&new);
585 if diff.is_empty() {
586 continue;
587 }
588 send_observe_item(writer, &diff).await?;
589 old = new;
590 }
591 _ = writer.inner.stopped() => {
592 debug!("observer closed");
593 break;
594 }
595 }
596 }
597 Ok(())
598}
599
600async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> {
601 use irpc::util::AsyncWriteVarintExt;
602 let item = ObserveItem::from(item);
603 let len = writer.inner.write_length_prefixed(item).await?;
604 writer.log_other_write(len);
605 Ok(())
606}
607
608pub trait LazyEvent {
611 fn call(self) -> Event;
612}
613
614impl<T> LazyEvent for T
615where
616 T: FnOnce() -> Event,
617{
618 fn call(self) -> Event {
619 self()
620 }
621}
622
623impl LazyEvent for Event {
624 fn call(self) -> Event {
625 self
626 }
627}
628
629#[derive(Debug, Clone)]
631pub struct EventSender(EventSenderInner);
632
633#[derive(Debug, Clone)]
634enum EventSenderInner {
635 Disabled,
636 Enabled(mpsc::Sender<Event>),
637}
638
639impl EventSender {
640 pub fn new(sender: Option<mpsc::Sender<Event>>) -> Self {
641 match sender {
642 Some(sender) => Self(EventSenderInner::Enabled(sender)),
643 None => Self(EventSenderInner::Disabled),
644 }
645 }
646
647 #[must_use = "permit should be checked by the caller"]
651 pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool {
652 let mut wait_for_permit = None;
653 self.send(|| {
654 let (tx, rx) = oneshot::channel();
655 wait_for_permit = Some(rx);
656 Event::ClientConnected {
657 connection_id,
658 node_id,
659 permitted: tx,
660 }
661 })
662 .await;
663 if let Some(wait_for_permit) = wait_for_permit {
664 wait_for_permit.await.unwrap_or(false)
667 } else {
668 true
669 }
670 }
671
672 fn try_send(&self, event: impl LazyEvent) {
676 match &self.0 {
677 EventSenderInner::Enabled(sender) => {
678 let value = event.call();
679 sender.try_send(value).ok();
680 }
681 EventSenderInner::Disabled => {}
682 }
683 }
684
685 async fn send(&self, event: impl LazyEvent) {
689 match &self.0 {
690 EventSenderInner::Enabled(sender) => {
691 let value = event.call();
692 if let Err(err) = sender.send(value).await {
693 error!("failed to send progress event: {:?}", err);
694 }
695 }
696 EventSenderInner::Disabled => {}
697 }
698 }
699}
700
701pub struct ProgressReader {
702 inner: RecvStream,
703 context: StreamContext,
704}
705
706impl Deref for ProgressReader {
707 type Target = StreamContext;
708
709 fn deref(&self) -> &Self::Target {
710 &self.context
711 }
712}
713
714impl DerefMut for ProgressReader {
715 fn deref_mut(&mut self) -> &mut Self::Target {
716 &mut self.context
717 }
718}
719
720pub struct CountingReader<R> {
721 pub inner: R,
722 pub read: u64,
723}
724
725impl<R> CountingReader<R> {
726 pub fn new(inner: R) -> Self {
727 Self { inner, read: 0 }
728 }
729
730 pub fn read(&self) -> u64 {
731 self.read
732 }
733}
734
735impl CountingReader<&mut iroh::endpoint::RecvStream> {
736 pub async fn read_to_end_as<T: DeserializeOwned>(&mut self, max_size: usize) -> io::Result<T> {
737 let data = self
738 .inner
739 .read_to_end(max_size)
740 .await
741 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
742 let value = postcard::from_bytes(&data)
743 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
744 self.read += data.len() as u64;
745 Ok(value)
746 }
747}
748
749impl<R: AsyncRead + Unpin> AsyncRead for CountingReader<R> {
750 fn poll_read(
751 self: Pin<&mut Self>,
752 cx: &mut std::task::Context<'_>,
753 buf: &mut tokio::io::ReadBuf<'_>,
754 ) -> Poll<io::Result<()>> {
755 let this = self.get_mut();
756 let result = Pin::new(&mut this.inner).poll_read(cx, buf);
757 if let Poll::Ready(Ok(())) = result {
758 this.read += buf.filled().len() as u64;
759 }
760 result
761 }
762}