1use parking_lot::Mutex;
2use std::panic::AssertUnwindSafe;
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use tokio::sync::{mpsc, Notify};
6use tracing::info_span;
7
8use crate::envelope::Envelope;
9use crate::error::{PanicError, RelayError};
10use crate::subscription::Subscription;
11use crate::tracker::CompletionTracker;
12
13const DEFAULT_CHANNEL_SIZE: usize = 65536;
14
15static RELAY_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
17
18fn next_relay_id() -> u64 {
19 RELAY_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
20}
21
22pub(crate) struct Inner {
27 id: u64,
29 subscribers: Mutex<Vec<SubscriberSender>>,
30 channel_size: usize,
31 pending_ready: AtomicUsize,
33 ready_count: AtomicUsize,
35 ready_notify: Notify,
37 msg_id_counter: AtomicU64,
38 handler_count: Arc<AtomicUsize>,
39 closed: AtomicBool,
40}
41
42impl Inner {
43 fn new(channel_size: usize) -> Self {
44 Self {
45 id: next_relay_id(),
46 subscribers: Mutex::new(Vec::new()),
47 channel_size,
48 pending_ready: AtomicUsize::new(0),
49 ready_count: AtomicUsize::new(0),
50 ready_notify: Notify::new(),
51 msg_id_counter: AtomicU64::new(0),
52 handler_count: Arc::new(AtomicUsize::new(0)),
53 closed: AtomicBool::new(false),
54 }
55 }
56
57 async fn wait_ready(&self) {
59 loop {
60 let pending = self.pending_ready.load(Ordering::SeqCst);
62 let ready = self.ready_count.load(Ordering::SeqCst);
63 if ready >= pending {
64 return;
65 }
66
67 let notified = self.ready_notify.notified();
69
70 let pending = self.pending_ready.load(Ordering::SeqCst);
72 let ready = self.ready_count.load(Ordering::SeqCst);
73 if ready >= pending {
74 return;
75 }
76
77 notified.await;
78 }
79 }
80}
81
82impl Drop for Inner {
83 fn drop(&mut self) {
84 self.closed.store(true, Ordering::SeqCst);
86 self.subscribers.lock().clear();
87 }
88}
89
90pub struct ReadyGuard {
95 inner: Option<Arc<Inner>>,
96 signaled: bool,
97}
98
99impl ReadyGuard {
100 fn new(inner: Arc<Inner>) -> Self {
101 inner.pending_ready.fetch_add(1, Ordering::SeqCst);
102 Self {
103 inner: Some(inner),
104 signaled: false,
105 }
106 }
107
108 pub fn signal(&mut self) {
111 if !self.signaled {
112 self.signaled = true;
113 if let Some(inner) = self.inner.take() {
114 inner.ready_count.fetch_add(1, Ordering::SeqCst);
115 inner.ready_notify.notify_waiters();
116 }
117 }
118 }
119}
120
121impl Drop for ReadyGuard {
122 fn drop(&mut self) {
123 self.signal();
125 }
126}
127
128struct HandlerGuard {
131 count: Arc<AtomicUsize>,
132}
133
134impl HandlerGuard {
135 fn new(count: Arc<AtomicUsize>) -> Self {
136 Self { count }
137 }
138}
139
140impl Drop for HandlerGuard {
141 fn drop(&mut self) {
142 self.count.fetch_sub(1, Ordering::SeqCst);
143 }
144}
145
146struct SubscriberSender {
147 tx: mpsc::Sender<Envelope>,
148}
149
150impl SubscriberSender {
151 fn new(tx: mpsc::Sender<Envelope>) -> Self {
152 Self { tx }
153 }
154
155 fn is_closed(&self) -> bool {
156 self.tx.is_closed()
157 }
158}
159
160#[derive(Clone)]
171pub struct Relay;
175
176impl Relay {
177 pub fn channel() -> (RelaySender, RelayReceiver) {
186 Self::channel_with_size(DEFAULT_CHANNEL_SIZE)
187 }
188
189 pub fn channel_with_size(channel_size: usize) -> (RelaySender, RelayReceiver) {
191 let inner = Arc::new(Inner::new(channel_size));
192 (
193 RelaySender {
194 inner: inner.clone(),
195 },
196 RelayReceiver { inner },
197 )
198 }
199
200 #[doc(hidden)]
203 pub fn new() -> TestRelay {
204 let (tx, rx) = Self::channel();
205 TestRelay { tx: Arc::new(tx), rx }
206 }
207
208 #[doc(hidden)]
210 pub fn with_channel_size(size: usize) -> TestRelay {
211 let (tx, rx) = Self::channel_with_size(size);
212 TestRelay { tx: Arc::new(tx), rx }
213 }
214}
215
216#[doc(hidden)]
219pub struct TestRelay {
220 tx: Arc<RelaySender>,
221 pub rx: RelayReceiver,
222}
223
224impl TestRelay {
225 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
226 self.tx.send(value).await
227 }
228
229 pub fn subscribe<T: 'static + Send + Sync>(&self) -> Subscription<T> {
230 self.rx.subscribe()
231 }
232
233 pub fn sink<T, F, R>(&self, f: F)
234 where
235 T: 'static + Send + Sync,
236 F: Fn(&T) -> R + Send + Sync + 'static,
237 R: IntoResult + 'static,
238 {
239 let (mut sub, handler_count) = self.rx.subscribe_tracked::<T>();
240 let weak_tx = self.tx.weak();
241 let msg_type = std::any::type_name::<T>();
242 let _handler_guard = HandlerGuard::new(handler_count);
243
244 tokio::spawn(async move {
245 let _guard = _handler_guard;
246 while let Some(msg) = sub.recv().await {
247 let tracker = sub.current_tracker();
248 let msg_id = sub.current_msg_id().unwrap_or(0);
249 let span = info_span!("pipedream.sink", msg_type = %msg_type, msg_id = %msg_id);
250 let _span_guard = span.enter();
251
252 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
253
254 match result {
255 Ok(Ok(())) => {
256 if let Some(t) = tracker {
257 t.complete_one();
258 }
259 sub.clear_tracker();
260 }
261 Ok(Err(e)) => {
262 let error = RelayError::new(msg_id, e, "sink");
263
264 tokio::spawn({
266 let weak_tx = weak_tx.clone();
267 let error = error.clone();
268 async move {
269 let _ = weak_tx.send(error).await;
270 }
271 });
272
273 if let Some(t) = tracker {
274 t.fail(error);
275 }
276 sub.clear_tracker();
277 }
278 Err(panic_info) => {
279 let error = RelayError::new(msg_id, PanicError::new(panic_info), "sink");
280
281 tokio::spawn({
283 let weak_tx = weak_tx.clone();
284 let error = error.clone();
285 async move {
286 let _ = weak_tx.send(error).await;
287 }
288 });
289
290 if let Some(t) = tracker {
291 t.fail(error);
292 }
293 sub.clear_tracker();
294 }
295 }
296 }
297 });
298 }
299
300 pub fn tap<T, F, R>(&self, f: F)
301 where
302 T: 'static + Send + Sync,
303 F: Fn(&T) -> R + Send + Sync + 'static,
304 R: IntoResult + 'static,
305 {
306 let (mut sub, handler_count) = self.rx.subscribe_tracked::<T>();
307 let weak_tx = self.tx.weak();
308 let msg_type = std::any::type_name::<T>();
309 let _handler_guard = HandlerGuard::new(handler_count);
310
311 tokio::spawn(async move {
312 let _guard = _handler_guard;
313 while let Some(msg) = sub.recv().await {
314 let tracker = sub.current_tracker();
315 let msg_id = sub.current_msg_id().unwrap_or(0);
316 let span = info_span!("pipedream.tap", msg_type = %msg_type, msg_id = %msg_id);
317 let _span_guard = span.enter();
318
319 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
320
321 match result {
322 Ok(Ok(())) => {
323 if let Some(t) = tracker {
324 t.complete_one();
325 }
326 sub.clear_tracker();
327 }
328 Ok(Err(e)) => {
329 let error = RelayError::new(msg_id, e, "tap");
330
331 tokio::spawn({
333 let weak_tx = weak_tx.clone();
334 let error = error.clone();
335 async move {
336 let _ = weak_tx.send(error).await;
337 }
338 });
339
340 if let Some(t) = tracker {
341 t.fail(error);
342 }
343 sub.clear_tracker();
344 }
345 Err(panic_info) => {
346 let error = RelayError::new(msg_id, PanicError::new(panic_info), "tap");
347
348 tokio::spawn({
350 let weak_tx = weak_tx.clone();
351 let error = error.clone();
352 async move {
353 let _ = weak_tx.send(error).await;
354 }
355 });
356
357 if let Some(t) = tracker {
358 t.fail(error);
359 }
360 sub.clear_tracker();
361 }
362 }
363 }
364 });
365 }
366
367 pub fn is_closed(&self) -> bool {
368 self.tx.is_closed()
369 }
370
371 pub fn handler_count(&self) -> usize {
372 self.tx.handler_count()
373 }
374
375 pub fn close(&self) {
376 self.tx.close()
377 }
378
379 pub async fn send_any(
380 &self,
381 value: Arc<dyn std::any::Any + Send + Sync>,
382 type_id: std::any::TypeId,
383 ) -> Result<(), SendError> {
384 self.tx.send_any(value, type_id).await
385 }
386
387 pub async fn send_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
388 self.tx.send_envelope(envelope).await
389 }
390
391 pub fn within<F, Fut>(&self, f: F)
392 where
393 F: FnOnce() -> Fut + Send + 'static,
394 Fut: std::future::Future<Output = ()> + Send,
395 {
396 self.rx.within(f);
397 }
398}
399
400impl Clone for TestRelay {
401 fn clone(&self) -> Self {
402 Self {
403 tx: self.tx.clone(),
404 rx: self.rx.clone(),
405 }
406 }
407}
408
409#[derive(Debug, Clone)]
414pub enum SendError {
415 Closed,
416 Downstream(RelayError),
417}
418
419pub trait IntoResult {
424 type Error: std::error::Error + Send + Sync + 'static;
425 fn into_result(self) -> Result<(), Self::Error>;
426}
427
428impl IntoResult for () {
429 type Error = std::convert::Infallible;
430 fn into_result(self) -> Result<(), Self::Error> {
431 Ok(())
432 }
433}
434
435impl<E: std::error::Error + Send + Sync + 'static> IntoResult for Result<(), E> {
436 type Error = E;
437 fn into_result(self) -> Result<(), E> {
438 self
439 }
440}
441
442pub struct RelaySender {
454 inner: Arc<Inner>,
455}
456
457impl RelaySender {
458 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
460 if self.inner.closed.load(Ordering::SeqCst) {
461 return Err(SendError::Closed);
462 }
463
464 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
465 let expected = self.inner.handler_count.load(Ordering::SeqCst);
466 let tracker = Arc::new(CompletionTracker::new(expected));
467 let envelope = Envelope::with_origin(value, msg_id, Some(tracker), self.inner.id);
468
469 self.send_envelope(envelope).await
470 }
471
472 pub async fn send_any(
474 &self,
475 value: Arc<dyn std::any::Any + Send + Sync>,
476 type_id: std::any::TypeId,
477 ) -> Result<(), SendError> {
478 if self.inner.closed.load(Ordering::SeqCst) {
479 return Err(SendError::Closed);
480 }
481
482 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
483 let expected = self.inner.handler_count.load(Ordering::SeqCst);
484 let tracker = Arc::new(CompletionTracker::new(expected));
485 let envelope =
486 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), self.inner.id);
487
488 self.send_envelope(envelope).await
489 }
490
491 pub async fn send_any_with_origin(
493 &self,
494 value: Arc<dyn std::any::Any + Send + Sync>,
495 type_id: std::any::TypeId,
496 origin: u64,
497 ) -> Result<(), SendError> {
498 if self.inner.closed.load(Ordering::SeqCst) {
499 return Err(SendError::Closed);
500 }
501
502 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
503 let expected = self.inner.handler_count.load(Ordering::SeqCst);
504 let tracker = Arc::new(CompletionTracker::new(expected));
505 let envelope =
506 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), origin);
507
508 self.send_envelope(envelope).await
509 }
510
511 pub async fn send_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
514 if self.inner.closed.load(Ordering::SeqCst) {
515 return Err(SendError::Closed);
516 }
517
518 self.inner.wait_ready().await;
519
520 let envelope = if envelope.origin() == 0 {
521 envelope.with_new_origin(self.inner.id)
522 } else {
523 envelope
524 };
525
526 let tracker = envelope.tracker();
527 self.deliver_envelope(envelope).await?;
528
529 if let Some(tracker) = tracker {
530 tracker.clone().wait_owned().await;
531 if let Some(error) = tracker.take_error() {
532 return Err(SendError::Downstream(error));
533 }
534 }
535
536 Ok(())
537 }
538
539 async fn deliver_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
540 let subs: Vec<_> = {
541 let mut subs = self.inner.subscribers.lock();
542 subs.retain(|s| !s.is_closed());
543 subs.iter().map(|s| s.tx.clone()).collect()
544 };
545
546 for tx in subs {
547 match tx.try_send(envelope.clone()) {
548 Ok(_) => {}
549 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
550 }
552 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
553 }
555 }
556 }
557
558 Ok(())
559 }
560
561 pub fn weak(&self) -> WeakSender {
566 WeakSender {
567 inner: Arc::downgrade(&self.inner),
568 }
569 }
570
571 pub fn is_closed(&self) -> bool {
573 self.inner.closed.load(Ordering::SeqCst)
574 }
575
576 pub fn id(&self) -> u64 {
578 self.inner.id
579 }
580
581 pub fn handler_count(&self) -> usize {
583 self.inner.handler_count.load(Ordering::SeqCst)
584 }
585
586 pub fn close(&self) {
589 self.inner.closed.store(true, Ordering::SeqCst);
590 self.inner.subscribers.lock().clear();
591 }
592
593 #[doc(hidden)]
596 pub fn clone_for_test(&self) -> Self {
597 Self {
598 inner: self.inner.clone(),
599 }
600 }
601}
602
603impl Drop for RelaySender {
604 fn drop(&mut self) {
605 self.inner.closed.store(true, Ordering::SeqCst);
607 self.inner.subscribers.lock().clear();
608 }
609}
610
611#[derive(Clone)]
622pub struct WeakSender {
623 inner: Weak<Inner>,
624}
625
626impl WeakSender {
627 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
629 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
630
631 if inner.closed.load(Ordering::SeqCst) {
632 return Err(SendError::Closed);
633 }
634
635 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
636 let expected = inner.handler_count.load(Ordering::SeqCst);
637 let tracker = Arc::new(CompletionTracker::new(expected));
638 let envelope = Envelope::with_origin(value, msg_id, Some(tracker), inner.id);
639
640 self.send_envelope_inner(&inner, envelope).await
641 }
642
643 pub async fn send_any(
645 &self,
646 value: Arc<dyn std::any::Any + Send + Sync>,
647 type_id: std::any::TypeId,
648 ) -> Result<(), SendError> {
649 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
650
651 if inner.closed.load(Ordering::SeqCst) {
652 return Err(SendError::Closed);
653 }
654
655 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
656 let expected = inner.handler_count.load(Ordering::SeqCst);
657 let tracker = Arc::new(CompletionTracker::new(expected));
658 let envelope =
659 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), inner.id);
660
661 self.send_envelope_inner(&inner, envelope).await
662 }
663
664 pub async fn send_any_with_origin(
666 &self,
667 value: Arc<dyn std::any::Any + Send + Sync>,
668 type_id: std::any::TypeId,
669 origin: u64,
670 ) -> Result<(), SendError> {
671 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
672
673 if inner.closed.load(Ordering::SeqCst) {
674 return Err(SendError::Closed);
675 }
676
677 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
678 let expected = inner.handler_count.load(Ordering::SeqCst);
679 let tracker = Arc::new(CompletionTracker::new(expected));
680 let envelope =
681 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), origin);
682
683 self.send_envelope_inner(&inner, envelope).await
684 }
685
686 async fn send_envelope_inner(
687 &self,
688 inner: &Arc<Inner>,
689 envelope: Envelope,
690 ) -> Result<(), SendError> {
691 if inner.closed.load(Ordering::SeqCst) {
692 return Err(SendError::Closed);
693 }
694
695 inner.wait_ready().await;
696
697 let envelope = if envelope.origin() == 0 {
698 envelope.with_new_origin(inner.id)
699 } else {
700 envelope
701 };
702
703 let tracker = envelope.tracker();
704
705 let subs: Vec<_> = {
707 let mut subs = inner.subscribers.lock();
708 subs.retain(|s| !s.is_closed());
709 subs.iter().map(|s| s.tx.clone()).collect()
710 };
711
712 for tx in subs {
713 let _ = tx.try_send(envelope.clone());
714 }
715
716 if let Some(tracker) = tracker {
717 tracker.clone().wait_owned().await;
718 if let Some(error) = tracker.take_error() {
719 return Err(SendError::Downstream(error));
720 }
721 }
722
723 Ok(())
724 }
725
726 pub fn is_closed(&self) -> bool {
728 match self.inner.upgrade() {
729 Some(inner) => inner.closed.load(Ordering::SeqCst),
730 None => true,
731 }
732 }
733
734}
735
736#[derive(Clone)]
745pub struct RelayReceiver {
746 inner: Arc<Inner>,
747}
748
749impl RelayReceiver {
750 pub fn subscribe<T: 'static + Send + Sync>(&self) -> Subscription<T> {
754 let (tx, rx) = mpsc::channel(self.inner.channel_size);
755 self.inner
756 .subscribers
757 .lock()
758 .push(SubscriberSender::new(tx));
759 Subscription::new(rx)
760 }
761
762 pub fn subscribe_all(&self) -> mpsc::Receiver<Envelope> {
766 let (tx, rx) = mpsc::channel(self.inner.channel_size);
767 self.inner
768 .subscribers
769 .lock()
770 .push(SubscriberSender::new(tx));
771 rx
772 }
773
774 pub fn subscribe_tracked<T: 'static + Send + Sync>(
778 &self,
779 ) -> (Subscription<T>, Arc<AtomicUsize>) {
780 let (tx, rx) = mpsc::channel(self.inner.channel_size);
781 self.inner
782 .subscribers
783 .lock()
784 .push(SubscriberSender::new(tx));
785 self.inner.handler_count.fetch_add(1, Ordering::SeqCst);
786
787 let mut ready_guard = ReadyGuard::new(self.inner.clone());
789 ready_guard.signal();
790
791 (
792 Subscription::new_tracked(rx),
793 self.inner.handler_count.clone(),
794 )
795 }
796
797 pub fn is_closed(&self) -> bool {
799 self.inner.closed.load(Ordering::SeqCst)
800 }
801
802 pub fn id(&self) -> u64 {
804 self.inner.id
805 }
806
807 pub fn sink<T, F, R>(&self, f: F)
812 where
813 T: 'static + Send + Sync,
814 F: Fn(&T) -> R + Send + Sync + 'static,
815 R: IntoResult + 'static,
816 {
817 let (mut sub, handler_count) = self.subscribe_tracked::<T>();
818 let msg_type = std::any::type_name::<T>();
819 let _handler_guard = HandlerGuard::new(handler_count);
820
821 tokio::spawn(async move {
822 let _guard = _handler_guard;
823 while let Some(msg) = sub.recv().await {
824 let tracker = sub.current_tracker();
825 let msg_id = sub.current_msg_id().unwrap_or(0);
826 let span = info_span!("pipedream.sink", msg_type = %msg_type, msg_id = %msg_id);
827 let _span_guard = span.enter();
828
829 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
830
831 match result {
832 Ok(Ok(())) => {
833 if let Some(t) = tracker {
834 t.complete_one();
835 }
836 sub.clear_tracker();
837 }
838 Ok(Err(e)) => {
839 let error = RelayError::new(msg_id, e, "sink");
840 if let Some(t) = tracker {
841 t.fail(error);
842 }
843 sub.clear_tracker();
844 }
845 Err(panic_info) => {
846 let error = RelayError::new(msg_id, PanicError::new(panic_info), "sink");
847 if let Some(t) = tracker {
848 t.fail(error);
849 }
850 sub.clear_tracker();
851 }
852 }
853 }
854 });
855 }
856
857 pub fn tap<T, F, R>(&self, f: F)
862 where
863 T: 'static + Send + Sync,
864 F: Fn(&T) -> R + Send + Sync + 'static,
865 R: IntoResult + 'static,
866 {
867 let (mut sub, handler_count) = self.subscribe_tracked::<T>();
868 let msg_type = std::any::type_name::<T>();
869 let _handler_guard = HandlerGuard::new(handler_count);
870
871 tokio::spawn(async move {
872 let _guard = _handler_guard;
873 while let Some(msg) = sub.recv().await {
874 let tracker = sub.current_tracker();
875 let msg_id = sub.current_msg_id().unwrap_or(0);
876 let span = info_span!("pipedream.tap", msg_type = %msg_type, msg_id = %msg_id);
877 let _span_guard = span.enter();
878
879 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
880
881 match result {
882 Ok(Ok(())) => {
883 if let Some(t) = tracker {
884 t.complete_one();
885 }
886 sub.clear_tracker();
887 }
888 Ok(Err(e)) => {
889 let error = RelayError::new(msg_id, e, "tap");
890 if let Some(t) = tracker {
891 t.fail(error);
892 }
893 sub.clear_tracker();
894 }
895 Err(panic_info) => {
896 let error = RelayError::new(msg_id, PanicError::new(panic_info), "tap");
897 if let Some(t) = tracker {
898 t.fail(error);
899 }
900 sub.clear_tracker();
901 }
902 }
903 }
904 });
905 }
906
907 pub fn within<F, Fut>(&self, f: F)
912 where
913 F: FnOnce() -> Fut + Send + 'static,
914 Fut: std::future::Future<Output = ()> + Send,
915 {
916 use futures::FutureExt;
917 tokio::spawn(async move {
918 let result = AssertUnwindSafe(f()).catch_unwind().await;
919 if let Err(panic_info) = result {
920 eprintln!("Panic in within(): {:?}", PanicError::new(panic_info));
921 }
922 });
923 }
924}