Skip to main content

tightbeam/
router.rs

1#[cfg(not(feature = "std"))]
2extern crate alloc;
3#[cfg(not(feature = "std"))]
4use alloc::{sync::Arc, vec::Vec};
5
6#[cfg(feature = "std")]
7use std::sync::Arc;
8
9use crate::{Frame, Message};
10
11#[cfg(feature = "derive")]
12use crate::Errorizable;
13
14pub type Result<T> = core::result::Result<T, RouterError>;
15
16#[cfg_attr(feature = "derive", derive(Errorizable))]
17#[derive(Debug)]
18pub enum RouterError {
19	#[cfg_attr(feature = "derive", error("No route configured for provided message"))]
20	UnknownRoute,
21}
22
23crate::impl_error_display!(RouterError {
24	UnknownRoute => "No route configured for provided message",
25});
26
27pub trait RouterPolicy: Send + Sync {
28	fn dispatch<T: Message + Send + 'static>(&self, message: Arc<Frame>) -> Result<()>;
29}
30
31#[macro_export]
32macro_rules! routes {
33	// Helper for generating dispatch logic
34	(@dispatch $self:ident, $message:ident, [ $( ($MsgTy:ty, $this:ident, $arg:pat_param, $handler:block) ),* ]) => {
35		$(
36			if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
37				let $arg = $message;
38				let $this = $self;
39				{ $handler }
40				return Ok(());
41			}
42		)*
43		Err($crate::router::RouterError::UnknownRoute)
44	};
45
46	// Helper for generating dispatch logic (1-arg form)
47	(@dispatch $self:ident, $message:ident, [ $( ($MsgTy:ty, $arg:pat_param, $handler:block) ),* ]) => {
48		$(
49			if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
50				let $arg = $message;
51				{ $handler }
52			 return Ok(());
53			}
54		)*
55		Err($crate::router::RouterError::UnknownRoute)
56	};
57
58	(
59		$RouterName:ident { $( $field:ident : $fty:ty ),* $(,)? } :
60		$(
61			$MsgTy:ty | $($arg:ident),* | $handler:block
62		)+
63	) => {
64		struct $RouterName { $( $field : $fty ),* }
65		impl $crate::router::RouterPolicy for $RouterName {
66			#[cfg(not(feature = "std"))]
67			fn dispatch<T: $crate::Message + Send + 'static>(&self, message: alloc::sync::Arc<$crate::Frame>) -> $crate::router::Result<()> {
68				$(
69					if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
70						let ($($arg),*) = (self, message);
71						$handler;
72						return Ok(());
73					}
74				)*
75				Err($crate::router::RouterError::UnknownRoute)
76			}
77
78			#[cfg(feature = "std")]
79			fn dispatch<T: $crate::Message + Send + 'static>(&self, message: std::sync::Arc<$crate::Frame>) -> $crate::router::Result<()> {
80				$(
81					if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$MsgTy>() {
82						let ($($arg),*) = (self, message);
83						$handler;
84						return Ok(());
85					}
86				)*
87				Err($crate::router::RouterError::UnknownRoute)
88			}
89		}
90	};
91}
92
93#[cfg(test)]
94mod tests {
95	use std::sync::{mpsc, Arc};
96	use std::time::Duration;
97
98	use crate::compose;
99	use crate::der::Sequence;
100	use crate::router::RouterPolicy;
101	use crate::Beamable;
102	use crate::Frame;
103
104	#[cfg(not(feature = "derive"))]
105	use crate::router::RouterPolicy;
106
107	#[cfg_attr(feature = "derive", derive(Beamable))]
108	#[derive(Sequence, Clone, Debug, PartialEq)]
109	pub struct HealthCheck {
110		pub uptime: u64,
111	}
112
113	#[cfg(not(feature = "derive"))]
114	impl crate::Message for HealthCheck {
115		const MUST_BE_CONFIDENTIAL: bool = false;
116		const MUST_BE_NON_REPUDIABLE: bool = false;
117		const MUST_BE_COMPRESSED: bool = false;
118		const MUST_BE_PRIORITIZED: bool = false;
119		const MIN_VERSION: crate::Version = crate::Version::V0;
120	}
121
122	#[cfg_attr(feature = "derive", derive(Beamable))]
123	#[derive(Sequence, Clone, Debug, PartialEq)]
124	pub struct Payment {
125		pub from: String,
126		pub amount: u64,
127	}
128
129	#[cfg(not(feature = "derive"))]
130	impl crate::Message for Payment {
131		const MUST_BE_CONFIDENTIAL: bool = false;
132		const MUST_BE_NON_REPUDIABLE: bool = false;
133		const MUST_BE_COMPRESSED: bool = false;
134		const MUST_BE_PRIORITIZED: bool = false;
135		const MIN_VERSION: crate::Version = crate::Version::V0;
136	}
137
138	#[test]
139	fn test_mpsc_channel_routing() -> Result<(), Box<dyn std::error::Error>> {
140		#[cfg(feature = "derive")]
141		routes! {
142			ChannelRouter {
143				payment_tx: mpsc::Sender<Arc<Frame>>,
144				health_tx: mpsc::Sender<Arc<Frame>>,
145			}:
146				Payment |router, msg| {
147					let _ = router.payment_tx.send(msg);
148				}
149				HealthCheck |router, msg| {
150					let _ = router.health_tx.send(msg);
151				}
152		}
153
154		#[cfg(not(feature = "derive"))]
155		struct ChannelRouter {
156			payment_tx: mpsc::Sender<Arc<Frame>>,
157			health_tx: mpsc::Sender<Arc<Frame>>,
158		}
159
160		#[cfg(not(feature = "derive"))]
161		impl super::RouterPolicy for ChannelRouter {
162			fn dispatch<M: Message>(&self, message: Arc<Frame>) -> crate::router::Result<()> {
163				if std::any::TypeId::of::<M>() == std::any::TypeId::of::<Payment>() {
164					let _ = self.payment_tx.send(message);
165					return Ok(());
166				}
167
168				if std::any::TypeId::of::<M>() == std::any::TypeId::of::<HealthCheck>() {
169					let _ = self.health_tx.send(message);
170					return Ok(());
171				}
172
173				Err(super::RouterError::UnknownRoute)
174			}
175		}
176
177		let (payment_tx, payment_rx) = mpsc::channel::<Arc<Frame>>();
178		let (health_tx, health_rx) = mpsc::channel::<Arc<Frame>>();
179		let router = ChannelRouter { payment_tx, health_tx };
180
181		let n = 5usize;
182		for i in 0..n {
183			// Compose Payment
184			let payment = compose! {
185				V0: id: format!("p-{i}"),
186					order: 1u64,
187					message: Payment {
188						from: "alice".into(),
189						amount: i as u64
190					}
191			}?;
192			// Route
193			router.dispatch::<Payment>(Arc::new(payment))?;
194
195			// Compose HealthCheck
196			let health = compose! {
197				V0: id: format!("h-{i}"),
198					order: 1u64,
199					message: HealthCheck {
200						uptime: i as u64
201					}
202			}?;
203			// Route
204			router.dispatch::<HealthCheck>(Arc::new(health))?;
205		}
206
207		// Verify n messages per channel
208		let timeout = Duration::from_millis(200);
209		for i in 0..n {
210			let received_payment = payment_rx.recv_timeout(timeout)?;
211			let message: Payment = crate::decode(&received_payment.message)?;
212			assert_eq!(&received_payment.metadata.id, &format!("p-{i}").as_bytes());
213			assert_eq!(message, Payment { from: "alice".into(), amount: i as u64 });
214
215			let received_health = health_rx.recv_timeout(timeout)?;
216			let message: HealthCheck = crate::decode(&received_health.message)?;
217			assert_eq!(received_health.metadata.id, format!("h-{i}").as_bytes());
218			assert_eq!(message, HealthCheck { uptime: i as u64 });
219		}
220
221		Ok(())
222	}
223}