1use parking_lot::Mutex;
2use std::panic::AssertUnwindSafe;
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
4use std::sync::{Arc, Weak};
5use tokio::sync::{mpsc, Notify};
6use tracing::info_span;
7
8use crate::envelope::Envelope;
9use crate::error::{PanicError, RelayError};
10use crate::subscription::Subscription;
11use crate::tracker::CompletionTracker;
12
13const DEFAULT_CHANNEL_SIZE: usize = 65536;
14
15static RELAY_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
17
18fn next_relay_id() -> u64 {
19 RELAY_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
20}
21
22pub(crate) struct Inner {
27 id: u64,
29 subscribers: Mutex<Vec<SubscriberSender>>,
30 channel_size: usize,
31 pending_ready: AtomicUsize,
33 ready_count: AtomicUsize,
35 ready_notify: Notify,
37 msg_id_counter: AtomicU64,
38 handler_count: Arc<AtomicUsize>,
39 closed: AtomicBool,
40}
41
42impl Inner {
43 fn new(channel_size: usize) -> Self {
44 Self {
45 id: next_relay_id(),
46 subscribers: Mutex::new(Vec::new()),
47 channel_size,
48 pending_ready: AtomicUsize::new(0),
49 ready_count: AtomicUsize::new(0),
50 ready_notify: Notify::new(),
51 msg_id_counter: AtomicU64::new(0),
52 handler_count: Arc::new(AtomicUsize::new(0)),
53 closed: AtomicBool::new(false),
54 }
55 }
56
57 async fn wait_ready(&self) {
59 loop {
60 let pending = self.pending_ready.load(Ordering::SeqCst);
62 let ready = self.ready_count.load(Ordering::SeqCst);
63 if ready >= pending {
64 return;
65 }
66
67 let notified = self.ready_notify.notified();
69
70 let pending = self.pending_ready.load(Ordering::SeqCst);
72 let ready = self.ready_count.load(Ordering::SeqCst);
73 if ready >= pending {
74 return;
75 }
76
77 notified.await;
78 }
79 }
80}
81
82impl Drop for Inner {
83 fn drop(&mut self) {
84 self.closed.store(true, Ordering::SeqCst);
86 self.subscribers.lock().clear();
87 }
88}
89
90struct HandlerGuard {
97 count: Arc<AtomicUsize>,
98}
99
100impl HandlerGuard {
101 fn new(count: Arc<AtomicUsize>) -> Self {
102 Self { count }
103 }
104}
105
106impl Drop for HandlerGuard {
107 fn drop(&mut self) {
108 self.count.fetch_sub(1, Ordering::SeqCst);
109 }
110}
111
112struct SubscriberSender {
113 tx: mpsc::Sender<Envelope>,
114}
115
116impl SubscriberSender {
117 fn new(tx: mpsc::Sender<Envelope>) -> Self {
118 Self { tx }
119 }
120
121 fn is_closed(&self) -> bool {
122 self.tx.is_closed()
123 }
124}
125
126#[derive(Clone)]
137pub struct Relay;
141
142impl Relay {
143 pub fn channel() -> (RelaySender, RelayReceiver) {
152 Self::channel_with_size(DEFAULT_CHANNEL_SIZE)
153 }
154
155 pub fn channel_with_size(channel_size: usize) -> (RelaySender, RelayReceiver) {
157 let inner = Arc::new(Inner::new(channel_size));
158 (
159 RelaySender {
160 inner: inner.clone(),
161 },
162 RelayReceiver { inner },
163 )
164 }
165
166 #[doc(hidden)]
169 pub fn new() -> TestRelay {
170 let (tx, rx) = Self::channel();
171 TestRelay { tx: Arc::new(tx), rx }
172 }
173
174 #[doc(hidden)]
176 pub fn with_channel_size(size: usize) -> TestRelay {
177 let (tx, rx) = Self::channel_with_size(size);
178 TestRelay { tx: Arc::new(tx), rx }
179 }
180}
181
182#[doc(hidden)]
185pub struct TestRelay {
186 tx: Arc<RelaySender>,
187 pub rx: RelayReceiver,
188}
189
190impl TestRelay {
191 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
192 self.tx.send(value).await
193 }
194
195 pub fn subscribe<T: 'static + Send + Sync>(&self) -> Subscription<T> {
196 self.rx.subscribe()
197 }
198
199 pub fn sink<T, F, R>(&self, f: F)
200 where
201 T: 'static + Send + Sync,
202 F: Fn(&T) -> R + Send + Sync + 'static,
203 R: IntoResult + 'static,
204 {
205 let (mut sub, handler_count) = self.rx.subscribe_tracked::<T>();
206 let weak_tx = self.tx.weak();
207 let msg_type = std::any::type_name::<T>();
208 let _handler_guard = HandlerGuard::new(handler_count);
209
210 tokio::spawn(async move {
211 let _guard = _handler_guard;
212 while let Some(msg) = sub.recv().await {
213 let tracker = sub.current_tracker();
214 let msg_id = sub.current_msg_id().unwrap_or(0);
215 let span = info_span!("pipedream.sink", msg_type = %msg_type, msg_id = %msg_id);
216 let _span_guard = span.enter();
217
218 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
219
220 match result {
221 Ok(Ok(())) => {
222 if let Some(t) = tracker {
223 t.complete_one();
224 }
225 sub.clear_tracker();
226 }
227 Ok(Err(e)) => {
228 let error = RelayError::new(msg_id, e, "sink");
229
230 tokio::spawn({
232 let weak_tx = weak_tx.clone();
233 let error = error.clone();
234 async move {
235 let _ = weak_tx.send(error).await;
236 }
237 });
238
239 if let Some(t) = tracker {
240 t.fail(error);
241 }
242 sub.clear_tracker();
243 }
244 Err(panic_info) => {
245 let error = RelayError::new(msg_id, PanicError::new(panic_info), "sink");
246
247 tokio::spawn({
249 let weak_tx = weak_tx.clone();
250 let error = error.clone();
251 async move {
252 let _ = weak_tx.send(error).await;
253 }
254 });
255
256 if let Some(t) = tracker {
257 t.fail(error);
258 }
259 sub.clear_tracker();
260 }
261 }
262 }
263 });
264 }
265
266 pub fn tap<T, F, R>(&self, f: F)
267 where
268 T: 'static + Send + Sync,
269 F: Fn(&T) -> R + Send + Sync + 'static,
270 R: IntoResult + 'static,
271 {
272 let (mut sub, handler_count) = self.rx.subscribe_tracked::<T>();
273 let weak_tx = self.tx.weak();
274 let msg_type = std::any::type_name::<T>();
275 let _handler_guard = HandlerGuard::new(handler_count);
276
277 tokio::spawn(async move {
278 let _guard = _handler_guard;
279 while let Some(msg) = sub.recv().await {
280 let tracker = sub.current_tracker();
281 let msg_id = sub.current_msg_id().unwrap_or(0);
282 let span = info_span!("pipedream.tap", msg_type = %msg_type, msg_id = %msg_id);
283 let _span_guard = span.enter();
284
285 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
286
287 match result {
288 Ok(Ok(())) => {
289 if let Some(t) = tracker {
290 t.complete_one();
291 }
292 sub.clear_tracker();
293 }
294 Ok(Err(e)) => {
295 let error = RelayError::new(msg_id, e, "tap");
296
297 tokio::spawn({
299 let weak_tx = weak_tx.clone();
300 let error = error.clone();
301 async move {
302 let _ = weak_tx.send(error).await;
303 }
304 });
305
306 if let Some(t) = tracker {
307 t.fail(error);
308 }
309 sub.clear_tracker();
310 }
311 Err(panic_info) => {
312 let error = RelayError::new(msg_id, PanicError::new(panic_info), "tap");
313
314 tokio::spawn({
316 let weak_tx = weak_tx.clone();
317 let error = error.clone();
318 async move {
319 let _ = weak_tx.send(error).await;
320 }
321 });
322
323 if let Some(t) = tracker {
324 t.fail(error);
325 }
326 sub.clear_tracker();
327 }
328 }
329 }
330 });
331 }
332
333 pub fn is_closed(&self) -> bool {
334 self.tx.is_closed()
335 }
336
337 pub fn handler_count(&self) -> usize {
338 self.tx.handler_count()
339 }
340
341 pub fn close(&self) {
342 self.tx.close()
343 }
344
345 pub async fn send_any(
346 &self,
347 value: Arc<dyn std::any::Any + Send + Sync>,
348 type_id: std::any::TypeId,
349 ) -> Result<(), SendError> {
350 self.tx.send_any(value, type_id).await
351 }
352
353 pub async fn send_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
354 self.tx.send_envelope(envelope).await
355 }
356
357 pub fn within<F, Fut>(&self, f: F)
358 where
359 F: FnOnce() -> Fut + Send + 'static,
360 Fut: std::future::Future<Output = ()> + Send,
361 {
362 self.rx.within(f);
363 }
364}
365
366impl Clone for TestRelay {
367 fn clone(&self) -> Self {
368 Self {
369 tx: self.tx.clone(),
370 rx: self.rx.clone(),
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
380pub enum SendError {
381 Closed,
382 Downstream(RelayError),
383}
384
385pub trait IntoResult {
390 type Error: std::error::Error + Send + Sync + 'static;
391 fn into_result(self) -> Result<(), Self::Error>;
392}
393
394impl IntoResult for () {
395 type Error = std::convert::Infallible;
396 fn into_result(self) -> Result<(), Self::Error> {
397 Ok(())
398 }
399}
400
401impl<E: std::error::Error + Send + Sync + 'static> IntoResult for Result<(), E> {
402 type Error = E;
403 fn into_result(self) -> Result<(), E> {
404 self
405 }
406}
407
408pub struct RelaySender {
420 inner: Arc<Inner>,
421}
422
423impl RelaySender {
424 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
426 if self.inner.closed.load(Ordering::SeqCst) {
427 return Err(SendError::Closed);
428 }
429
430 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
431 let expected = self.inner.handler_count.load(Ordering::SeqCst);
432 let tracker = Arc::new(CompletionTracker::new(expected));
433 let envelope = Envelope::with_origin(value, msg_id, Some(tracker), self.inner.id);
434
435 self.send_envelope(envelope).await
436 }
437
438 pub async fn send_any(
440 &self,
441 value: Arc<dyn std::any::Any + Send + Sync>,
442 type_id: std::any::TypeId,
443 ) -> Result<(), SendError> {
444 if self.inner.closed.load(Ordering::SeqCst) {
445 return Err(SendError::Closed);
446 }
447
448 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
449 let expected = self.inner.handler_count.load(Ordering::SeqCst);
450 let tracker = Arc::new(CompletionTracker::new(expected));
451 let envelope =
452 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), self.inner.id);
453
454 self.send_envelope(envelope).await
455 }
456
457 pub async fn send_any_with_origin(
459 &self,
460 value: Arc<dyn std::any::Any + Send + Sync>,
461 type_id: std::any::TypeId,
462 origin: u64,
463 ) -> Result<(), SendError> {
464 if self.inner.closed.load(Ordering::SeqCst) {
465 return Err(SendError::Closed);
466 }
467
468 let msg_id = self.inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
469 let expected = self.inner.handler_count.load(Ordering::SeqCst);
470 let tracker = Arc::new(CompletionTracker::new(expected));
471 let envelope =
472 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), origin);
473
474 self.send_envelope(envelope).await
475 }
476
477 pub async fn send_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
480 if self.inner.closed.load(Ordering::SeqCst) {
481 return Err(SendError::Closed);
482 }
483
484 self.inner.wait_ready().await;
485
486 let envelope = if envelope.origin() == 0 {
487 envelope.with_new_origin(self.inner.id)
488 } else {
489 envelope
490 };
491
492 let tracker = envelope.tracker();
493 self.deliver_envelope(envelope).await?;
494
495 if let Some(tracker) = tracker {
496 tracker.clone().wait_owned().await;
497 if let Some(error) = tracker.take_error() {
498 return Err(SendError::Downstream(error));
499 }
500 }
501
502 Ok(())
503 }
504
505 async fn deliver_envelope(&self, envelope: Envelope) -> Result<(), SendError> {
506 let subs: Vec<_> = {
507 let mut subs = self.inner.subscribers.lock();
508 subs.retain(|s| !s.is_closed());
509 subs.iter().map(|s| s.tx.clone()).collect()
510 };
511
512 for tx in subs {
513 match tx.try_send(envelope.clone()) {
514 Ok(_) => {}
515 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
516 }
518 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
519 }
521 }
522 }
523
524 Ok(())
525 }
526
527 pub fn weak(&self) -> WeakSender {
532 WeakSender {
533 inner: Arc::downgrade(&self.inner),
534 }
535 }
536
537 pub fn is_closed(&self) -> bool {
539 self.inner.closed.load(Ordering::SeqCst)
540 }
541
542 pub fn id(&self) -> u64 {
544 self.inner.id
545 }
546
547 pub fn handler_count(&self) -> usize {
549 self.inner.handler_count.load(Ordering::SeqCst)
550 }
551
552 pub fn close(&self) {
555 self.inner.closed.store(true, Ordering::SeqCst);
556 self.inner.subscribers.lock().clear();
557 }
558
559 #[doc(hidden)]
562 pub fn clone_for_test(&self) -> Self {
563 Self {
564 inner: self.inner.clone(),
565 }
566 }
567}
568
569impl Drop for RelaySender {
570 fn drop(&mut self) {
571 self.inner.closed.store(true, Ordering::SeqCst);
573 self.inner.subscribers.lock().clear();
574 }
575}
576
577#[derive(Clone)]
588pub struct WeakSender {
589 inner: Weak<Inner>,
590}
591
592impl WeakSender {
593 pub async fn send<T: 'static + Send + Sync>(&self, value: T) -> Result<(), SendError> {
595 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
596
597 if inner.closed.load(Ordering::SeqCst) {
598 return Err(SendError::Closed);
599 }
600
601 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
602 let expected = inner.handler_count.load(Ordering::SeqCst);
603 let tracker = Arc::new(CompletionTracker::new(expected));
604 let envelope = Envelope::with_origin(value, msg_id, Some(tracker), inner.id);
605
606 self.send_envelope_inner(&inner, envelope).await
607 }
608
609 pub async fn send_any(
611 &self,
612 value: Arc<dyn std::any::Any + Send + Sync>,
613 type_id: std::any::TypeId,
614 ) -> Result<(), SendError> {
615 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
616
617 if inner.closed.load(Ordering::SeqCst) {
618 return Err(SendError::Closed);
619 }
620
621 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
622 let expected = inner.handler_count.load(Ordering::SeqCst);
623 let tracker = Arc::new(CompletionTracker::new(expected));
624 let envelope =
625 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), inner.id);
626
627 self.send_envelope_inner(&inner, envelope).await
628 }
629
630 pub async fn send_any_with_origin(
632 &self,
633 value: Arc<dyn std::any::Any + Send + Sync>,
634 type_id: std::any::TypeId,
635 origin: u64,
636 ) -> Result<(), SendError> {
637 let inner = self.inner.upgrade().ok_or(SendError::Closed)?;
638
639 if inner.closed.load(Ordering::SeqCst) {
640 return Err(SendError::Closed);
641 }
642
643 let msg_id = inner.msg_id_counter.fetch_add(1, Ordering::Relaxed);
644 let expected = inner.handler_count.load(Ordering::SeqCst);
645 let tracker = Arc::new(CompletionTracker::new(expected));
646 let envelope =
647 Envelope::from_any_with_origin(value, type_id, msg_id, Some(tracker), origin);
648
649 self.send_envelope_inner(&inner, envelope).await
650 }
651
652 async fn send_envelope_inner(
653 &self,
654 inner: &Arc<Inner>,
655 envelope: Envelope,
656 ) -> Result<(), SendError> {
657 if inner.closed.load(Ordering::SeqCst) {
658 return Err(SendError::Closed);
659 }
660
661 inner.wait_ready().await;
662
663 let envelope = if envelope.origin() == 0 {
664 envelope.with_new_origin(inner.id)
665 } else {
666 envelope
667 };
668
669 let tracker = envelope.tracker();
670
671 let subs: Vec<_> = {
673 let mut subs = inner.subscribers.lock();
674 subs.retain(|s| !s.is_closed());
675 subs.iter().map(|s| s.tx.clone()).collect()
676 };
677
678 for tx in subs {
679 let _ = tx.try_send(envelope.clone());
680 }
681
682 if let Some(tracker) = tracker {
683 tracker.clone().wait_owned().await;
684 if let Some(error) = tracker.take_error() {
685 return Err(SendError::Downstream(error));
686 }
687 }
688
689 Ok(())
690 }
691
692 pub fn is_closed(&self) -> bool {
694 match self.inner.upgrade() {
695 Some(inner) => inner.closed.load(Ordering::SeqCst),
696 None => true,
697 }
698 }
699
700}
701
702#[derive(Clone)]
711pub struct RelayReceiver {
712 inner: Arc<Inner>,
713}
714
715impl RelayReceiver {
716 pub fn subscribe<T: 'static + Send + Sync>(&self) -> Subscription<T> {
720 let (tx, rx) = mpsc::channel(self.inner.channel_size);
721 self.inner
722 .subscribers
723 .lock()
724 .push(SubscriberSender::new(tx));
725 Subscription::new(rx)
726 }
727
728 pub fn subscribe_all(&self) -> mpsc::Receiver<Envelope> {
732 let (tx, rx) = mpsc::channel(self.inner.channel_size);
733 self.inner
734 .subscribers
735 .lock()
736 .push(SubscriberSender::new(tx));
737 rx
738 }
739
740 pub fn subscribe_tracked<T: 'static + Send + Sync>(
742 &self,
743 ) -> (Subscription<T>, Arc<AtomicUsize>) {
744 let (tx, rx) = mpsc::channel(self.inner.channel_size);
745 self.inner
746 .subscribers
747 .lock()
748 .push(SubscriberSender::new(tx));
749 self.inner.handler_count.fetch_add(1, Ordering::SeqCst);
750 (
751 Subscription::new_tracked(rx),
752 self.inner.handler_count.clone(),
753 )
754 }
755
756 pub fn is_closed(&self) -> bool {
758 self.inner.closed.load(Ordering::SeqCst)
759 }
760
761 pub fn id(&self) -> u64 {
763 self.inner.id
764 }
765
766 pub fn sink<T, F, R>(&self, f: F)
771 where
772 T: 'static + Send + Sync,
773 F: Fn(&T) -> R + Send + Sync + 'static,
774 R: IntoResult + 'static,
775 {
776 let (mut sub, handler_count) = self.subscribe_tracked::<T>();
777 let msg_type = std::any::type_name::<T>();
778 let _handler_guard = HandlerGuard::new(handler_count);
779
780 tokio::spawn(async move {
781 let _guard = _handler_guard;
782 while let Some(msg) = sub.recv().await {
783 let tracker = sub.current_tracker();
784 let msg_id = sub.current_msg_id().unwrap_or(0);
785 let span = info_span!("pipedream.sink", msg_type = %msg_type, msg_id = %msg_id);
786 let _span_guard = span.enter();
787
788 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
789
790 match result {
791 Ok(Ok(())) => {
792 if let Some(t) = tracker {
793 t.complete_one();
794 }
795 sub.clear_tracker();
796 }
797 Ok(Err(e)) => {
798 let error = RelayError::new(msg_id, e, "sink");
799 if let Some(t) = tracker {
800 t.fail(error);
801 }
802 sub.clear_tracker();
803 }
804 Err(panic_info) => {
805 let error = RelayError::new(msg_id, PanicError::new(panic_info), "sink");
806 if let Some(t) = tracker {
807 t.fail(error);
808 }
809 sub.clear_tracker();
810 }
811 }
812 }
813 });
814 }
815
816 pub fn tap<T, F, R>(&self, f: F)
821 where
822 T: 'static + Send + Sync,
823 F: Fn(&T) -> R + Send + Sync + 'static,
824 R: IntoResult + 'static,
825 {
826 let (mut sub, handler_count) = self.subscribe_tracked::<T>();
827 let msg_type = std::any::type_name::<T>();
828 let _handler_guard = HandlerGuard::new(handler_count);
829
830 tokio::spawn(async move {
831 let _guard = _handler_guard;
832 while let Some(msg) = sub.recv().await {
833 let tracker = sub.current_tracker();
834 let msg_id = sub.current_msg_id().unwrap_or(0);
835 let span = info_span!("pipedream.tap", msg_type = %msg_type, msg_id = %msg_id);
836 let _span_guard = span.enter();
837
838 let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&*msg).into_result()));
839
840 match result {
841 Ok(Ok(())) => {
842 if let Some(t) = tracker {
843 t.complete_one();
844 }
845 sub.clear_tracker();
846 }
847 Ok(Err(e)) => {
848 let error = RelayError::new(msg_id, e, "tap");
849 if let Some(t) = tracker {
850 t.fail(error);
851 }
852 sub.clear_tracker();
853 }
854 Err(panic_info) => {
855 let error = RelayError::new(msg_id, PanicError::new(panic_info), "tap");
856 if let Some(t) = tracker {
857 t.fail(error);
858 }
859 sub.clear_tracker();
860 }
861 }
862 }
863 });
864 }
865
866 pub fn within<F, Fut>(&self, f: F)
871 where
872 F: FnOnce() -> Fut + Send + 'static,
873 Fut: std::future::Future<Output = ()> + Send,
874 {
875 use futures::FutureExt;
876 tokio::spawn(async move {
877 let result = AssertUnwindSafe(f()).catch_unwind().await;
878 if let Err(panic_info) = result {
879 eprintln!("Panic in within(): {:?}", PanicError::new(panic_info));
880 }
881 });
882 }
883}