1use std::{collections::HashMap, fmt, sync::Arc};
2
3use bigerror::{ConversionError, IntoContext, LogError, Report};
4use tokio::{
5 sync::{
6 mpsc,
7 mpsc::{UnboundedReceiver, UnboundedSender},
8 },
9 task::JoinSet,
10};
11use tracing::{debug, error, trace, warn, Instrument};
12
13use crate::{
14 manager::{HashKind, Signal, SignalQueue},
15 notification::{Notification, NotificationProcessor, RexMessage},
16 queue::StreamableDeque,
17 Rex, RexError, StateId,
18};
19
20pub trait StateRouter<K>: Send + Sync
21where
22 K: HashKind,
23{
24 type Inbound;
25 fn get_id(&self, input: &Self::Inbound) -> Result<Option<StateId<K>>, Report<RexError>>;
26 fn get_kind(&self) -> K;
27}
28
29pub type BoxedStateRouter<K, In> = Box<dyn StateRouter<K, Inbound = In>>;
30
31pub struct PacketRouter<K, In>(Arc<HashMap<K, BoxedStateRouter<K, In>>>)
33where
34 K: HashKind;
35
36impl<K: HashKind, In> Clone for PacketRouter<K, In> {
37 fn clone(&self) -> Self {
38 Self(self.0.clone())
39 }
40}
41
42impl<K, In> PacketRouter<K, In>
43where
44 for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
45{
46 #[must_use]
47 pub fn new(state_routers: Vec<BoxedStateRouter<K, In>>) -> Self {
48 let mut router_map: HashMap<K, BoxedStateRouter<K, In>> = HashMap::new();
49 for router in state_routers {
50 if let Some(old_router) = router_map.insert(router.get_kind(), router) {
51 panic!(
52 "Found multiple routers for kind: {:?}",
53 old_router.get_kind()
54 );
55 }
56 }
57 Self(Arc::new(router_map))
58 }
59
60 fn get_id(&self, packet: &In) -> Result<Option<StateId<K>>, Report<RexError>> {
61 let kind = K::try_from(packet);
62 let kind = kind.map_err(IntoContext::into_ctx)?;
63 let Some(router) = self.0.get(&kind) else {
64 return Ok(None);
65 };
66 router.get_id(packet)
67 }
68}
69
70pub struct IngressAdapter<K>
72where
73 K: Rex + Ingress,
74 K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
75 K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
76{
77 pub(crate) outbound_tx: UnboundedSender<K::Out>,
78 pub(crate) signal_queue: SignalQueue<K>,
79 pub(crate) router: PacketRouter<K, K::In>,
80 pub inbound_tx: UnboundedSender<K::In>,
81 pub(crate) inbound_rx: Option<UnboundedReceiver<K::In>>,
83 pub(crate) topic: <K::Message as RexMessage>::Topic,
84}
85
86impl<K> IngressAdapter<K>
87where
88 K: Rex + Ingress,
89 K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
90 K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
91{
92 #[must_use]
93 pub fn new(
94 signal_queue: SignalQueue<K>,
95 outbound_tx: UnboundedSender<K::Out>,
96 state_routers: Vec<BoxedStateRouter<K, K::In>>,
97 topic: impl Into<<K::Message as RexMessage>::Topic>,
98 ) -> Self {
99 let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::<K::In>();
100
101 Self {
102 signal_queue,
103 outbound_tx,
104 router: PacketRouter::new(state_routers),
105 inbound_tx,
106 inbound_rx: Some(inbound_rx),
107 topic: topic.into(),
108 }
109 }
110
111 pub(crate) fn spawn_inbound(&mut self, join_set: &mut JoinSet<()>) {
112 let router = self.router.clone();
113 let signal_queue = self.signal_queue.clone();
114 let inbound_rx = self.inbound_rx.take().expect("inbound_rx missing");
115 join_set.spawn(Self::process_inbound(router, signal_queue, inbound_rx).in_current_span());
116 }
117
118 async fn process_inbound(
119 router: PacketRouter<K, K::In>,
120 signal_queue: Arc<StreamableDeque<Signal<K>>>,
121 mut packet_rx: UnboundedReceiver<K::In>,
122 ) {
123 debug!(target: "state_machine", spawning = "IngressAdapter.packet_tx");
124 while let Some(packet) = packet_rx.recv().await {
125 trace!("receiving packet");
126 let id = match router.get_id(&packet) {
127 Err(e) => {
128 error!(err = ?e, ?packet, "could not get id from router");
129 continue;
130 }
131 Ok(None) => {
132 warn!(?packet, "unable to route packet");
133 continue;
134 }
135 Ok(Some(state_id)) => state_id,
136 };
137 K::Input::try_from(packet)
138 .map(|input| {
139 signal_queue.push_back(Signal { id, input });
140 })
141 .log_attached_err("ia::processors from packet failed");
142 }
143 }
144}
145
146pub trait Ingress: Rex
147where
148 Self::Input: TryFrom<Self::In, Error = Report<ConversionError>>,
149 Self::Message: TryInto<Self::Out, Error = Report<ConversionError>>,
150 for<'a> Self: TryFrom<&'a Self::In, Error = Report<ConversionError>>,
151{
152 type In: Send + Sync + fmt::Debug + 'static;
153 type Out: Send + Sync + fmt::Debug + 'static;
154}
155
156impl<K> NotificationProcessor<K::Message> for IngressAdapter<K>
157where
158 K: Rex + Ingress,
159 K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
160 K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
161{
162 fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
163 debug!("calling IngressAdapter::process_inbound");
164 self.spawn_inbound(join_set);
165
166 debug!("starting IngressAdapter notification_tx");
167
168 let (input_tx, mut input_rx) = mpsc::unbounded_channel::<Notification<K::Message>>();
169 let outbound_tx = self.outbound_tx.clone();
170
171 let _notification_handle = join_set.spawn(
172 async move {
173 debug!(target: "state_machine", spawning = "IngressAdapter.notification_tx");
174 while let Some(notification) = input_rx.recv().await {
175 notification
176 .0
177 .try_into()
178 .map(|packet| {
179 trace!("sending packet");
180 outbound_tx.send(packet).log_err();
181 })
182 .log_attached_err("Invalid input");
183 }
184 }
185 .in_current_span(),
186 );
187
188 input_tx
189 }
190
191 fn get_topics(&self) -> &[<K::Message as RexMessage>::Topic] {
192 std::slice::from_ref(&self.topic)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use std::time::Duration;
199
200 use tokio::{sync::mpsc::UnboundedReceiver, task::JoinSet};
201 use tokio_stream::StreamExt;
202
203 use super::*;
204 use crate::{
205 notification::{NotificationManager, NotificationQueue},
206 test_support::*,
207 RexBuilder, StateId,
208 };
209
210 type TestIngressAdapter = (IngressAdapter<TestKind>, UnboundedReceiver<OutPacket>);
211
212 impl TestDefault for TestIngressAdapter {
213 fn test_default() -> Self {
214 let signal_queue = SignalQueue::default();
215 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
216
217 let adapter = IngressAdapter::new(
218 signal_queue,
219 outbound_tx,
220 vec![Box::new(TestStateRouter)],
221 TestTopic::Ingress,
222 );
223 (adapter, outbound_rx)
224 }
225 }
226
227 #[tokio::test]
228 #[tracing_test::traced_test]
229 async fn route_to_network() {
230 let (adapter, mut network_rx) = TestIngressAdapter::test_default();
231 let mut join_set = JoinSet::new();
232
233 let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
234 vec![Box::new(adapter)],
235 &mut join_set,
236 NotificationQueue::new(),
237 );
238 let notification_tx = notification_manager.init(&mut join_set);
239
240 let unknown_packet = OutPacket(b"unknown_packet".to_vec());
241
242 notification_tx.send(Notification(unknown_packet.clone().into()));
245 tokio::time::sleep(Duration::from_millis(1)).await;
246 assert_eq!(Ok(unknown_packet), network_rx.try_recv());
247
248 let unsupported_packet = OutPacket(b"unsupported_packet".to_vec());
249
250 notification_tx.send(Notification(unsupported_packet.clone().into()));
251 tokio::time::sleep(Duration::from_millis(1)).await;
252 assert_eq!(Ok(unsupported_packet), network_rx.try_recv());
253 }
254
255 #[tokio::test]
256 #[tracing_test::traced_test]
257 async fn route_from_network() {
258 let (adapter, _outbound_rx) = TestIngressAdapter::test_default();
259 let signal_queue = adapter.signal_queue.clone();
260 let signal_rx = signal_queue.stream().timeout(Duration::from_millis(2));
261 tokio::pin!(signal_rx);
262
263 let inboud_tx = adapter.inbound_tx.clone();
264 let mut join_set = JoinSet::new();
265
266 let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
267 vec![Box::new(adapter)],
268 &mut join_set,
269 NotificationQueue::new(),
270 );
271 let _notification_tx = notification_manager.init(&mut join_set);
272
273 let unknown_packet = InPacket(b"unknown_packet".to_vec());
275 inboud_tx.send(unknown_packet).unwrap();
276 signal_rx.next().await.unwrap().unwrap_err();
277
278 let supported_packet = InPacket(b"new_state".to_vec());
279 inboud_tx.send(supported_packet.clone()).unwrap();
280 let signal = signal_rx.next().await.unwrap().unwrap();
281 assert_eq!(
282 Signal {
283 id: StateId::new_with_u128(TestKind, 1),
284 input: TestInput::Packet(supported_packet),
285 },
286 signal,
287 );
288 }
289 #[tokio::test]
290 #[tracing_test::traced_test]
291 async fn rex_builder() {
292 let (outbound_tx, _outbound_rx) = mpsc::unbounded_channel::<OutPacket>();
294
295 let (inbound_tx, builder) = RexBuilder::new_connected(outbound_tx);
296 let ctx = builder
297 .with_ingress_adapter(vec![Box::new(TestStateRouter)], TestTopic::Ingress)
298 .build();
299
300 let signal_rx = ctx.signal_queue.stream().timeout(Duration::from_millis(2));
301 tokio::pin!(signal_rx);
302
303 let unknown_packet = InPacket(b"unknown_packet".to_vec());
305 inbound_tx.send(unknown_packet).unwrap();
306 signal_rx.next().await.unwrap().unwrap_err();
307
308 let supported_packet = InPacket(b"new_state".to_vec());
309 inbound_tx.send(supported_packet.clone()).unwrap();
310 let signal = signal_rx.next().await.unwrap().unwrap();
311 assert_eq!(
312 Signal {
313 id: StateId::new_with_u128(TestKind, 1),
314 input: TestInput::Packet(supported_packet),
315 },
316 signal,
317 );
318 }
319}