iroh_blobs/provider/
events.rs

1use std::{fmt::Debug, io, ops::Deref};
2
3use irpc::{
4    channel::{mpsc, none::NoSender, oneshot},
5    rpc_requests, Channels, WithChannels,
6};
7use serde::{Deserialize, Serialize};
8use snafu::Snafu;
9
10use crate::{
11    protocol::{
12        GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13        ERR_PERMISSION,
14    },
15    provider::{events::irpc_ext::IrpcClientExt, TransferStats},
16    Hash,
17};
18
19/// Mode for connect events.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[repr(u8)]
22pub enum ConnectMode {
23    /// We don't get notification of connect events at all.
24    #[default]
25    None,
26    /// We get a notification for connect events.
27    Notify,
28    /// We get a request for connect events and can reject incoming connections.
29    Intercept,
30}
31
32/// Request mode for observe requests.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[repr(u8)]
35pub enum ObserveMode {
36    /// We don't get notification of connect events at all.
37    #[default]
38    None,
39    /// We get a notification for connect events.
40    Notify,
41    /// We get a request for connect events and can reject incoming connections.
42    Intercept,
43}
44
45/// Request mode for all data related requests.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RequestMode {
49    /// We don't get request events at all.
50    #[default]
51    None,
52    /// We get a notification for each request, but no transfer events.
53    Notify,
54    /// We get a request for each request, and can reject incoming requests, but no transfer events.
55    Intercept,
56    /// We get a notification for each request as well as detailed transfer events.
57    NotifyLog,
58    /// We get a request for each request, and can reject incoming requests.
59    /// We also get detailed transfer events.
60    InterceptLog,
61    /// This request type is completely disabled. All requests will be rejected.
62    ///
63    /// This means that requests of this kind will always be rejected, whereas
64    /// None means that we don't get any events, but requests will be processed normally.
65    Disabled,
66}
67
68/// Throttling mode for requests that support throttling.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70#[repr(u8)]
71pub enum ThrottleMode {
72    /// We don't get these kinds of events at all
73    #[default]
74    None,
75    /// We call throttle to give the event handler a way to throttle requests
76    Intercept,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AbortReason {
81    /// The request was aborted because a limit was exceeded. It is OK to try again later.
82    RateLimited,
83    /// The request was aborted because the client does not have permission to perform the operation.
84    Permission,
85}
86
87/// Errors that can occur when sending progress updates.
88#[derive(Debug, Snafu)]
89pub enum ProgressError {
90    Limit,
91    Permission,
92    #[snafu(transparent)]
93    Internal {
94        source: irpc::Error,
95    },
96}
97
98impl From<ProgressError> for io::Error {
99    fn from(value: ProgressError) -> Self {
100        match value {
101            ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
102            ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
103            ProgressError::Internal { source } => source.into(),
104        }
105    }
106}
107
108impl ProgressError {
109    pub fn code(&self) -> quinn::VarInt {
110        match self {
111            ProgressError::Limit => ERR_LIMIT,
112            ProgressError::Permission => ERR_PERMISSION,
113            ProgressError::Internal { .. } => ERR_INTERNAL,
114        }
115    }
116
117    pub fn reason(&self) -> &'static [u8] {
118        match self {
119            ProgressError::Limit => b"limit",
120            ProgressError::Permission => b"permission",
121            ProgressError::Internal { .. } => b"internal",
122        }
123    }
124}
125
126impl From<AbortReason> for ProgressError {
127    fn from(value: AbortReason) -> Self {
128        match value {
129            AbortReason::RateLimited => ProgressError::Limit,
130            AbortReason::Permission => ProgressError::Permission,
131        }
132    }
133}
134
135impl From<irpc::channel::RecvError> for ProgressError {
136    fn from(value: irpc::channel::RecvError) -> Self {
137        ProgressError::Internal {
138            source: value.into(),
139        }
140    }
141}
142
143impl From<irpc::channel::SendError> for ProgressError {
144    fn from(value: irpc::channel::SendError) -> Self {
145        ProgressError::Internal {
146            source: value.into(),
147        }
148    }
149}
150
151pub type EventResult = Result<(), AbortReason>;
152pub type ClientResult = Result<(), ProgressError>;
153
154/// Event mask to configure which events are sent to the event handler.
155///
156/// This can also be used to completely disable certain request types. E.g.
157/// push requests are disabled by default, as they can write to the local store.
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub struct EventMask {
160    /// Connection event mask
161    pub connected: ConnectMode,
162    /// Get request event mask
163    pub get: RequestMode,
164    /// Get many request event mask
165    pub get_many: RequestMode,
166    /// Push request event mask
167    pub push: RequestMode,
168    /// Observe request event mask
169    pub observe: ObserveMode,
170    /// throttling is somewhat costly, so you can disable it completely
171    pub throttle: ThrottleMode,
172}
173
174impl Default for EventMask {
175    fn default() -> Self {
176        Self::DEFAULT
177    }
178}
179
180impl EventMask {
181    /// All event notifications are fully disabled. Push requests are disabled by default.
182    pub const DEFAULT: Self = Self {
183        connected: ConnectMode::None,
184        get: RequestMode::None,
185        get_many: RequestMode::None,
186        push: RequestMode::Disabled,
187        throttle: ThrottleMode::None,
188        observe: ObserveMode::None,
189    };
190
191    /// All event notifications for read-only requests are fully enabled.
192    ///
193    /// If you want to enable push requests, which can write to the local store, you
194    /// need to do it manually. Providing constants that have push enabled would
195    /// risk misuse.
196    pub const ALL_READONLY: Self = Self {
197        connected: ConnectMode::Intercept,
198        get: RequestMode::InterceptLog,
199        get_many: RequestMode::InterceptLog,
200        push: RequestMode::Disabled,
201        throttle: ThrottleMode::Intercept,
202        observe: ObserveMode::Intercept,
203    };
204}
205
206/// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant.
207#[derive(Debug, Serialize, Deserialize)]
208pub struct Notify<T>(T);
209
210impl<T> Deref for Notify<T> {
211    type Target = T;
212
213    fn deref(&self) -> &Self::Target {
214        &self.0
215    }
216}
217
218#[derive(Debug, Default, Clone)]
219pub struct EventSender {
220    mask: EventMask,
221    inner: Option<irpc::Client<ProviderProto>>,
222}
223
224#[derive(Debug, Default)]
225enum RequestUpdates {
226    /// Request tracking was not configured, all ops are no-ops
227    #[default]
228    None,
229    /// Active request tracking, all ops actually send
230    Active(mpsc::Sender<RequestUpdate>),
231    /// Disabled request tracking, we just hold on to the sender so it drops
232    /// once the request is completed or aborted.
233    Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
234}
235
236#[derive(Debug)]
237pub struct RequestTracker {
238    updates: RequestUpdates,
239    throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
240}
241
242impl RequestTracker {
243    fn new(
244        updates: RequestUpdates,
245        throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
246    ) -> Self {
247        Self { updates, throttle }
248    }
249
250    /// A request tracker that doesn't track anything.
251    pub const NONE: Self = Self {
252        updates: RequestUpdates::None,
253        throttle: None,
254    };
255
256    /// Transfer for index `index` started, size `size` in bytes.
257    pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
258        if let RequestUpdates::Active(tx) = &self.updates {
259            tx.send(
260                TransferStarted {
261                    index,
262                    hash: *hash,
263                    size,
264                }
265                .into(),
266            )
267            .await?;
268        }
269        Ok(())
270    }
271
272    /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes.
273    pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
274        if let RequestUpdates::Active(tx) = &mut self.updates {
275            tx.try_send(TransferProgress { end_offset }.into()).await?;
276        }
277        if let Some((throttle, connection_id, request_id)) = &self.throttle {
278            throttle
279                .rpc(Throttle {
280                    connection_id: *connection_id,
281                    request_id: *request_id,
282                    size: len,
283                })
284                .await??;
285        }
286        Ok(())
287    }
288
289    /// Transfer completed for the previously reported blob.
290    pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
291        if let RequestUpdates::Active(tx) = &self.updates {
292            tx.send(TransferCompleted { stats: f() }.into()).await?;
293        }
294        Ok(())
295    }
296
297    /// Transfer aborted for the previously reported blob.
298    pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
299        if let RequestUpdates::Active(tx) = &self.updates {
300            tx.send(TransferAborted { stats: f() }.into()).await?;
301        }
302        Ok(())
303    }
304}
305
306/// Client for progress notifications.
307///
308/// For most event types, the client can be configured to either send notifications or requests that
309/// can have a response.
310impl EventSender {
311    /// A client that does not send anything.
312    pub const DEFAULT: Self = Self {
313        mask: EventMask::DEFAULT,
314        inner: None,
315    };
316
317    pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
318        Self {
319            mask,
320            inner: Some(irpc::Client::from(client)),
321        }
322    }
323
324    pub fn channel(
325        capacity: usize,
326        mask: EventMask,
327    ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
328        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
329        (Self::new(tx, mask), rx)
330    }
331
332    /// Log request events at trace level.
333    pub fn tracing(&self, mask: EventMask) -> Self {
334        use tracing::trace;
335        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
336        n0_future::task::spawn(async move {
337            fn log_request_events(
338                mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
339                connection_id: u64,
340                request_id: u64,
341            ) {
342                n0_future::task::spawn(async move {
343                    while let Ok(Some(update)) = rx.recv().await {
344                        trace!(%connection_id, %request_id, "{update:?}");
345                    }
346                });
347            }
348            while let Some(msg) = rx.recv().await {
349                match msg {
350                    ProviderMessage::ClientConnected(msg) => {
351                        trace!("{:?}", msg.inner);
352                        msg.tx.send(Ok(())).await.ok();
353                    }
354                    ProviderMessage::ClientConnectedNotify(msg) => {
355                        trace!("{:?}", msg.inner);
356                    }
357                    ProviderMessage::ConnectionClosed(msg) => {
358                        trace!("{:?}", msg.inner);
359                    }
360                    ProviderMessage::GetRequestReceived(msg) => {
361                        trace!("{:?}", msg.inner);
362                        msg.tx.send(Ok(())).await.ok();
363                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
364                    }
365                    ProviderMessage::GetRequestReceivedNotify(msg) => {
366                        trace!("{:?}", msg.inner);
367                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
368                    }
369                    ProviderMessage::GetManyRequestReceived(msg) => {
370                        trace!("{:?}", msg.inner);
371                        msg.tx.send(Ok(())).await.ok();
372                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
373                    }
374                    ProviderMessage::GetManyRequestReceivedNotify(msg) => {
375                        trace!("{:?}", msg.inner);
376                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
377                    }
378                    ProviderMessage::PushRequestReceived(msg) => {
379                        trace!("{:?}", msg.inner);
380                        msg.tx.send(Ok(())).await.ok();
381                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
382                    }
383                    ProviderMessage::PushRequestReceivedNotify(msg) => {
384                        trace!("{:?}", msg.inner);
385                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
386                    }
387                    ProviderMessage::ObserveRequestReceived(msg) => {
388                        trace!("{:?}", msg.inner);
389                        msg.tx.send(Ok(())).await.ok();
390                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
391                    }
392                    ProviderMessage::ObserveRequestReceivedNotify(msg) => {
393                        trace!("{:?}", msg.inner);
394                        log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
395                    }
396                    ProviderMessage::Throttle(msg) => {
397                        trace!("{:?}", msg.inner);
398                        msg.tx.send(Ok(())).await.ok();
399                    }
400                }
401            }
402        });
403        Self {
404            mask,
405            inner: Some(irpc::Client::from(tx)),
406        }
407    }
408
409    /// A new client has been connected.
410    pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
411        if let Some(client) = &self.inner {
412            match self.mask.connected {
413                ConnectMode::None => {}
414                ConnectMode::Notify => client.notify(Notify(f())).await?,
415                ConnectMode::Intercept => client.rpc(f()).await??,
416            }
417        };
418        Ok(())
419    }
420
421    /// A connection has been closed.
422    pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
423        if let Some(client) = &self.inner {
424            client.notify(f()).await?;
425        };
426        Ok(())
427    }
428
429    /// Abstract request, to DRY the 3 to 4 request types.
430    ///
431    /// DRYing stuff with lots of bounds is no fun at all...
432    pub(crate) async fn request<Req>(
433        &self,
434        f: impl FnOnce() -> Req,
435        connection_id: u64,
436        request_id: u64,
437    ) -> Result<RequestTracker, ProgressError>
438    where
439        ProviderProto: From<RequestReceived<Req>>,
440        ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
441        RequestReceived<Req>: Channels<
442            ProviderProto,
443            Tx = oneshot::Sender<EventResult>,
444            Rx = mpsc::Receiver<RequestUpdate>,
445        >,
446        ProviderProto: From<Notify<RequestReceived<Req>>>,
447        ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
448        Notify<RequestReceived<Req>>:
449            Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
450    {
451        let client = self.inner.as_ref();
452        Ok(self.create_tracker((
453            match self.mask.get {
454                RequestMode::None => RequestUpdates::None,
455                RequestMode::Notify if client.is_some() => {
456                    let msg = RequestReceived {
457                        request: f(),
458                        connection_id,
459                        request_id,
460                    };
461                    RequestUpdates::Disabled(
462                        client.unwrap().notify_streaming(Notify(msg), 32).await?,
463                    )
464                }
465                RequestMode::Intercept if client.is_some() => {
466                    let msg = RequestReceived {
467                        request: f(),
468                        connection_id,
469                        request_id,
470                    };
471                    let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
472                    // bail out if the request is not allowed
473                    rx.await??;
474                    RequestUpdates::Disabled(tx)
475                }
476                RequestMode::NotifyLog if client.is_some() => {
477                    let msg = RequestReceived {
478                        request: f(),
479                        connection_id,
480                        request_id,
481                    };
482                    RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
483                }
484                RequestMode::InterceptLog if client.is_some() => {
485                    let msg = RequestReceived {
486                        request: f(),
487                        connection_id,
488                        request_id,
489                    };
490                    let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
491                    // bail out if the request is not allowed
492                    rx.await??;
493                    RequestUpdates::Active(tx)
494                }
495                RequestMode::Disabled => {
496                    return Err(ProgressError::Permission);
497                }
498                _ => RequestUpdates::None,
499            },
500            connection_id,
501            request_id,
502        )))
503    }
504
505    fn create_tracker(
506        &self,
507        (updates, connection_id, request_id): (RequestUpdates, u64, u64),
508    ) -> RequestTracker {
509        let throttle = match self.mask.throttle {
510            ThrottleMode::None => None,
511            ThrottleMode::Intercept => self
512                .inner
513                .clone()
514                .map(|client| (client, connection_id, request_id)),
515        };
516        RequestTracker::new(updates, throttle)
517    }
518}
519
520#[rpc_requests(message = ProviderMessage)]
521#[derive(Debug, Serialize, Deserialize)]
522pub enum ProviderProto {
523    /// A new client connected to the provider.
524    #[rpc(tx = oneshot::Sender<EventResult>)]
525    ClientConnected(ClientConnected),
526
527    /// A new client connected to the provider. Notify variant.
528    #[rpc(tx = NoSender)]
529    ClientConnectedNotify(Notify<ClientConnected>),
530
531    /// A client disconnected from the provider.
532    #[rpc(tx = NoSender)]
533    ConnectionClosed(ConnectionClosed),
534
535    /// A new get request was received from the provider.
536    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
537    GetRequestReceived(RequestReceived<GetRequest>),
538
539    /// A new get request was received from the provider (notify variant).
540    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
541    GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
542
543    /// A new get many request was received from the provider.
544    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
545    GetManyRequestReceived(RequestReceived<GetManyRequest>),
546
547    /// A new get many request was received from the provider (notify variant).
548    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
549    GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
550
551    /// A new push request was received from the provider.
552    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
553    PushRequestReceived(RequestReceived<PushRequest>),
554
555    /// A new push request was received from the provider (notify variant).
556    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
557    PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
558
559    /// A new observe request was received from the provider.
560    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
561    ObserveRequestReceived(RequestReceived<ObserveRequest>),
562
563    /// A new observe request was received from the provider (notify variant).
564    #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
565    ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
566
567    /// Request to throttle sending for a specific data request.
568    #[rpc(tx = oneshot::Sender<EventResult>)]
569    Throttle(Throttle),
570}
571
572mod proto {
573    use iroh::NodeId;
574    use serde::{Deserialize, Serialize};
575
576    use crate::{provider::TransferStats, Hash};
577
578    #[derive(Debug, Serialize, Deserialize)]
579    pub struct ClientConnected {
580        pub connection_id: u64,
581        pub node_id: Option<NodeId>,
582    }
583
584    #[derive(Debug, Serialize, Deserialize)]
585    pub struct ConnectionClosed {
586        pub connection_id: u64,
587    }
588
589    /// A new get request was received from the provider.
590    #[derive(Debug, Serialize, Deserialize)]
591    pub struct RequestReceived<R> {
592        /// The connection id. Multiple requests can be sent over the same connection.
593        pub connection_id: u64,
594        /// The request id. There is a new id for each request.
595        pub request_id: u64,
596        /// The request
597        pub request: R,
598    }
599
600    /// Request to throttle sending for a specific request.
601    #[derive(Debug, Serialize, Deserialize)]
602    pub struct Throttle {
603        /// The connection id. Multiple requests can be sent over the same connection.
604        pub connection_id: u64,
605        /// The request id. There is a new id for each request.
606        pub request_id: u64,
607        /// Size of the chunk to be throttled. This will usually be 16 KiB.
608        pub size: u64,
609    }
610
611    #[derive(Debug, Serialize, Deserialize)]
612    pub struct TransferProgress {
613        /// The end offset of the chunk that was sent.
614        pub end_offset: u64,
615    }
616
617    #[derive(Debug, Serialize, Deserialize)]
618    pub struct TransferStarted {
619        pub index: u64,
620        pub hash: Hash,
621        pub size: u64,
622    }
623
624    #[derive(Debug, Serialize, Deserialize)]
625    pub struct TransferCompleted {
626        pub stats: Box<TransferStats>,
627    }
628
629    #[derive(Debug, Serialize, Deserialize)]
630    pub struct TransferAborted {
631        pub stats: Box<TransferStats>,
632    }
633
634    /// Stream of updates for a single request
635    #[derive(Debug, Serialize, Deserialize, derive_more::From)]
636    pub enum RequestUpdate {
637        /// Start of transfer for a blob, mandatory event
638        Started(TransferStarted),
639        /// Progress for a blob - optional event
640        Progress(TransferProgress),
641        /// Successful end of transfer
642        Completed(TransferCompleted),
643        /// Aborted end of transfer
644        Aborted(TransferAborted),
645    }
646}
647pub use proto::*;
648
649mod irpc_ext {
650    use std::future::Future;
651
652    use irpc::{
653        channel::{mpsc, none::NoSender},
654        Channels, RpcMessage, Service, WithChannels,
655    };
656
657    pub trait IrpcClientExt<S: Service> {
658        fn notify_streaming<Req, Update>(
659            &self,
660            msg: Req,
661            local_update_cap: usize,
662        ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
663        where
664            S: From<Req>,
665            S::Message: From<WithChannels<Req, S>>,
666            Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
667            Update: RpcMessage;
668    }
669
670    impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
671        fn notify_streaming<Req, Update>(
672            &self,
673            msg: Req,
674            local_update_cap: usize,
675        ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
676        where
677            S: From<Req>,
678            S::Message: From<WithChannels<Req, S>>,
679            Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
680            Update: RpcMessage,
681        {
682            let client = self.clone();
683            async move {
684                let request = client.request().await?;
685                match request {
686                    irpc::Request::Local(local) => {
687                        let (req_tx, req_rx) = mpsc::channel(local_update_cap);
688                        local
689                            .send((msg, NoSender, req_rx))
690                            .await
691                            .map_err(irpc::Error::from)?;
692                        Ok(req_tx)
693                    }
694                    irpc::Request::Remote(remote) => {
695                        let (s, _) = remote.write(msg).await?;
696                        Ok(s.into())
697                    }
698                }
699            }
700        }
701    }
702}