1use std::{fmt::Debug, io, ops::Deref};
2
3use irpc::{
4 channel::{mpsc, none::NoSender, oneshot},
5 rpc_requests, Channels, WithChannels,
6};
7use serde::{Deserialize, Serialize};
8use snafu::Snafu;
9
10use crate::{
11 protocol::{
12 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13 ERR_PERMISSION,
14 },
15 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
16 Hash,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[repr(u8)]
22pub enum ConnectMode {
23 #[default]
25 None,
26 Notify,
28 Intercept,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[repr(u8)]
35pub enum ObserveMode {
36 #[default]
38 None,
39 Notify,
41 Intercept,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RequestMode {
49 #[default]
51 None,
52 Notify,
54 Intercept,
56 NotifyLog,
58 InterceptLog,
61 Disabled,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70#[repr(u8)]
71pub enum ThrottleMode {
72 #[default]
74 None,
75 Intercept,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AbortReason {
81 RateLimited,
83 Permission,
85}
86
87#[derive(Debug, Snafu)]
89pub enum ProgressError {
90 Limit,
91 Permission,
92 #[snafu(transparent)]
93 Internal {
94 source: irpc::Error,
95 },
96}
97
98impl From<ProgressError> for io::Error {
99 fn from(value: ProgressError) -> Self {
100 match value {
101 ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
102 ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
103 ProgressError::Internal { source } => source.into(),
104 }
105 }
106}
107
108impl ProgressError {
109 pub fn code(&self) -> quinn::VarInt {
110 match self {
111 ProgressError::Limit => ERR_LIMIT,
112 ProgressError::Permission => ERR_PERMISSION,
113 ProgressError::Internal { .. } => ERR_INTERNAL,
114 }
115 }
116
117 pub fn reason(&self) -> &'static [u8] {
118 match self {
119 ProgressError::Limit => b"limit",
120 ProgressError::Permission => b"permission",
121 ProgressError::Internal { .. } => b"internal",
122 }
123 }
124}
125
126impl From<AbortReason> for ProgressError {
127 fn from(value: AbortReason) -> Self {
128 match value {
129 AbortReason::RateLimited => ProgressError::Limit,
130 AbortReason::Permission => ProgressError::Permission,
131 }
132 }
133}
134
135impl From<irpc::channel::RecvError> for ProgressError {
136 fn from(value: irpc::channel::RecvError) -> Self {
137 ProgressError::Internal {
138 source: value.into(),
139 }
140 }
141}
142
143impl From<irpc::channel::SendError> for ProgressError {
144 fn from(value: irpc::channel::SendError) -> Self {
145 ProgressError::Internal {
146 source: value.into(),
147 }
148 }
149}
150
151pub type EventResult = Result<(), AbortReason>;
152pub type ClientResult = Result<(), ProgressError>;
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub struct EventMask {
160 pub connected: ConnectMode,
162 pub get: RequestMode,
164 pub get_many: RequestMode,
166 pub push: RequestMode,
168 pub observe: ObserveMode,
170 pub throttle: ThrottleMode,
172}
173
174impl Default for EventMask {
175 fn default() -> Self {
176 Self::DEFAULT
177 }
178}
179
180impl EventMask {
181 pub const DEFAULT: Self = Self {
183 connected: ConnectMode::None,
184 get: RequestMode::None,
185 get_many: RequestMode::None,
186 push: RequestMode::Disabled,
187 throttle: ThrottleMode::None,
188 observe: ObserveMode::None,
189 };
190
191 pub const ALL_READONLY: Self = Self {
197 connected: ConnectMode::Intercept,
198 get: RequestMode::InterceptLog,
199 get_many: RequestMode::InterceptLog,
200 push: RequestMode::Disabled,
201 throttle: ThrottleMode::Intercept,
202 observe: ObserveMode::Intercept,
203 };
204}
205
206#[derive(Debug, Serialize, Deserialize)]
208pub struct Notify<T>(T);
209
210impl<T> Deref for Notify<T> {
211 type Target = T;
212
213 fn deref(&self) -> &Self::Target {
214 &self.0
215 }
216}
217
218#[derive(Debug, Default, Clone)]
219pub struct EventSender {
220 mask: EventMask,
221 inner: Option<irpc::Client<ProviderProto>>,
222}
223
224#[derive(Debug, Default)]
225enum RequestUpdates {
226 #[default]
228 None,
229 Active(mpsc::Sender<RequestUpdate>),
231 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
234}
235
236#[derive(Debug)]
237pub struct RequestTracker {
238 updates: RequestUpdates,
239 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
240}
241
242impl RequestTracker {
243 fn new(
244 updates: RequestUpdates,
245 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
246 ) -> Self {
247 Self { updates, throttle }
248 }
249
250 pub const NONE: Self = Self {
252 updates: RequestUpdates::None,
253 throttle: None,
254 };
255
256 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
258 if let RequestUpdates::Active(tx) = &self.updates {
259 tx.send(
260 TransferStarted {
261 index,
262 hash: *hash,
263 size,
264 }
265 .into(),
266 )
267 .await?;
268 }
269 Ok(())
270 }
271
272 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
274 if let RequestUpdates::Active(tx) = &mut self.updates {
275 tx.try_send(TransferProgress { end_offset }.into()).await?;
276 }
277 if let Some((throttle, connection_id, request_id)) = &self.throttle {
278 throttle
279 .rpc(Throttle {
280 connection_id: *connection_id,
281 request_id: *request_id,
282 size: len,
283 })
284 .await??;
285 }
286 Ok(())
287 }
288
289 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
291 if let RequestUpdates::Active(tx) = &self.updates {
292 tx.send(TransferCompleted { stats: f() }.into()).await?;
293 }
294 Ok(())
295 }
296
297 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
299 if let RequestUpdates::Active(tx) = &self.updates {
300 tx.send(TransferAborted { stats: f() }.into()).await?;
301 }
302 Ok(())
303 }
304}
305
306impl EventSender {
311 pub const DEFAULT: Self = Self {
313 mask: EventMask::DEFAULT,
314 inner: None,
315 };
316
317 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
318 Self {
319 mask,
320 inner: Some(irpc::Client::from(client)),
321 }
322 }
323
324 pub fn channel(
325 capacity: usize,
326 mask: EventMask,
327 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
328 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
329 (Self::new(tx, mask), rx)
330 }
331
332 pub fn tracing(&self, mask: EventMask) -> Self {
334 use tracing::trace;
335 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
336 n0_future::task::spawn(async move {
337 fn log_request_events(
338 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
339 connection_id: u64,
340 request_id: u64,
341 ) {
342 n0_future::task::spawn(async move {
343 while let Ok(Some(update)) = rx.recv().await {
344 trace!(%connection_id, %request_id, "{update:?}");
345 }
346 });
347 }
348 while let Some(msg) = rx.recv().await {
349 match msg {
350 ProviderMessage::ClientConnected(msg) => {
351 trace!("{:?}", msg.inner);
352 msg.tx.send(Ok(())).await.ok();
353 }
354 ProviderMessage::ClientConnectedNotify(msg) => {
355 trace!("{:?}", msg.inner);
356 }
357 ProviderMessage::ConnectionClosed(msg) => {
358 trace!("{:?}", msg.inner);
359 }
360 ProviderMessage::GetRequestReceived(msg) => {
361 trace!("{:?}", msg.inner);
362 msg.tx.send(Ok(())).await.ok();
363 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
364 }
365 ProviderMessage::GetRequestReceivedNotify(msg) => {
366 trace!("{:?}", msg.inner);
367 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
368 }
369 ProviderMessage::GetManyRequestReceived(msg) => {
370 trace!("{:?}", msg.inner);
371 msg.tx.send(Ok(())).await.ok();
372 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
373 }
374 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
375 trace!("{:?}", msg.inner);
376 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
377 }
378 ProviderMessage::PushRequestReceived(msg) => {
379 trace!("{:?}", msg.inner);
380 msg.tx.send(Ok(())).await.ok();
381 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
382 }
383 ProviderMessage::PushRequestReceivedNotify(msg) => {
384 trace!("{:?}", msg.inner);
385 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
386 }
387 ProviderMessage::ObserveRequestReceived(msg) => {
388 trace!("{:?}", msg.inner);
389 msg.tx.send(Ok(())).await.ok();
390 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
391 }
392 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
393 trace!("{:?}", msg.inner);
394 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
395 }
396 ProviderMessage::Throttle(msg) => {
397 trace!("{:?}", msg.inner);
398 msg.tx.send(Ok(())).await.ok();
399 }
400 }
401 }
402 });
403 Self {
404 mask,
405 inner: Some(irpc::Client::from(tx)),
406 }
407 }
408
409 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
411 if let Some(client) = &self.inner {
412 match self.mask.connected {
413 ConnectMode::None => {}
414 ConnectMode::Notify => client.notify(Notify(f())).await?,
415 ConnectMode::Intercept => client.rpc(f()).await??,
416 }
417 };
418 Ok(())
419 }
420
421 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
423 if let Some(client) = &self.inner {
424 client.notify(f()).await?;
425 };
426 Ok(())
427 }
428
429 pub(crate) async fn request<Req>(
433 &self,
434 f: impl FnOnce() -> Req,
435 connection_id: u64,
436 request_id: u64,
437 ) -> Result<RequestTracker, ProgressError>
438 where
439 ProviderProto: From<RequestReceived<Req>>,
440 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
441 RequestReceived<Req>: Channels<
442 ProviderProto,
443 Tx = oneshot::Sender<EventResult>,
444 Rx = mpsc::Receiver<RequestUpdate>,
445 >,
446 ProviderProto: From<Notify<RequestReceived<Req>>>,
447 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
448 Notify<RequestReceived<Req>>:
449 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
450 {
451 let client = self.inner.as_ref();
452 Ok(self.create_tracker((
453 match self.mask.get {
454 RequestMode::None => RequestUpdates::None,
455 RequestMode::Notify if client.is_some() => {
456 let msg = RequestReceived {
457 request: f(),
458 connection_id,
459 request_id,
460 };
461 RequestUpdates::Disabled(
462 client.unwrap().notify_streaming(Notify(msg), 32).await?,
463 )
464 }
465 RequestMode::Intercept if client.is_some() => {
466 let msg = RequestReceived {
467 request: f(),
468 connection_id,
469 request_id,
470 };
471 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
472 rx.await??;
474 RequestUpdates::Disabled(tx)
475 }
476 RequestMode::NotifyLog if client.is_some() => {
477 let msg = RequestReceived {
478 request: f(),
479 connection_id,
480 request_id,
481 };
482 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
483 }
484 RequestMode::InterceptLog if client.is_some() => {
485 let msg = RequestReceived {
486 request: f(),
487 connection_id,
488 request_id,
489 };
490 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
491 rx.await??;
493 RequestUpdates::Active(tx)
494 }
495 RequestMode::Disabled => {
496 return Err(ProgressError::Permission);
497 }
498 _ => RequestUpdates::None,
499 },
500 connection_id,
501 request_id,
502 )))
503 }
504
505 fn create_tracker(
506 &self,
507 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
508 ) -> RequestTracker {
509 let throttle = match self.mask.throttle {
510 ThrottleMode::None => None,
511 ThrottleMode::Intercept => self
512 .inner
513 .clone()
514 .map(|client| (client, connection_id, request_id)),
515 };
516 RequestTracker::new(updates, throttle)
517 }
518}
519
520#[rpc_requests(message = ProviderMessage)]
521#[derive(Debug, Serialize, Deserialize)]
522pub enum ProviderProto {
523 #[rpc(tx = oneshot::Sender<EventResult>)]
525 ClientConnected(ClientConnected),
526
527 #[rpc(tx = NoSender)]
529 ClientConnectedNotify(Notify<ClientConnected>),
530
531 #[rpc(tx = NoSender)]
533 ConnectionClosed(ConnectionClosed),
534
535 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
537 GetRequestReceived(RequestReceived<GetRequest>),
538
539 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
541 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
542
543 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
545 GetManyRequestReceived(RequestReceived<GetManyRequest>),
546
547 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
549 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
550
551 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
553 PushRequestReceived(RequestReceived<PushRequest>),
554
555 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
557 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
558
559 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
561 ObserveRequestReceived(RequestReceived<ObserveRequest>),
562
563 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
565 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
566
567 #[rpc(tx = oneshot::Sender<EventResult>)]
569 Throttle(Throttle),
570}
571
572mod proto {
573 use iroh::NodeId;
574 use serde::{Deserialize, Serialize};
575
576 use crate::{provider::TransferStats, Hash};
577
578 #[derive(Debug, Serialize, Deserialize)]
579 pub struct ClientConnected {
580 pub connection_id: u64,
581 pub node_id: Option<NodeId>,
582 }
583
584 #[derive(Debug, Serialize, Deserialize)]
585 pub struct ConnectionClosed {
586 pub connection_id: u64,
587 }
588
589 #[derive(Debug, Serialize, Deserialize)]
591 pub struct RequestReceived<R> {
592 pub connection_id: u64,
594 pub request_id: u64,
596 pub request: R,
598 }
599
600 #[derive(Debug, Serialize, Deserialize)]
602 pub struct Throttle {
603 pub connection_id: u64,
605 pub request_id: u64,
607 pub size: u64,
609 }
610
611 #[derive(Debug, Serialize, Deserialize)]
612 pub struct TransferProgress {
613 pub end_offset: u64,
615 }
616
617 #[derive(Debug, Serialize, Deserialize)]
618 pub struct TransferStarted {
619 pub index: u64,
620 pub hash: Hash,
621 pub size: u64,
622 }
623
624 #[derive(Debug, Serialize, Deserialize)]
625 pub struct TransferCompleted {
626 pub stats: Box<TransferStats>,
627 }
628
629 #[derive(Debug, Serialize, Deserialize)]
630 pub struct TransferAborted {
631 pub stats: Box<TransferStats>,
632 }
633
634 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
636 pub enum RequestUpdate {
637 Started(TransferStarted),
639 Progress(TransferProgress),
641 Completed(TransferCompleted),
643 Aborted(TransferAborted),
645 }
646}
647pub use proto::*;
648
649mod irpc_ext {
650 use std::future::Future;
651
652 use irpc::{
653 channel::{mpsc, none::NoSender},
654 Channels, RpcMessage, Service, WithChannels,
655 };
656
657 pub trait IrpcClientExt<S: Service> {
658 fn notify_streaming<Req, Update>(
659 &self,
660 msg: Req,
661 local_update_cap: usize,
662 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
663 where
664 S: From<Req>,
665 S::Message: From<WithChannels<Req, S>>,
666 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
667 Update: RpcMessage;
668 }
669
670 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
671 fn notify_streaming<Req, Update>(
672 &self,
673 msg: Req,
674 local_update_cap: usize,
675 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
676 where
677 S: From<Req>,
678 S::Message: From<WithChannels<Req, S>>,
679 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
680 Update: RpcMessage,
681 {
682 let client = self.clone();
683 async move {
684 let request = client.request().await?;
685 match request {
686 irpc::Request::Local(local) => {
687 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
688 local
689 .send((msg, NoSender, req_rx))
690 .await
691 .map_err(irpc::Error::from)?;
692 Ok(req_tx)
693 }
694 irpc::Request::Remote(remote) => {
695 let (s, _) = remote.write(msg).await?;
696 Ok(s.into())
697 }
698 }
699 }
700 }
701 }
702}