1use std::{fmt::Debug, io, ops::Deref};
2
3use irpc::{
4 channel::{mpsc, none::NoSender, oneshot},
5 rpc_requests, Channels, WithChannels,
6};
7use serde::{Deserialize, Serialize};
8use snafu::Snafu;
9
10use crate::{
11 protocol::{
12 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13 ERR_PERMISSION,
14 },
15 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
16 Hash,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[repr(u8)]
22pub enum ConnectMode {
23 #[default]
25 None,
26 Notify,
28 Intercept,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[repr(u8)]
35pub enum ObserveMode {
36 #[default]
38 None,
39 Notify,
41 Intercept,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RequestMode {
49 #[default]
51 None,
52 Notify,
54 Intercept,
56 NotifyLog,
58 InterceptLog,
61 Disabled,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70#[repr(u8)]
71pub enum ThrottleMode {
72 #[default]
74 None,
75 Intercept,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AbortReason {
81 RateLimited,
83 Permission,
85}
86
87#[derive(Debug, Snafu)]
89pub enum ProgressError {
90 Limit,
91 Permission,
92 #[snafu(transparent)]
93 Internal {
94 source: irpc::Error,
95 },
96}
97
98impl From<ProgressError> for io::Error {
99 fn from(value: ProgressError) -> Self {
100 match value {
101 ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
102 ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
103 ProgressError::Internal { source } => source.into(),
104 }
105 }
106}
107
108pub trait HasErrorCode {
109 fn code(&self) -> quinn::VarInt;
110}
111
112impl HasErrorCode for ProgressError {
113 fn code(&self) -> quinn::VarInt {
114 match self {
115 ProgressError::Limit => ERR_LIMIT,
116 ProgressError::Permission => ERR_PERMISSION,
117 ProgressError::Internal { .. } => ERR_INTERNAL,
118 }
119 }
120}
121
122impl ProgressError {
123 pub fn reason(&self) -> &'static [u8] {
124 match self {
125 ProgressError::Limit => b"limit",
126 ProgressError::Permission => b"permission",
127 ProgressError::Internal { .. } => b"internal",
128 }
129 }
130}
131
132impl From<AbortReason> for ProgressError {
133 fn from(value: AbortReason) -> Self {
134 match value {
135 AbortReason::RateLimited => ProgressError::Limit,
136 AbortReason::Permission => ProgressError::Permission,
137 }
138 }
139}
140
141impl From<irpc::channel::mpsc::RecvError> for ProgressError {
142 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
143 ProgressError::Internal {
144 source: value.into(),
145 }
146 }
147}
148
149impl From<irpc::channel::oneshot::RecvError> for ProgressError {
150 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
151 ProgressError::Internal {
152 source: value.into(),
153 }
154 }
155}
156
157impl From<irpc::channel::SendError> for ProgressError {
158 fn from(value: irpc::channel::SendError) -> Self {
159 ProgressError::Internal {
160 source: value.into(),
161 }
162 }
163}
164
165pub type EventResult = Result<(), AbortReason>;
166pub type ClientResult = Result<(), ProgressError>;
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct EventMask {
174 pub connected: ConnectMode,
176 pub get: RequestMode,
178 pub get_many: RequestMode,
180 pub push: RequestMode,
182 pub observe: ObserveMode,
184 pub throttle: ThrottleMode,
186}
187
188impl Default for EventMask {
189 fn default() -> Self {
190 Self::DEFAULT
191 }
192}
193
194impl EventMask {
195 pub const DEFAULT: Self = Self {
197 connected: ConnectMode::None,
198 get: RequestMode::None,
199 get_many: RequestMode::None,
200 push: RequestMode::Disabled,
201 throttle: ThrottleMode::None,
202 observe: ObserveMode::None,
203 };
204
205 pub const ALL_READONLY: Self = Self {
211 connected: ConnectMode::Intercept,
212 get: RequestMode::InterceptLog,
213 get_many: RequestMode::InterceptLog,
214 push: RequestMode::Disabled,
215 throttle: ThrottleMode::Intercept,
216 observe: ObserveMode::Intercept,
217 };
218}
219
220#[derive(Debug, Serialize, Deserialize)]
222pub struct Notify<T>(T);
223
224impl<T> Deref for Notify<T> {
225 type Target = T;
226
227 fn deref(&self) -> &Self::Target {
228 &self.0
229 }
230}
231
232#[derive(Debug, Default, Clone)]
233pub struct EventSender {
234 mask: EventMask,
235 inner: Option<irpc::Client<ProviderProto>>,
236}
237
238#[derive(Debug, Default)]
239enum RequestUpdates {
240 #[default]
242 None,
243 Active(mpsc::Sender<RequestUpdate>),
245 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
248}
249
250#[derive(Debug)]
251pub struct RequestTracker {
252 updates: RequestUpdates,
253 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
254}
255
256impl RequestTracker {
257 fn new(
258 updates: RequestUpdates,
259 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
260 ) -> Self {
261 Self { updates, throttle }
262 }
263
264 pub const NONE: Self = Self {
266 updates: RequestUpdates::None,
267 throttle: None,
268 };
269
270 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
272 if let RequestUpdates::Active(tx) = &self.updates {
273 tx.send(
274 TransferStarted {
275 index,
276 hash: *hash,
277 size,
278 }
279 .into(),
280 )
281 .await?;
282 }
283 Ok(())
284 }
285
286 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
288 if let RequestUpdates::Active(tx) = &mut self.updates {
289 tx.try_send(TransferProgress { end_offset }.into()).await?;
290 }
291 if let Some((throttle, connection_id, request_id)) = &self.throttle {
292 throttle
293 .rpc(Throttle {
294 connection_id: *connection_id,
295 request_id: *request_id,
296 size: len,
297 })
298 .await??;
299 }
300 Ok(())
301 }
302
303 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
305 if let RequestUpdates::Active(tx) = &self.updates {
306 tx.send(TransferCompleted { stats: f() }.into()).await?;
307 }
308 Ok(())
309 }
310
311 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
313 if let RequestUpdates::Active(tx) = &self.updates {
314 tx.send(TransferAborted { stats: f() }.into()).await?;
315 }
316 Ok(())
317 }
318}
319
320impl EventSender {
325 pub const DEFAULT: Self = Self {
327 mask: EventMask::DEFAULT,
328 inner: None,
329 };
330
331 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
332 Self {
333 mask,
334 inner: Some(irpc::Client::from(client)),
335 }
336 }
337
338 pub fn channel(
339 capacity: usize,
340 mask: EventMask,
341 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
342 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
343 (Self::new(tx, mask), rx)
344 }
345
346 pub fn tracing(&self, mask: EventMask) -> Self {
348 use tracing::trace;
349 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
350 n0_future::task::spawn(async move {
351 fn log_request_events(
352 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
353 connection_id: u64,
354 request_id: u64,
355 ) {
356 n0_future::task::spawn(async move {
357 while let Ok(Some(update)) = rx.recv().await {
358 trace!(%connection_id, %request_id, "{update:?}");
359 }
360 });
361 }
362 while let Some(msg) = rx.recv().await {
363 match msg {
364 ProviderMessage::ClientConnected(msg) => {
365 trace!("{:?}", msg.inner);
366 msg.tx.send(Ok(())).await.ok();
367 }
368 ProviderMessage::ClientConnectedNotify(msg) => {
369 trace!("{:?}", msg.inner);
370 }
371 ProviderMessage::ConnectionClosed(msg) => {
372 trace!("{:?}", msg.inner);
373 }
374 ProviderMessage::GetRequestReceived(msg) => {
375 trace!("{:?}", msg.inner);
376 msg.tx.send(Ok(())).await.ok();
377 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
378 }
379 ProviderMessage::GetRequestReceivedNotify(msg) => {
380 trace!("{:?}", msg.inner);
381 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
382 }
383 ProviderMessage::GetManyRequestReceived(msg) => {
384 trace!("{:?}", msg.inner);
385 msg.tx.send(Ok(())).await.ok();
386 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
387 }
388 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
389 trace!("{:?}", msg.inner);
390 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
391 }
392 ProviderMessage::PushRequestReceived(msg) => {
393 trace!("{:?}", msg.inner);
394 msg.tx.send(Ok(())).await.ok();
395 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
396 }
397 ProviderMessage::PushRequestReceivedNotify(msg) => {
398 trace!("{:?}", msg.inner);
399 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
400 }
401 ProviderMessage::ObserveRequestReceived(msg) => {
402 trace!("{:?}", msg.inner);
403 msg.tx.send(Ok(())).await.ok();
404 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
405 }
406 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
407 trace!("{:?}", msg.inner);
408 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
409 }
410 ProviderMessage::Throttle(msg) => {
411 trace!("{:?}", msg.inner);
412 msg.tx.send(Ok(())).await.ok();
413 }
414 }
415 }
416 });
417 Self {
418 mask,
419 inner: Some(irpc::Client::from(tx)),
420 }
421 }
422
423 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
425 if let Some(client) = &self.inner {
426 match self.mask.connected {
427 ConnectMode::None => {}
428 ConnectMode::Notify => client.notify(Notify(f())).await?,
429 ConnectMode::Intercept => client.rpc(f()).await??,
430 }
431 };
432 Ok(())
433 }
434
435 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
437 if let Some(client) = &self.inner {
438 client.notify(f()).await?;
439 };
440 Ok(())
441 }
442
443 pub(crate) async fn request<Req>(
447 &self,
448 f: impl FnOnce() -> Req,
449 connection_id: u64,
450 request_id: u64,
451 ) -> Result<RequestTracker, ProgressError>
452 where
453 ProviderProto: From<RequestReceived<Req>>,
454 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
455 RequestReceived<Req>: Channels<
456 ProviderProto,
457 Tx = oneshot::Sender<EventResult>,
458 Rx = mpsc::Receiver<RequestUpdate>,
459 >,
460 ProviderProto: From<Notify<RequestReceived<Req>>>,
461 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
462 Notify<RequestReceived<Req>>:
463 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
464 {
465 let client = self.inner.as_ref();
466 Ok(self.create_tracker((
467 match self.mask.get {
468 RequestMode::None => RequestUpdates::None,
469 RequestMode::Notify if client.is_some() => {
470 let msg = RequestReceived {
471 request: f(),
472 connection_id,
473 request_id,
474 };
475 RequestUpdates::Disabled(
476 client.unwrap().notify_streaming(Notify(msg), 32).await?,
477 )
478 }
479 RequestMode::Intercept if client.is_some() => {
480 let msg = RequestReceived {
481 request: f(),
482 connection_id,
483 request_id,
484 };
485 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
486 rx.await??;
488 RequestUpdates::Disabled(tx)
489 }
490 RequestMode::NotifyLog if client.is_some() => {
491 let msg = RequestReceived {
492 request: f(),
493 connection_id,
494 request_id,
495 };
496 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
497 }
498 RequestMode::InterceptLog if client.is_some() => {
499 let msg = RequestReceived {
500 request: f(),
501 connection_id,
502 request_id,
503 };
504 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
505 rx.await??;
507 RequestUpdates::Active(tx)
508 }
509 RequestMode::Disabled => {
510 return Err(ProgressError::Permission);
511 }
512 _ => RequestUpdates::None,
513 },
514 connection_id,
515 request_id,
516 )))
517 }
518
519 fn create_tracker(
520 &self,
521 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
522 ) -> RequestTracker {
523 let throttle = match self.mask.throttle {
524 ThrottleMode::None => None,
525 ThrottleMode::Intercept => self
526 .inner
527 .clone()
528 .map(|client| (client, connection_id, request_id)),
529 };
530 RequestTracker::new(updates, throttle)
531 }
532}
533
534#[rpc_requests(message = ProviderMessage)]
535#[derive(Debug, Serialize, Deserialize)]
536pub enum ProviderProto {
537 #[rpc(tx = oneshot::Sender<EventResult>)]
539 ClientConnected(ClientConnected),
540
541 #[rpc(tx = NoSender)]
543 ClientConnectedNotify(Notify<ClientConnected>),
544
545 #[rpc(tx = NoSender)]
547 ConnectionClosed(ConnectionClosed),
548
549 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
551 GetRequestReceived(RequestReceived<GetRequest>),
552
553 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
555 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
556
557 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
559 GetManyRequestReceived(RequestReceived<GetManyRequest>),
560
561 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
563 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
564
565 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
567 PushRequestReceived(RequestReceived<PushRequest>),
568
569 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
571 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
572
573 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
575 ObserveRequestReceived(RequestReceived<ObserveRequest>),
576
577 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
579 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
580
581 #[rpc(tx = oneshot::Sender<EventResult>)]
583 Throttle(Throttle),
584}
585
586mod proto {
587 use iroh::NodeId;
588 use serde::{Deserialize, Serialize};
589
590 use crate::{provider::TransferStats, Hash};
591
592 #[derive(Debug, Serialize, Deserialize)]
593 pub struct ClientConnected {
594 pub connection_id: u64,
595 pub node_id: Option<NodeId>,
596 }
597
598 #[derive(Debug, Serialize, Deserialize)]
599 pub struct ConnectionClosed {
600 pub connection_id: u64,
601 }
602
603 #[derive(Debug, Serialize, Deserialize)]
605 pub struct RequestReceived<R> {
606 pub connection_id: u64,
608 pub request_id: u64,
610 pub request: R,
612 }
613
614 #[derive(Debug, Serialize, Deserialize)]
616 pub struct Throttle {
617 pub connection_id: u64,
619 pub request_id: u64,
621 pub size: u64,
623 }
624
625 #[derive(Debug, Serialize, Deserialize)]
626 pub struct TransferProgress {
627 pub end_offset: u64,
629 }
630
631 #[derive(Debug, Serialize, Deserialize)]
632 pub struct TransferStarted {
633 pub index: u64,
634 pub hash: Hash,
635 pub size: u64,
636 }
637
638 #[derive(Debug, Serialize, Deserialize)]
639 pub struct TransferCompleted {
640 pub stats: Box<TransferStats>,
641 }
642
643 #[derive(Debug, Serialize, Deserialize)]
644 pub struct TransferAborted {
645 pub stats: Box<TransferStats>,
646 }
647
648 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
650 pub enum RequestUpdate {
651 Started(TransferStarted),
653 Progress(TransferProgress),
655 Completed(TransferCompleted),
657 Aborted(TransferAborted),
659 }
660}
661pub use proto::*;
662
663mod irpc_ext {
664 use std::future::Future;
665
666 use irpc::{
667 channel::{mpsc, none::NoSender},
668 Channels, RpcMessage, Service, WithChannels,
669 };
670
671 pub trait IrpcClientExt<S: Service> {
672 fn notify_streaming<Req, Update>(
673 &self,
674 msg: Req,
675 local_update_cap: usize,
676 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
677 where
678 S: From<Req>,
679 S::Message: From<WithChannels<Req, S>>,
680 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
681 Update: RpcMessage;
682 }
683
684 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
685 fn notify_streaming<Req, Update>(
686 &self,
687 msg: Req,
688 local_update_cap: usize,
689 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
690 where
691 S: From<Req>,
692 S::Message: From<WithChannels<Req, S>>,
693 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
694 Update: RpcMessage,
695 {
696 let client = self.clone();
697 async move {
698 let request = client.request().await?;
699 match request {
700 irpc::Request::Local(local) => {
701 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
702 local
703 .send((msg, NoSender, req_rx))
704 .await
705 .map_err(irpc::Error::from)?;
706 Ok(req_tx)
707 }
708 irpc::Request::Remote(remote) => {
709 let (s, _) = remote.write(msg).await?;
710 Ok(s.into())
711 }
712 }
713 }
714 }
715 }
716}