iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::net_protocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{
7    fmt::Debug,
8    io,
9    ops::{Deref, DerefMut},
10    pin::Pin,
11    task::Poll,
12    time::Duration,
13};
14
15use anyhow::{Context, Result};
16use bao_tree::ChunkRanges;
17use iroh::{
18    endpoint::{self, RecvStream, SendStream},
19    NodeId,
20};
21use irpc::channel::oneshot;
22use n0_future::StreamExt;
23use serde::de::DeserializeOwned;
24use tokio::{io::AsyncRead, select, sync::mpsc};
25use tracing::{debug, debug_span, error, warn, Instrument};
26
27use crate::{
28    api::{self, blobs::Bitfield, Store},
29    hashseq::HashSeq,
30    protocol::{
31        ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest,
32        Request,
33    },
34    Hash,
35};
36
37/// Provider progress events, to keep track of what the provider is doing.
38///
39/// ClientConnected ->
40///    (GetRequestReceived -> (TransferStarted -> TransferProgress*n)*n -> (TransferCompleted | TransferAborted))*n ->
41/// ConnectionClosed
42#[derive(Debug)]
43pub enum Event {
44    /// A new client connected to the provider.
45    ClientConnected {
46        connection_id: u64,
47        node_id: NodeId,
48        permitted: oneshot::Sender<bool>,
49    },
50    /// Connection closed.
51    ConnectionClosed { connection_id: u64 },
52    /// A new get request was received from the provider.
53    GetRequestReceived {
54        /// The connection id. Multiple requests can be sent over the same connection.
55        connection_id: u64,
56        /// The request id. There is a new id for each request.
57        request_id: u64,
58        /// The root hash of the request.
59        hash: Hash,
60        /// The exact query ranges of the request.
61        ranges: ChunkRangesSeq,
62    },
63    /// A new get request was received from the provider.
64    GetManyRequestReceived {
65        /// The connection id. Multiple requests can be sent over the same connection.
66        connection_id: u64,
67        /// The request id. There is a new id for each request.
68        request_id: u64,
69        /// The root hash of the request.
70        hashes: Vec<Hash>,
71        /// The exact query ranges of the request.
72        ranges: ChunkRangesSeq,
73    },
74    /// A new get request was received from the provider.
75    PushRequestReceived {
76        /// The connection id. Multiple requests can be sent over the same connection.
77        connection_id: u64,
78        /// The request id. There is a new id for each request.
79        request_id: u64,
80        /// The root hash of the request.
81        hash: Hash,
82        /// The exact query ranges of the request.
83        ranges: ChunkRangesSeq,
84        /// Complete this to permit the request.
85        permitted: oneshot::Sender<bool>,
86    },
87    /// Transfer for the nth blob started.
88    TransferStarted {
89        /// The connection id. Multiple requests can be sent over the same connection.
90        connection_id: u64,
91        /// The request id. There is a new id for each request.
92        request_id: u64,
93        /// The index of the blob in the request. 0 for the first blob or for raw blob requests.
94        index: u64,
95        /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs.
96        hash: Hash,
97        /// The size of the blob. This is the full size of the blob, not the size we are sending.
98        size: u64,
99    },
100    /// Progress of the transfer.
101    TransferProgress {
102        /// The connection id. Multiple requests can be sent over the same connection.
103        connection_id: u64,
104        /// The request id. There is a new id for each request.
105        request_id: u64,
106        /// The index of the blob in the request. 0 for the first blob or for raw blob requests.
107        index: u64,
108        /// The end offset of the chunk that was sent.
109        end_offset: u64,
110    },
111    /// Entire transfer completed.
112    TransferCompleted {
113        /// The connection id. Multiple requests can be sent over the same connection.
114        connection_id: u64,
115        /// The request id. There is a new id for each request.
116        request_id: u64,
117        /// Statistics about the transfer.
118        stats: Box<TransferStats>,
119    },
120    /// Entire transfer aborted
121    TransferAborted {
122        /// The connection id. Multiple requests can be sent over the same connection.
123        connection_id: u64,
124        /// The request id. There is a new id for each request.
125        request_id: u64,
126        /// Statistics about the part of the transfer that was aborted.
127        stats: Option<Box<TransferStats>>,
128    },
129}
130
131/// Statistics about a successful or failed transfer.
132#[derive(Debug)]
133pub struct TransferStats {
134    /// The number of bytes sent that are part of the payload.
135    pub payload_bytes_sent: u64,
136    /// The number of bytes sent that are not part of the payload.
137    ///
138    /// Hash pairs and the initial size header.
139    pub other_bytes_sent: u64,
140    /// The number of bytes read from the stream.
141    ///
142    /// This is the size of the request.
143    pub bytes_read: u64,
144    /// Total duration from reading the request to transfer completed.
145    pub duration: Duration,
146}
147
148/// Read the request from the getter.
149///
150/// Will fail if there is an error while reading, or if no valid request is sent.
151///
152/// This will read exactly the number of bytes needed for the request, and
153/// leave the rest of the stream for the caller to read.
154///
155/// It is up to the caller do decide if there should be more data.
156pub async fn read_request(reader: &mut ProgressReader) -> Result<Request> {
157    let mut counting = CountingReader::new(&mut reader.inner);
158    let res = Request::read_async(&mut counting).await?;
159    reader.bytes_read += counting.read();
160    Ok(res)
161}
162
163#[derive(Debug)]
164pub struct StreamContext {
165    /// The connection ID from the connection
166    pub connection_id: u64,
167    /// The request ID from the recv stream
168    pub request_id: u64,
169    /// The number of bytes written that are part of the payload
170    pub payload_bytes_sent: u64,
171    /// The number of bytes written that are not part of the payload
172    pub other_bytes_sent: u64,
173    /// The number of bytes read from the stream
174    pub bytes_read: u64,
175    /// The progress sender to send events to
176    pub progress: EventSender,
177}
178
179/// Wrapper for a [`quinn::SendStream`] with additional per request information.
180#[derive(Debug)]
181pub struct ProgressWriter {
182    /// The quinn::SendStream to write to
183    pub inner: SendStream,
184    pub(crate) context: StreamContext,
185}
186
187impl Deref for ProgressWriter {
188    type Target = StreamContext;
189
190    fn deref(&self) -> &Self::Target {
191        &self.context
192    }
193}
194
195impl DerefMut for ProgressWriter {
196    fn deref_mut(&mut self) -> &mut Self::Target {
197        &mut self.context
198    }
199}
200
201impl StreamContext {
202    /// Increase the write count due to a non-payload write.
203    pub fn log_other_write(&mut self, len: usize) {
204        self.other_bytes_sent += len as u64;
205    }
206
207    pub async fn send_transfer_completed(&mut self) {
208        self.progress
209            .send(|| Event::TransferCompleted {
210                connection_id: self.connection_id,
211                request_id: self.request_id,
212                stats: Box::new(TransferStats {
213                    payload_bytes_sent: self.payload_bytes_sent,
214                    other_bytes_sent: self.other_bytes_sent,
215                    bytes_read: self.bytes_read,
216                    duration: Duration::ZERO,
217                }),
218            })
219            .await;
220    }
221
222    pub async fn send_transfer_aborted(&mut self) {
223        self.progress
224            .send(|| Event::TransferAborted {
225                connection_id: self.connection_id,
226                request_id: self.request_id,
227                stats: Some(Box::new(TransferStats {
228                    payload_bytes_sent: self.payload_bytes_sent,
229                    other_bytes_sent: self.other_bytes_sent,
230                    bytes_read: self.bytes_read,
231                    duration: Duration::ZERO,
232                })),
233            })
234            .await;
235    }
236
237    /// Increase the write count due to a payload write, and notify the progress sender.
238    ///
239    /// `index` is the index of the blob in the request.
240    /// `offset` is the offset in the blob where the write started.
241    /// `len` is the length of the write.
242    pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) {
243        self.payload_bytes_sent += len as u64;
244        self.progress.try_send(|| Event::TransferProgress {
245            connection_id: self.connection_id,
246            request_id: self.request_id,
247            index,
248            end_offset: offset + len as u64,
249        });
250    }
251
252    /// Send a get request received event.
253    ///
254    /// This sends all the required information to make sense of subsequent events such as
255    /// [`Event::TransferStarted`] and [`Event::TransferProgress`].
256    pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) {
257        self.progress
258            .send(|| Event::GetRequestReceived {
259                connection_id: self.connection_id,
260                request_id: self.request_id,
261                hash: *hash,
262                ranges: ranges.clone(),
263            })
264            .await;
265    }
266
267    /// Send a get request received event.
268    ///
269    /// This sends all the required information to make sense of subsequent events such as
270    /// [`Event::TransferStarted`] and [`Event::TransferProgress`].
271    pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) {
272        self.progress
273            .send(|| Event::GetManyRequestReceived {
274                connection_id: self.connection_id,
275                request_id: self.request_id,
276                hashes: hashes.to_vec(),
277                ranges: ranges.clone(),
278            })
279            .await;
280    }
281
282    /// Authorize a push request.
283    ///
284    /// This will send a request to the event sender, and wait for a response if a
285    /// progress sender is enabled. If not, it will always fail.
286    ///
287    /// We want to make accepting push requests very explicit, since this allows
288    /// remote nodes to add arbitrary data to our store.
289    #[must_use = "permit should be checked by the caller"]
290    pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool {
291        let mut wait_for_permit = None;
292        // send the request, including the permit channel
293        self.progress
294            .send(|| {
295                let (tx, rx) = oneshot::channel();
296                wait_for_permit = Some(rx);
297                Event::PushRequestReceived {
298                    connection_id: self.connection_id,
299                    request_id: self.request_id,
300                    hash: *hash,
301                    ranges: ranges.clone(),
302                    permitted: tx,
303                }
304            })
305            .await;
306        // wait for the permit, if necessary
307        if let Some(wait_for_permit) = wait_for_permit {
308            // if somebody does not handle the request, they will drop the channel,
309            // and this will fail immediately.
310            wait_for_permit.await.unwrap_or(false)
311        } else {
312            false
313        }
314    }
315
316    /// Send a transfer started event.
317    pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) {
318        self.progress
319            .send(|| Event::TransferStarted {
320                connection_id: self.connection_id,
321                request_id: self.request_id,
322                index,
323                hash: *hash,
324                size,
325            })
326            .await;
327    }
328}
329
330/// Handle a single connection.
331pub async fn handle_connection(
332    connection: endpoint::Connection,
333    store: Store,
334    progress: EventSender,
335) {
336    let connection_id = connection.stable_id() as u64;
337    let span = debug_span!("connection", connection_id);
338    async move {
339        let Ok(node_id) = connection.remote_node_id() else {
340            warn!("failed to get node id");
341            return;
342        };
343        if !progress
344            .authorize_client_connection(connection_id, node_id)
345            .await
346        {
347            debug!("client not authorized to connect");
348            return;
349        }
350        while let Ok((writer, reader)) = connection.accept_bi().await {
351            // The stream ID index is used to identify this request.  Requests only arrive in
352            // bi-directional RecvStreams initiated by the client, so this uniquely identifies them.
353            let request_id = reader.id().index();
354            let span = debug_span!("stream", stream_id = %request_id);
355            let store = store.clone();
356            let mut writer = ProgressWriter {
357                inner: writer,
358                context: StreamContext {
359                    connection_id,
360                    request_id,
361                    payload_bytes_sent: 0,
362                    other_bytes_sent: 0,
363                    bytes_read: 0,
364                    progress: progress.clone(),
365                },
366            };
367            tokio::spawn(
368                async move {
369                    match handle_stream(store, reader, &mut writer).await {
370                        Ok(()) => {
371                            writer.send_transfer_completed().await;
372                        }
373                        Err(err) => {
374                            warn!("error: {err:#?}",);
375                            writer.send_transfer_aborted().await;
376                        }
377                    }
378                }
379                .instrument(span),
380            );
381        }
382        progress
383            .send(Event::ConnectionClosed { connection_id })
384            .await;
385    }
386    .instrument(span)
387    .await
388}
389
390async fn handle_stream(
391    store: Store,
392    reader: RecvStream,
393    writer: &mut ProgressWriter,
394) -> Result<()> {
395    // 1. Decode the request.
396    debug!("reading request");
397    let mut reader = ProgressReader {
398        inner: reader,
399        context: StreamContext {
400            connection_id: writer.connection_id,
401            request_id: writer.request_id,
402            payload_bytes_sent: 0,
403            other_bytes_sent: 0,
404            bytes_read: 0,
405            progress: writer.progress.clone(),
406        },
407    };
408    let request = match read_request(&mut reader).await {
409        Ok(request) => request,
410        Err(e) => {
411            // todo: increase invalid requests metric counter
412            return Err(e);
413        }
414    };
415
416    match request {
417        Request::Get(request) => {
418            // we expect no more bytes after the request, so if there are more bytes, it is an invalid request.
419            reader.inner.read_to_end(0).await?;
420            // move the context so we don't lose the bytes read
421            writer.context = reader.context;
422            handle_get(store, request, writer).await
423        }
424        Request::GetMany(request) => {
425            // we expect no more bytes after the request, so if there are more bytes, it is an invalid request.
426            reader.inner.read_to_end(0).await?;
427            // move the context so we don't lose the bytes read
428            writer.context = reader.context;
429            handle_get_many(store, request, writer).await
430        }
431        Request::Observe(request) => {
432            // we expect no more bytes after the request, so if there are more bytes, it is an invalid request.
433            reader.inner.read_to_end(0).await?;
434            handle_observe(store, request, writer).await
435        }
436        Request::Push(request) => {
437            writer.inner.finish()?;
438            handle_push(store, request, reader).await
439        }
440        _ => anyhow::bail!("unsupported request: {request:?}"),
441        // Request::Push(request) => handle_push(store, request, writer).await,
442    }
443}
444
445/// Handle a single get request.
446///
447/// Requires a database, the request, and a writer.
448pub async fn handle_get(
449    store: Store,
450    request: GetRequest,
451    writer: &mut ProgressWriter,
452) -> Result<()> {
453    let hash = request.hash;
454    debug!(%hash, "get received request");
455
456    writer
457        .send_get_request_received(&hash, &request.ranges)
458        .await;
459    let mut hash_seq = None;
460    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
461        if offset == 0 {
462            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
463        } else {
464            // todo: this assumes that 1. the hashseq is complete and 2. it is
465            // small enough to fit in memory.
466            //
467            // This should really read the hashseq from the store in chunks,
468            // only where needed, so we can deal with holes and large hashseqs.
469            let hash_seq = match &hash_seq {
470                Some(b) => b,
471                None => {
472                    let bytes = store.get_bytes(hash).await?;
473                    let hs = HashSeq::try_from(bytes)?;
474                    hash_seq = Some(hs);
475                    hash_seq.as_ref().unwrap()
476                }
477            };
478            let o = usize::try_from(offset - 1).context("offset too large")?;
479            let Some(hash) = hash_seq.get(o) else {
480                break;
481            };
482            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
483        }
484    }
485
486    Ok(())
487}
488
489/// Handle a single get request.
490///
491/// Requires a database, the request, and a writer.
492pub async fn handle_get_many(
493    store: Store,
494    request: GetManyRequest,
495    writer: &mut ProgressWriter,
496) -> Result<()> {
497    debug!("get_many received request");
498    writer
499        .send_get_many_request_received(&request.hashes, &request.ranges)
500        .await;
501    let request_ranges = request.ranges.iter_infinite();
502    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
503        if !ranges.is_empty() {
504            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
505        }
506    }
507    Ok(())
508}
509
510/// Handle a single push request.
511///
512/// Requires a database, the request, and a reader.
513pub async fn handle_push(
514    store: Store,
515    request: PushRequest,
516    mut reader: ProgressReader,
517) -> Result<()> {
518    let hash = request.hash;
519    debug!(%hash, "push received request");
520    if !reader.authorize_push_request(&hash, &request.ranges).await {
521        debug!("push request not authorized");
522        return Ok(());
523    };
524    let mut request_ranges = request.ranges.iter_infinite();
525    let root_ranges = request_ranges.next().expect("infinite iterator");
526    if !root_ranges.is_empty() {
527        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
528        store
529            .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner)
530            .await?;
531    }
532    if request.ranges.is_blob() {
533        debug!("push request complete");
534        return Ok(());
535    }
536    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
537    let hash_seq = store.get_bytes(hash).await?;
538    let hash_seq = HashSeq::try_from(hash_seq)?;
539    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
540        if child_ranges.is_empty() {
541            continue;
542        }
543        store
544            .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner)
545            .await?;
546    }
547    Ok(())
548}
549
550/// Send a blob to the client.
551pub(crate) async fn send_blob(
552    store: &Store,
553    index: u64,
554    hash: Hash,
555    ranges: ChunkRanges,
556    writer: &mut ProgressWriter,
557) -> api::Result<()> {
558    Ok(store
559        .export_bao(hash, ranges)
560        .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
561        .await?)
562}
563
564/// Handle a single push request.
565///
566/// Requires a database, the request, and a reader.
567pub async fn handle_observe(
568    store: Store,
569    request: ObserveRequest,
570    writer: &mut ProgressWriter,
571) -> Result<()> {
572    let mut stream = store.observe(request.hash).stream().await?;
573    let mut old = stream
574        .next()
575        .await
576        .ok_or(anyhow::anyhow!("observe stream closed before first value"))?;
577    // send the initial bitfield
578    send_observe_item(writer, &old).await?;
579    // send updates until the remote loses interest
580    loop {
581        select! {
582            new = stream.next() => {
583                let new = new.context("observe stream closed")?;
584                let diff = old.diff(&new);
585                if diff.is_empty() {
586                    continue;
587                }
588                send_observe_item(writer, &diff).await?;
589                old = new;
590            }
591            _ = writer.inner.stopped() => {
592                debug!("observer closed");
593                break;
594            }
595        }
596    }
597    Ok(())
598}
599
600async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> {
601    use irpc::util::AsyncWriteVarintExt;
602    let item = ObserveItem::from(item);
603    let len = writer.inner.write_length_prefixed(item).await?;
604    writer.log_other_write(len);
605    Ok(())
606}
607
608/// Helper to lazyly create an [`Event`], in the case that the event creation
609/// is expensive and we want to avoid it if the progress sender is disabled.
610pub trait LazyEvent {
611    fn call(self) -> Event;
612}
613
614impl<T> LazyEvent for T
615where
616    T: FnOnce() -> Event,
617{
618    fn call(self) -> Event {
619        self()
620    }
621}
622
623impl LazyEvent for Event {
624    fn call(self) -> Event {
625        self
626    }
627}
628
629/// A sender for provider events.
630#[derive(Debug, Clone)]
631pub struct EventSender(EventSenderInner);
632
633#[derive(Debug, Clone)]
634enum EventSenderInner {
635    Disabled,
636    Enabled(mpsc::Sender<Event>),
637}
638
639impl EventSender {
640    pub fn new(sender: Option<mpsc::Sender<Event>>) -> Self {
641        match sender {
642            Some(sender) => Self(EventSenderInner::Enabled(sender)),
643            None => Self(EventSenderInner::Disabled),
644        }
645    }
646
647    /// Send a client connected event, if the progress sender is enabled.
648    ///
649    /// This will permit the client to connect if the sender is disabled.
650    #[must_use = "permit should be checked by the caller"]
651    pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool {
652        let mut wait_for_permit = None;
653        self.send(|| {
654            let (tx, rx) = oneshot::channel();
655            wait_for_permit = Some(rx);
656            Event::ClientConnected {
657                connection_id,
658                node_id,
659                permitted: tx,
660            }
661        })
662        .await;
663        if let Some(wait_for_permit) = wait_for_permit {
664            // if we have events configured, and they drop the channel, we consider that as a no!
665            // todo: this will be confusing and needs to be properly documented.
666            wait_for_permit.await.unwrap_or(false)
667        } else {
668            true
669        }
670    }
671
672    /// Send an ephemeral event, if the progress sender is enabled.
673    ///
674    /// The event will only be created if the sender is enabled.
675    fn try_send(&self, event: impl LazyEvent) {
676        match &self.0 {
677            EventSenderInner::Enabled(sender) => {
678                let value = event.call();
679                sender.try_send(value).ok();
680            }
681            EventSenderInner::Disabled => {}
682        }
683    }
684
685    /// Send a mandatory event, if the progress sender is enabled.
686    ///
687    /// The event only be created if the sender is enabled.
688    async fn send(&self, event: impl LazyEvent) {
689        match &self.0 {
690            EventSenderInner::Enabled(sender) => {
691                let value = event.call();
692                if let Err(err) = sender.send(value).await {
693                    error!("failed to send progress event: {:?}", err);
694                }
695            }
696            EventSenderInner::Disabled => {}
697        }
698    }
699}
700
701pub struct ProgressReader {
702    inner: RecvStream,
703    context: StreamContext,
704}
705
706impl Deref for ProgressReader {
707    type Target = StreamContext;
708
709    fn deref(&self) -> &Self::Target {
710        &self.context
711    }
712}
713
714impl DerefMut for ProgressReader {
715    fn deref_mut(&mut self) -> &mut Self::Target {
716        &mut self.context
717    }
718}
719
720pub struct CountingReader<R> {
721    pub inner: R,
722    pub read: u64,
723}
724
725impl<R> CountingReader<R> {
726    pub fn new(inner: R) -> Self {
727        Self { inner, read: 0 }
728    }
729
730    pub fn read(&self) -> u64 {
731        self.read
732    }
733}
734
735impl CountingReader<&mut iroh::endpoint::RecvStream> {
736    pub async fn read_to_end_as<T: DeserializeOwned>(&mut self, max_size: usize) -> io::Result<T> {
737        let data = self
738            .inner
739            .read_to_end(max_size)
740            .await
741            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
742        let value = postcard::from_bytes(&data)
743            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
744        self.read += data.len() as u64;
745        Ok(value)
746    }
747}
748
749impl<R: AsyncRead + Unpin> AsyncRead for CountingReader<R> {
750    fn poll_read(
751        self: Pin<&mut Self>,
752        cx: &mut std::task::Context<'_>,
753        buf: &mut tokio::io::ReadBuf<'_>,
754    ) -> Poll<io::Result<()>> {
755        let this = self.get_mut();
756        let result = Pin::new(&mut this.inner).poll_read(cx, buf);
757        if let Poll::Ready(Ok(())) = result {
758            this.read += buf.filled().len() as u64;
759        }
760        result
761    }
762}