iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{
7    fmt::Debug,
8    future::Future,
9    io,
10    time::{Duration, Instant},
11};
12
13use anyhow::Result;
14use bao_tree::ChunkRanges;
15use iroh::endpoint::{self, VarInt};
16use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
17use n0_future::StreamExt;
18use quinn::ConnectionError;
19use serde::{Deserialize, Serialize};
20use snafu::Snafu;
21use tokio::select;
22use tracing::{debug, debug_span, Instrument};
23
24use crate::{
25    api::{
26        blobs::{Bitfield, WriteProgress},
27        ExportBaoError, ExportBaoResult, RequestError, Store,
28    },
29    hashseq::HashSeq,
30    protocol::{
31        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
32    },
33    provider::events::{
34        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
35        RequestTracker,
36    },
37    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
38    Hash,
39};
40pub mod events;
41use events::EventSender;
42
43type DefaultReader = iroh::endpoint::RecvStream;
44type DefaultWriter = iroh::endpoint::SendStream;
45
46/// Statistics about a successful or failed transfer.
47#[derive(Debug, Serialize, Deserialize)]
48pub struct TransferStats {
49    /// The number of bytes sent that are part of the payload.
50    pub payload_bytes_sent: u64,
51    /// The number of bytes sent that are not part of the payload.
52    ///
53    /// Hash pairs and the initial size header.
54    pub other_bytes_sent: u64,
55    /// The number of bytes read from the stream.
56    ///
57    /// In most cases this is just the request, for push requests this is
58    /// request, size header and hash pairs.
59    pub other_bytes_read: u64,
60    /// Total duration from reading the request to transfer completed.
61    pub duration: Duration,
62}
63
64/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
65#[derive(Debug)]
66pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
67    t0: Instant,
68    connection_id: u64,
69    reader: R,
70    writer: W,
71    other_bytes_read: u64,
72    events: EventSender,
73}
74
75impl StreamPair {
76    pub async fn accept(
77        conn: &endpoint::Connection,
78        events: EventSender,
79    ) -> Result<Self, ConnectionError> {
80        let (writer, reader) = conn.accept_bi().await?;
81        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
82    }
83}
84
85impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
86    pub fn stream_id(&self) -> u64 {
87        self.reader.id()
88    }
89
90    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
91        Self {
92            t0: Instant::now(),
93            connection_id,
94            reader,
95            writer,
96            other_bytes_read: 0,
97            events,
98        }
99    }
100
101    /// Read the request.
102    ///
103    /// Will fail if there is an error while reading, or if no valid request is sent.
104    ///
105    /// This will read exactly the number of bytes needed for the request, and
106    /// leave the rest of the stream for the caller to read.
107    ///
108    /// It is up to the caller do decide if there should be more data.
109    pub async fn read_request(&mut self) -> Result<Request> {
110        let (res, size) = Request::read_async(&mut self.reader).await?;
111        self.other_bytes_read += size as u64;
112        Ok(res)
113    }
114
115    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
116    pub async fn into_writer(
117        mut self,
118        tracker: RequestTracker,
119    ) -> Result<ProgressWriter<W>, io::Error> {
120        self.reader.expect_eof().await?;
121        drop(self.reader);
122        Ok(ProgressWriter::new(
123            self.writer,
124            WriterContext {
125                t0: self.t0,
126                other_bytes_read: self.other_bytes_read,
127                payload_bytes_written: 0,
128                other_bytes_written: 0,
129                tracker,
130            },
131        ))
132    }
133
134    pub async fn into_reader(
135        mut self,
136        tracker: RequestTracker,
137    ) -> Result<ProgressReader<R>, io::Error> {
138        self.writer.sync().await?;
139        drop(self.writer);
140        Ok(ProgressReader {
141            inner: self.reader,
142            context: ReaderContext {
143                t0: self.t0,
144                other_bytes_read: self.other_bytes_read,
145                tracker,
146            },
147        })
148    }
149
150    pub async fn get_request(
151        &self,
152        f: impl FnOnce() -> GetRequest,
153    ) -> Result<RequestTracker, ProgressError> {
154        self.events
155            .request(f, self.connection_id, self.reader.id())
156            .await
157    }
158
159    pub async fn get_many_request(
160        &self,
161        f: impl FnOnce() -> GetManyRequest,
162    ) -> Result<RequestTracker, ProgressError> {
163        self.events
164            .request(f, self.connection_id, self.reader.id())
165            .await
166    }
167
168    pub async fn push_request(
169        &self,
170        f: impl FnOnce() -> PushRequest,
171    ) -> Result<RequestTracker, ProgressError> {
172        self.events
173            .request(f, self.connection_id, self.reader.id())
174            .await
175    }
176
177    pub async fn observe_request(
178        &self,
179        f: impl FnOnce() -> ObserveRequest,
180    ) -> Result<RequestTracker, ProgressError> {
181        self.events
182            .request(f, self.connection_id, self.reader.id())
183            .await
184    }
185
186    pub fn stats(&self) -> TransferStats {
187        TransferStats {
188            payload_bytes_sent: 0,
189            other_bytes_sent: 0,
190            other_bytes_read: self.other_bytes_read,
191            duration: self.t0.elapsed(),
192        }
193    }
194}
195
196#[derive(Debug)]
197struct ReaderContext {
198    /// The start time of the transfer
199    t0: Instant,
200    /// The number of bytes read from the stream
201    other_bytes_read: u64,
202    /// Progress tracking for the request
203    tracker: RequestTracker,
204}
205
206impl ReaderContext {
207    fn stats(&self) -> TransferStats {
208        TransferStats {
209            payload_bytes_sent: 0,
210            other_bytes_sent: 0,
211            other_bytes_read: self.other_bytes_read,
212            duration: self.t0.elapsed(),
213        }
214    }
215}
216
217#[derive(Debug)]
218pub(crate) struct WriterContext {
219    /// The start time of the transfer
220    t0: Instant,
221    /// The number of bytes read from the stream
222    other_bytes_read: u64,
223    /// The number of payload bytes written to the stream
224    payload_bytes_written: u64,
225    /// The number of bytes written that are not part of the payload
226    other_bytes_written: u64,
227    /// Way to report progress
228    tracker: RequestTracker,
229}
230
231impl WriterContext {
232    fn stats(&self) -> TransferStats {
233        TransferStats {
234            payload_bytes_sent: self.payload_bytes_written,
235            other_bytes_sent: self.other_bytes_written,
236            other_bytes_read: self.other_bytes_read,
237            duration: self.t0.elapsed(),
238        }
239    }
240}
241
242impl WriteProgress for WriterContext {
243    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
244        let len = len as u64;
245        let end_offset = offset + len;
246        self.payload_bytes_written += len;
247        self.tracker.transfer_progress(len, end_offset).await
248    }
249
250    fn log_other_write(&mut self, len: usize) {
251        self.other_bytes_written += len as u64;
252    }
253
254    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
255        self.tracker.transfer_started(index, hash, size).await.ok();
256    }
257}
258
259/// Wrapper for a [`quinn::SendStream`] with additional per request information.
260#[derive(Debug)]
261pub struct ProgressWriter<W: SendStream = DefaultWriter> {
262    /// The quinn::SendStream to write to
263    pub inner: W,
264    pub(crate) context: WriterContext,
265}
266
267impl<W: SendStream> ProgressWriter<W> {
268    fn new(inner: W, context: WriterContext) -> Self {
269        Self { inner, context }
270    }
271
272    async fn transfer_aborted(&self) {
273        self.context
274            .tracker
275            .transfer_aborted(|| Box::new(self.context.stats()))
276            .await
277            .ok();
278    }
279
280    async fn transfer_completed(&self) {
281        self.context
282            .tracker
283            .transfer_completed(|| Box::new(self.context.stats()))
284            .await
285            .ok();
286    }
287}
288
289/// Handle a single connection.
290pub async fn handle_connection(
291    connection: endpoint::Connection,
292    store: Store,
293    progress: EventSender,
294) {
295    let connection_id = connection.stable_id() as u64;
296    let span = debug_span!("connection", connection_id);
297    async move {
298        if let Err(cause) = progress
299            .client_connected(|| ClientConnected {
300                connection_id,
301                node_id: connection.remote_node_id().ok(),
302            })
303            .await
304        {
305            connection.close(cause.code(), cause.reason());
306            debug!("closing connection: {cause}");
307            return;
308        }
309        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
310            let span = debug_span!("stream", stream_id = %pair.stream_id());
311            let store = store.clone();
312            tokio::spawn(handle_stream(pair, store).instrument(span));
313        }
314        progress
315            .connection_closed(|| ConnectionClosed { connection_id })
316            .await
317            .ok();
318    }
319    .instrument(span)
320    .await
321}
322
323/// Describes how to handle errors for a stream.
324pub trait ErrorHandler {
325    type W: AsyncStreamWriter;
326    type R: AsyncStreamReader;
327    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
328    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
329}
330
331async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
332    pair: &mut StreamPair<R, W>,
333    r: Result<T, E>,
334) -> Result<T, E> {
335    match r {
336        Ok(x) => Ok(x),
337        Err(e) => {
338            pair.writer.reset(e.code()).ok();
339            Err(e)
340        }
341    }
342}
343async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
344    writer: &mut ProgressWriter<W>,
345    r: Result<T, E>,
346) -> Result<T, E> {
347    match r {
348        Ok(x) => {
349            writer.transfer_completed().await;
350            Ok(x)
351        }
352        Err(e) => {
353            writer.inner.reset(e.code()).ok();
354            writer.transfer_aborted().await;
355            Err(e)
356        }
357    }
358}
359async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
360    reader: &mut ProgressReader<R>,
361    r: Result<T, E>,
362) -> Result<T, E> {
363    match r {
364        Ok(x) => {
365            reader.transfer_completed().await;
366            Ok(x)
367        }
368        Err(e) => {
369            reader.inner.stop(e.code()).ok();
370            reader.transfer_aborted().await;
371            Err(e)
372        }
373    }
374}
375
376pub async fn handle_stream<R: RecvStream, W: SendStream>(
377    mut pair: StreamPair<R, W>,
378    store: Store,
379) -> anyhow::Result<()> {
380    let request = pair.read_request().await?;
381    match request {
382        Request::Get(request) => handle_get(pair, store, request).await?,
383        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
384        Request::Observe(request) => handle_observe(pair, store, request).await?,
385        Request::Push(request) => handle_push(pair, store, request).await?,
386        _ => {}
387    }
388    Ok(())
389}
390
391#[derive(Debug, Snafu)]
392#[snafu(module)]
393pub enum HandleGetError {
394    #[snafu(transparent)]
395    ExportBao {
396        source: ExportBaoError,
397    },
398    InvalidHashSeq,
399    InvalidOffset,
400}
401
402impl HasErrorCode for HandleGetError {
403    fn code(&self) -> VarInt {
404        match self {
405            HandleGetError::ExportBao {
406                source: ExportBaoError::ClientError { source, .. },
407            } => source.code(),
408            HandleGetError::InvalidHashSeq => ERR_INTERNAL,
409            HandleGetError::InvalidOffset => ERR_INTERNAL,
410            _ => ERR_INTERNAL,
411        }
412    }
413}
414
415/// Handle a single get request.
416///
417/// Requires a database, the request, and a writer.
418async fn handle_get_impl<W: SendStream>(
419    store: Store,
420    request: GetRequest,
421    writer: &mut ProgressWriter<W>,
422) -> Result<(), HandleGetError> {
423    let hash = request.hash;
424    debug!(%hash, "get received request");
425    let mut hash_seq = None;
426    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
427        if offset == 0 {
428            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
429        } else {
430            // todo: this assumes that 1. the hashseq is complete and 2. it is
431            // small enough to fit in memory.
432            //
433            // This should really read the hashseq from the store in chunks,
434            // only where needed, so we can deal with holes and large hashseqs.
435            let hash_seq = match &hash_seq {
436                Some(b) => b,
437                None => {
438                    let bytes = store.get_bytes(hash).await?;
439                    let hs =
440                        HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
441                    hash_seq = Some(hs);
442                    hash_seq.as_ref().unwrap()
443                }
444            };
445            let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
446            let Some(hash) = hash_seq.get(o) else {
447                break;
448            };
449            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
450        }
451    }
452    writer
453        .inner
454        .sync()
455        .await
456        .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
457
458    Ok(())
459}
460
461pub async fn handle_get<R: RecvStream, W: SendStream>(
462    mut pair: StreamPair<R, W>,
463    store: Store,
464    request: GetRequest,
465) -> anyhow::Result<()> {
466    let res = pair.get_request(|| request.clone()).await;
467    let tracker = handle_read_request_result(&mut pair, res).await?;
468    let mut writer = pair.into_writer(tracker).await?;
469    let res = handle_get_impl(store, request, &mut writer).await;
470    handle_write_result(&mut writer, res).await?;
471    Ok(())
472}
473
474#[derive(Debug, Snafu)]
475pub enum HandleGetManyError {
476    #[snafu(transparent)]
477    ExportBao { source: ExportBaoError },
478}
479
480impl HasErrorCode for HandleGetManyError {
481    fn code(&self) -> VarInt {
482        match self {
483            Self::ExportBao {
484                source: ExportBaoError::ClientError { source, .. },
485            } => source.code(),
486            _ => ERR_INTERNAL,
487        }
488    }
489}
490
491/// Handle a single get request.
492///
493/// Requires a database, the request, and a writer.
494async fn handle_get_many_impl<W: SendStream>(
495    store: Store,
496    request: GetManyRequest,
497    writer: &mut ProgressWriter<W>,
498) -> Result<(), HandleGetManyError> {
499    debug!("get_many received request");
500    let request_ranges = request.ranges.iter_infinite();
501    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
502        if !ranges.is_empty() {
503            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
504        }
505    }
506    Ok(())
507}
508
509pub async fn handle_get_many<R: RecvStream, W: SendStream>(
510    mut pair: StreamPair<R, W>,
511    store: Store,
512    request: GetManyRequest,
513) -> anyhow::Result<()> {
514    let res = pair.get_many_request(|| request.clone()).await;
515    let tracker = handle_read_request_result(&mut pair, res).await?;
516    let mut writer = pair.into_writer(tracker).await?;
517    let res = handle_get_many_impl(store, request, &mut writer).await;
518    handle_write_result(&mut writer, res).await?;
519    Ok(())
520}
521
522#[derive(Debug, Snafu)]
523pub enum HandlePushError {
524    #[snafu(transparent)]
525    ExportBao {
526        source: ExportBaoError,
527    },
528
529    InvalidHashSeq,
530
531    #[snafu(transparent)]
532    Request {
533        source: RequestError,
534    },
535}
536
537impl HasErrorCode for HandlePushError {
538    fn code(&self) -> VarInt {
539        match self {
540            Self::ExportBao {
541                source: ExportBaoError::ClientError { source, .. },
542            } => source.code(),
543            _ => ERR_INTERNAL,
544        }
545    }
546}
547
548/// Handle a single push request.
549///
550/// Requires a database, the request, and a reader.
551async fn handle_push_impl<R: RecvStream>(
552    store: Store,
553    request: PushRequest,
554    reader: &mut ProgressReader<R>,
555) -> Result<(), HandlePushError> {
556    let hash = request.hash;
557    debug!(%hash, "push received request");
558    let mut request_ranges = request.ranges.iter_infinite();
559    let root_ranges = request_ranges.next().expect("infinite iterator");
560    if !root_ranges.is_empty() {
561        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
562        store
563            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
564            .await?;
565    }
566    if request.ranges.is_blob() {
567        debug!("push request complete");
568        return Ok(());
569    }
570    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
571    let hash_seq = store.get_bytes(hash).await?;
572    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
573    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
574        if child_ranges.is_empty() {
575            continue;
576        }
577        store
578            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
579            .await?;
580    }
581    Ok(())
582}
583
584pub async fn handle_push<R: RecvStream, W: SendStream>(
585    mut pair: StreamPair<R, W>,
586    store: Store,
587    request: PushRequest,
588) -> anyhow::Result<()> {
589    let res = pair.push_request(|| request.clone()).await;
590    let tracker = handle_read_request_result(&mut pair, res).await?;
591    let mut reader = pair.into_reader(tracker).await?;
592    let res = handle_push_impl(store, request, &mut reader).await;
593    handle_read_result(&mut reader, res).await?;
594    Ok(())
595}
596
597/// Send a blob to the client.
598pub(crate) async fn send_blob<W: SendStream>(
599    store: &Store,
600    index: u64,
601    hash: Hash,
602    ranges: ChunkRanges,
603    writer: &mut ProgressWriter<W>,
604) -> ExportBaoResult<()> {
605    store
606        .export_bao(hash, ranges)
607        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
608        .await
609}
610
611#[derive(Debug, Snafu)]
612pub enum HandleObserveError {
613    ObserveStreamClosed,
614
615    #[snafu(transparent)]
616    RemoteClosed {
617        source: io::Error,
618    },
619}
620
621impl HasErrorCode for HandleObserveError {
622    fn code(&self) -> VarInt {
623        ERR_INTERNAL
624    }
625}
626
627/// Handle a single push request.
628///
629/// Requires a database, the request, and a reader.
630async fn handle_observe_impl<W: SendStream>(
631    store: Store,
632    request: ObserveRequest,
633    writer: &mut ProgressWriter<W>,
634) -> std::result::Result<(), HandleObserveError> {
635    let mut stream = store
636        .observe(request.hash)
637        .stream()
638        .await
639        .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
640    let mut old = stream
641        .next()
642        .await
643        .ok_or(HandleObserveError::ObserveStreamClosed)?;
644    // send the initial bitfield
645    send_observe_item(writer, &old).await?;
646    // send updates until the remote loses interest
647    loop {
648        select! {
649            new = stream.next() => {
650                let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
651                let diff = old.diff(&new);
652                if diff.is_empty() {
653                    continue;
654                }
655                send_observe_item(writer, &diff).await?;
656                old = new;
657            }
658            _ = writer.inner.stopped() => {
659                debug!("observer closed");
660                break;
661            }
662        }
663    }
664    Ok(())
665}
666
667async fn send_observe_item<W: SendStream>(
668    writer: &mut ProgressWriter<W>,
669    item: &Bitfield,
670) -> io::Result<()> {
671    let item = ObserveItem::from(item);
672    let len = writer.inner.write_length_prefixed(item).await?;
673    writer.context.log_other_write(len);
674    Ok(())
675}
676
677pub async fn handle_observe<R: RecvStream, W: SendStream>(
678    mut pair: StreamPair<R, W>,
679    store: Store,
680    request: ObserveRequest,
681) -> anyhow::Result<()> {
682    let res = pair.observe_request(|| request.clone()).await;
683    let tracker = handle_read_request_result(&mut pair, res).await?;
684    let mut writer = pair.into_writer(tracker).await?;
685    let res = handle_observe_impl(store, request, &mut writer).await;
686    handle_write_result(&mut writer, res).await?;
687    Ok(())
688}
689
690pub struct ProgressReader<R: RecvStream = DefaultReader> {
691    inner: R,
692    context: ReaderContext,
693}
694
695impl<R: RecvStream> ProgressReader<R> {
696    async fn transfer_aborted(&self) {
697        self.context
698            .tracker
699            .transfer_aborted(|| Box::new(self.context.stats()))
700            .await
701            .ok();
702    }
703
704    async fn transfer_completed(&self) {
705        self.context
706            .tracker
707            .transfer_completed(|| Box::new(self.context.stats()))
708            .await
709            .ok();
710    }
711}