1use std::{
7 fmt::Debug,
8 io,
9 time::{Duration, Instant},
10};
11
12use anyhow::{Context, Result};
13use bao_tree::ChunkRanges;
14use iroh::endpoint::{self, RecvStream, SendStream};
15use n0_future::StreamExt;
16use quinn::{ClosedStream, ConnectionError, ReadToEndError};
17use serde::{de::DeserializeOwned, Deserialize, Serialize};
18use tokio::select;
19use tracing::{debug, debug_span, Instrument};
20
21use crate::{
22 api::{
23 blobs::{Bitfield, WriteProgress},
24 ExportBaoResult, Store,
25 },
26 hashseq::HashSeq,
27 protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request},
28 provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker},
29 Hash,
30};
31pub mod events;
32use events::EventSender;
33
34#[derive(Debug, Serialize, Deserialize)]
36pub struct TransferStats {
37 pub payload_bytes_sent: u64,
39 pub other_bytes_sent: u64,
43 pub other_bytes_read: u64,
48 pub duration: Duration,
50}
51
52#[derive(Debug)]
54pub struct StreamPair {
55 t0: Instant,
56 connection_id: u64,
57 request_id: u64,
58 reader: RecvStream,
59 writer: SendStream,
60 other_bytes_read: u64,
61 events: EventSender,
62}
63
64impl StreamPair {
65 pub async fn accept(
66 conn: &endpoint::Connection,
67 events: &EventSender,
68 ) -> Result<Self, ConnectionError> {
69 let (writer, reader) = conn.accept_bi().await?;
70 Ok(Self {
71 t0: Instant::now(),
72 connection_id: conn.stable_id() as u64,
73 request_id: reader.id().into(),
74 reader,
75 writer,
76 other_bytes_read: 0,
77 events: events.clone(),
78 })
79 }
80
81 pub async fn read_request(&mut self) -> Result<Request> {
90 let (res, size) = Request::read_async(&mut self.reader).await?;
91 self.other_bytes_read += size as u64;
92 Ok(res)
93 }
94
95 async fn into_writer(
97 mut self,
98 tracker: RequestTracker,
99 ) -> Result<ProgressWriter, ReadToEndError> {
100 let res = self.reader.read_to_end(0).await;
101 if let Err(e) = res {
102 tracker
103 .transfer_aborted(|| Box::new(self.stats()))
104 .await
105 .ok();
106 return Err(e);
107 };
108 Ok(ProgressWriter::new(
109 self.writer,
110 WriterContext {
111 t0: self.t0,
112 other_bytes_read: self.other_bytes_read,
113 payload_bytes_written: 0,
114 other_bytes_written: 0,
115 tracker,
116 },
117 ))
118 }
119
120 async fn into_reader(
121 mut self,
122 tracker: RequestTracker,
123 ) -> Result<ProgressReader, ClosedStream> {
124 let res = self.writer.finish();
125 if let Err(e) = res {
126 tracker
127 .transfer_aborted(|| Box::new(self.stats()))
128 .await
129 .ok();
130 return Err(e);
131 };
132 Ok(ProgressReader {
133 inner: self.reader,
134 context: ReaderContext {
135 t0: self.t0,
136 other_bytes_read: self.other_bytes_read,
137 tracker,
138 },
139 })
140 }
141
142 pub async fn get_request(
143 mut self,
144 f: impl FnOnce() -> GetRequest,
145 ) -> anyhow::Result<ProgressWriter> {
146 let res = self
147 .events
148 .request(f, self.connection_id, self.request_id)
149 .await;
150 match res {
151 Err(e) => {
152 self.writer.reset(e.code()).ok();
153 Err(e.into())
154 }
155 Ok(tracker) => Ok(self.into_writer(tracker).await?),
156 }
157 }
158
159 pub async fn get_many_request(
160 mut self,
161 f: impl FnOnce() -> GetManyRequest,
162 ) -> anyhow::Result<ProgressWriter> {
163 let res = self
164 .events
165 .request(f, self.connection_id, self.request_id)
166 .await;
167 match res {
168 Err(e) => {
169 self.writer.reset(e.code()).ok();
170 Err(e.into())
171 }
172 Ok(tracker) => Ok(self.into_writer(tracker).await?),
173 }
174 }
175
176 pub async fn push_request(
177 mut self,
178 f: impl FnOnce() -> PushRequest,
179 ) -> anyhow::Result<ProgressReader> {
180 let res = self
181 .events
182 .request(f, self.connection_id, self.request_id)
183 .await;
184 match res {
185 Err(e) => {
186 self.writer.reset(e.code()).ok();
187 Err(e.into())
188 }
189 Ok(tracker) => Ok(self.into_reader(tracker).await?),
190 }
191 }
192
193 pub async fn observe_request(
194 mut self,
195 f: impl FnOnce() -> ObserveRequest,
196 ) -> anyhow::Result<ProgressWriter> {
197 let res = self
198 .events
199 .request(f, self.connection_id, self.request_id)
200 .await;
201 match res {
202 Err(e) => {
203 self.writer.reset(e.code()).ok();
204 Err(e.into())
205 }
206 Ok(tracker) => Ok(self.into_writer(tracker).await?),
207 }
208 }
209
210 fn stats(&self) -> TransferStats {
211 TransferStats {
212 payload_bytes_sent: 0,
213 other_bytes_sent: 0,
214 other_bytes_read: self.other_bytes_read,
215 duration: self.t0.elapsed(),
216 }
217 }
218}
219
220#[derive(Debug)]
221struct ReaderContext {
222 t0: Instant,
224 other_bytes_read: u64,
226 tracker: RequestTracker,
228}
229
230impl ReaderContext {
231 fn stats(&self) -> TransferStats {
232 TransferStats {
233 payload_bytes_sent: 0,
234 other_bytes_sent: 0,
235 other_bytes_read: self.other_bytes_read,
236 duration: self.t0.elapsed(),
237 }
238 }
239}
240
241#[derive(Debug)]
242pub(crate) struct WriterContext {
243 t0: Instant,
245 other_bytes_read: u64,
247 payload_bytes_written: u64,
249 other_bytes_written: u64,
251 tracker: RequestTracker,
253}
254
255impl WriterContext {
256 fn stats(&self) -> TransferStats {
257 TransferStats {
258 payload_bytes_sent: self.payload_bytes_written,
259 other_bytes_sent: self.other_bytes_written,
260 other_bytes_read: self.other_bytes_read,
261 duration: self.t0.elapsed(),
262 }
263 }
264}
265
266impl WriteProgress for WriterContext {
267 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
268 let len = len as u64;
269 let end_offset = offset + len;
270 self.payload_bytes_written += len;
271 self.tracker.transfer_progress(len, end_offset).await
272 }
273
274 fn log_other_write(&mut self, len: usize) {
275 self.other_bytes_written += len as u64;
276 }
277
278 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
279 self.tracker.transfer_started(index, hash, size).await.ok();
280 }
281}
282
283#[derive(Debug)]
285pub struct ProgressWriter {
286 pub inner: SendStream,
288 pub(crate) context: WriterContext,
289}
290
291impl ProgressWriter {
292 fn new(inner: SendStream, context: WriterContext) -> Self {
293 Self { inner, context }
294 }
295
296 async fn transfer_aborted(&self) {
297 self.context
298 .tracker
299 .transfer_aborted(|| Box::new(self.context.stats()))
300 .await
301 .ok();
302 }
303
304 async fn transfer_completed(&self) {
305 self.context
306 .tracker
307 .transfer_completed(|| Box::new(self.context.stats()))
308 .await
309 .ok();
310 }
311}
312
313pub async fn handle_connection(
315 connection: endpoint::Connection,
316 store: Store,
317 progress: EventSender,
318) {
319 let connection_id = connection.stable_id() as u64;
320 let span = debug_span!("connection", connection_id);
321 async move {
322 if let Err(cause) = progress
323 .client_connected(|| ClientConnected {
324 connection_id,
325 node_id: connection.remote_node_id().ok(),
326 })
327 .await
328 {
329 connection.close(cause.code(), cause.reason());
330 debug!("closing connection: {cause}");
331 return;
332 }
333 while let Ok(context) = StreamPair::accept(&connection, &progress).await {
334 let span = debug_span!("stream", stream_id = %context.request_id);
335 let store = store.clone();
336 tokio::spawn(handle_stream(store, context).instrument(span));
337 }
338 progress
339 .connection_closed(|| ConnectionClosed { connection_id })
340 .await
341 .ok();
342 }
343 .instrument(span)
344 .await
345}
346
347async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> {
348 debug!("reading request");
350 let request = context.read_request().await?;
351
352 match request {
353 Request::Get(request) => {
354 let mut writer = context.get_request(|| request.clone()).await?;
355 let res = handle_get(store, request, &mut writer).await;
356 if res.is_ok() {
357 writer.transfer_completed().await;
358 } else {
359 writer.transfer_aborted().await;
360 }
361 }
362 Request::GetMany(request) => {
363 let mut writer = context.get_many_request(|| request.clone()).await?;
364 if handle_get_many(store, request, &mut writer).await.is_ok() {
365 writer.transfer_completed().await;
366 } else {
367 writer.transfer_aborted().await;
368 }
369 }
370 Request::Observe(request) => {
371 let mut writer = context.observe_request(|| request.clone()).await?;
372 if handle_observe(store, request, &mut writer).await.is_ok() {
373 writer.transfer_completed().await;
374 } else {
375 writer.transfer_aborted().await;
376 }
377 }
378 Request::Push(request) => {
379 let mut reader = context.push_request(|| request.clone()).await?;
380 if handle_push(store, request, &mut reader).await.is_ok() {
381 reader.transfer_completed().await;
382 } else {
383 reader.transfer_aborted().await;
384 }
385 }
386 _ => {}
387 }
388 Ok(())
389}
390
391pub async fn handle_get(
395 store: Store,
396 request: GetRequest,
397 writer: &mut ProgressWriter,
398) -> anyhow::Result<()> {
399 let hash = request.hash;
400 debug!(%hash, "get received request");
401 let mut hash_seq = None;
402 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
403 if offset == 0 {
404 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
405 } else {
406 let hash_seq = match &hash_seq {
412 Some(b) => b,
413 None => {
414 let bytes = store.get_bytes(hash).await?;
415 let hs = HashSeq::try_from(bytes)?;
416 hash_seq = Some(hs);
417 hash_seq.as_ref().unwrap()
418 }
419 };
420 let o = usize::try_from(offset - 1).context("offset too large")?;
421 let Some(hash) = hash_seq.get(o) else {
422 break;
423 };
424 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
425 }
426 }
427
428 Ok(())
429}
430
431pub async fn handle_get_many(
435 store: Store,
436 request: GetManyRequest,
437 writer: &mut ProgressWriter,
438) -> Result<()> {
439 debug!("get_many received request");
440 let request_ranges = request.ranges.iter_infinite();
441 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
442 if !ranges.is_empty() {
443 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
444 }
445 }
446 Ok(())
447}
448
449pub async fn handle_push(
453 store: Store,
454 request: PushRequest,
455 reader: &mut ProgressReader,
456) -> Result<()> {
457 let hash = request.hash;
458 debug!(%hash, "push received request");
459 let mut request_ranges = request.ranges.iter_infinite();
460 let root_ranges = request_ranges.next().expect("infinite iterator");
461 if !root_ranges.is_empty() {
462 store
464 .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner)
465 .await?;
466 }
467 if request.ranges.is_blob() {
468 debug!("push request complete");
469 return Ok(());
470 }
471 let hash_seq = store.get_bytes(hash).await?;
473 let hash_seq = HashSeq::try_from(hash_seq)?;
474 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
475 if child_ranges.is_empty() {
476 continue;
477 }
478 store
479 .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner)
480 .await?;
481 }
482 Ok(())
483}
484
485pub(crate) async fn send_blob(
487 store: &Store,
488 index: u64,
489 hash: Hash,
490 ranges: ChunkRanges,
491 writer: &mut ProgressWriter,
492) -> ExportBaoResult<()> {
493 store
494 .export_bao(hash, ranges)
495 .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
496 .await
497}
498
499pub async fn handle_observe(
503 store: Store,
504 request: ObserveRequest,
505 writer: &mut ProgressWriter,
506) -> Result<()> {
507 let mut stream = store.observe(request.hash).stream().await?;
508 let mut old = stream
509 .next()
510 .await
511 .ok_or(anyhow::anyhow!("observe stream closed before first value"))?;
512 send_observe_item(writer, &old).await?;
514 loop {
516 select! {
517 new = stream.next() => {
518 let new = new.context("observe stream closed")?;
519 let diff = old.diff(&new);
520 if diff.is_empty() {
521 continue;
522 }
523 send_observe_item(writer, &diff).await?;
524 old = new;
525 }
526 _ = writer.inner.stopped() => {
527 debug!("observer closed");
528 break;
529 }
530 }
531 }
532 Ok(())
533}
534
535async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> {
536 use irpc::util::AsyncWriteVarintExt;
537 let item = ObserveItem::from(item);
538 let len = writer.inner.write_length_prefixed(item).await?;
539 writer.context.log_other_write(len);
540 Ok(())
541}
542
543pub struct ProgressReader {
544 inner: RecvStream,
545 context: ReaderContext,
546}
547
548impl ProgressReader {
549 async fn transfer_aborted(&self) {
550 self.context
551 .tracker
552 .transfer_aborted(|| Box::new(self.context.stats()))
553 .await
554 .ok();
555 }
556
557 async fn transfer_completed(&self) {
558 self.context
559 .tracker
560 .transfer_completed(|| Box::new(self.context.stats()))
561 .await
562 .ok();
563 }
564}
565
566pub(crate) trait RecvStreamExt {
567 async fn read_to_end_as<T: DeserializeOwned>(
568 &mut self,
569 max_size: usize,
570 ) -> io::Result<(T, usize)>;
571}
572
573impl RecvStreamExt for RecvStream {
574 async fn read_to_end_as<T: DeserializeOwned>(
575 &mut self,
576 max_size: usize,
577 ) -> io::Result<(T, usize)> {
578 let data = self
579 .read_to_end(max_size)
580 .await
581 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
582 let value = postcard::from_bytes(&data)
583 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
584 Ok((value, data.len()))
585 }
586}