1use std::{collections::HashMap, fmt, sync::Arc};
2
3use bigerror::{Report, ResultIntoContext};
4use tokio::sync::mpsc::{self, UnboundedSender};
5use tracing::{debug, error, trace, warn};
6
7use crate::{
8 manager::{HashKind, Signal, SignalQueue},
9 notification::{GetTopic, Message, Notification, NotificationProcessor, Topic},
10 queue::StreamableDeque,
11 StateId, StateMachineError,
12};
13use bigerror::{ConversionError, LogError};
14
15pub trait StateRouter<K>: Send + Sync
16where
17 K: HashKind,
18{
19 type Inbound;
20 fn get_id(
21 &self,
22 input: &Self::Inbound,
23 ) -> Result<Option<StateId<K>>, Report<StateMachineError>>;
24 fn get_kind(&self) -> K;
25}
26
27type BoxedStateRouter<K, In> = Box<dyn StateRouter<K, Inbound = In>>;
28
29pub struct PacketRouter<K, In>(HashMap<K, BoxedStateRouter<K, In>>)
31where
32 K: HashKind;
33
34impl<'a, K, P: 'a> PacketRouter<K, P>
35where
36 K: HashKind + TryFrom<&'a P, Error = Report<ConversionError>>,
37{
38 fn get_id(&self, packet: &'a P) -> Result<Option<StateId<K>>, Report<StateMachineError>> {
39 let kind = K::try_from(packet).into_ctx()?;
40 let Some(router) = self.0.get(&kind) else {
41 return Ok(None);
42 };
43 router.get_id(packet)
44 }
45}
46
47pub struct IngressAdapter<K, SI, In, Out, T, const U: usize>
49where
50 K: HashKind,
51 SI: Send + Sync + fmt::Debug,
52 In: Send + Sync + fmt::Debug,
53 Out: Send + Sync + fmt::Debug,
54{
55 outbound_tx: UnboundedSender<Out>,
56 signal_queue: Arc<StreamableDeque<Signal<K, SI>>>,
57 router: Arc<PacketRouter<K, In>>,
58 inbound_tx: Option<UnboundedSender<In>>,
62 topics: [Topic<T>; U],
63}
64
65impl<K, SI, In, Out, T, const U: usize> IngressAdapter<K, SI, In, Out, T, U>
66where
67 for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
68 SI: Send + Sync + fmt::Debug + TryFrom<In, Error = Report<ConversionError>> + 'static,
69 In: Send + Sync + fmt::Debug + 'static,
70 Out: Send + Sync + fmt::Debug + 'static,
71{
72 pub fn new<const N: usize>(
73 signal_queue: Arc<SignalQueue<K, SI>>,
74 outbound_tx: UnboundedSender<Out>,
75 state_routers: [BoxedStateRouter<K, In>; N],
76 topics: [Topic<T>; U],
77 ) -> Self {
78 let state_routers: HashMap<K, BoxedStateRouter<K, In>> = state_routers
79 .into_iter()
80 .map(|r| (r.get_kind(), r))
81 .collect();
82
83 Self {
84 signal_queue,
85 outbound_tx,
86 router: Arc::new(PacketRouter(state_routers)),
87 inbound_tx: None,
88 topics,
89 }
90 }
91
92 pub fn init_packet_processor(&mut self) -> UnboundedSender<In> {
95 let router = self.router.clone();
96 let signal_queue = self.signal_queue.clone();
97 let (packet_tx, mut packet_rx) = mpsc::unbounded_channel::<In>();
98 let _nw_handle = tokio::spawn(async move {
99 debug!(target: "state_machine", spawning = "IngressAdapter.packet_tx");
100 while let Some(packet) = packet_rx.recv().await {
101 trace!("receiving packet");
102 let id = match router.get_id(&packet) {
103 Err(e) => {
104 error!(err = ?e, ?packet, "could not get id from router");
105 continue;
106 }
107 Ok(None) => {
108 warn!(?packet, "unable to route packet");
109 continue;
110 }
111 Ok(Some(state_id)) => state_id,
112 };
113 SI::try_from(packet)
114 .map(|input| {
115 signal_queue.push_back(Signal { id, input });
116 })
117 .log_attached_err("ia::processors from packet failed");
118 }
119 });
120 self.inbound_tx = Some(packet_tx.clone());
121
122 packet_tx
123 }
124
125 pub fn init_notification_processor<N>(&self) -> UnboundedSender<N>
126 where
127 N: TryInto<Out, Error = Report<ConversionError>> + Send + 'static + fmt::Debug,
128 {
129 debug!("starting IngressAdapter notification_tx");
130 self.inbound_tx
131 .as_ref()
132 .expect("IngressAdapter did not initialize packet_tx!");
133
134 let (input_tx, mut input_rx) = mpsc::unbounded_channel::<N>();
135 let outbound_tx = self.outbound_tx.clone();
136
137 let _notification_handle = tokio::spawn(async move {
138 debug!(target: "state_machine", spawning = "IngressAdapter.notification_tx");
139 while let Some(notification) = input_rx.recv().await {
140 notification
141 .try_into()
142 .map(|packet| {
143 trace!("sending packet");
144 outbound_tx.send(packet).log_err();
145 })
146 .log_attached_err("Invalid input");
147 }
148 });
149
150 input_tx
151 }
152}
153
154impl<K, T, M, SI, In, Out, const U: usize> NotificationProcessor<T, Notification<K, M>>
155 for IngressAdapter<K, SI, In, Out, T, U>
156where
157 for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
158 SI: Send + Sync + fmt::Debug + TryFrom<In, Error = Report<ConversionError>> + 'static,
159 M: Message + GetTopic<T>,
160 T: Send + Sync + 'static,
161 In: Send + Sync + fmt::Debug + 'static,
162 Out: Send
163 + Sync
164 + fmt::Debug
165 + TryFrom<Notification<K, M>, Error = Report<ConversionError>>
166 + 'static,
167{
168 fn init(&self) -> UnboundedSender<Notification<K, M>> {
169 self.init_notification_processor()
170 }
171
172 fn get_topics(&self) -> &[Topic<T>] {
173 &self.topics
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use std::time::Duration;
180
181 use tokio::sync::mpsc::UnboundedReceiver;
182 use tokio_stream::StreamExt;
183
184 use super::*;
185 use crate::{notification::NotificationManager, test_support::*, StateId, TestDefault};
186
187 type TestIngressAdapter = (
188 IngressAdapter<TestKind, TestInput, InPacket, OutPacket, TestTopic, 1>,
189 UnboundedReceiver<OutPacket>,
190 );
191
192 impl TestDefault for TestIngressAdapter {
193 fn test_default() -> Self {
194 let signal_queue = Arc::new(SignalQueue::new());
195 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
196
197 let nw_adapter = IngressAdapter::new(
198 signal_queue,
199 outbound_tx,
200 [Box::new(TestStateRouter)],
201 [Topic::Message(TestTopic::Ingress)],
202 );
203 (nw_adapter, outbound_rx)
204 }
205 }
206
207 #[tokio::test]
208 #[tracing_test::traced_test]
209 async fn route_to_network() {
210 let (mut nw_adapter, mut network_rx) = TestIngressAdapter::test_default();
211 let _inbound_tx = nw_adapter.init_packet_processor();
212
213 let notification_manager: NotificationManager<TestTopic, Notification<_, TestMsg>> =
214 NotificationManager::new([&nw_adapter]);
215 let notification_tx = notification_manager.init();
216
217 let unknown_packet = OutPacket(b"unknown_packet".to_vec());
218
219 notification_tx.send(unknown_packet.clone().into()).unwrap();
222 tokio::time::sleep(Duration::from_millis(1)).await;
223 assert_eq!(Ok(unknown_packet), network_rx.try_recv());
224
225 let unsupported_packet = OutPacket(b"unsupported_packet".to_vec());
226
227 notification_tx
228 .send(unsupported_packet.clone().into())
229 .unwrap();
230 tokio::time::sleep(Duration::from_millis(1)).await;
231 assert_eq!(Ok(unsupported_packet), network_rx.try_recv());
232 }
233
234 #[tokio::test]
235 #[tracing_test::traced_test]
236 async fn route_from_network() {
237 let (mut nw_adapter, _outbound_rx) = TestIngressAdapter::test_default();
238 let signal_queue = nw_adapter.signal_queue.clone();
239 let signal_rx = signal_queue.stream().timeout(Duration::from_millis(2));
240 tokio::pin!(signal_rx);
241
242 let inboud_tx = nw_adapter.init_packet_processor();
243
244 let notification_manager: NotificationManager<TestTopic, Notification<_, TestMsg>> =
245 NotificationManager::new([&nw_adapter]);
246 let _notification_tx = notification_manager.init();
247
248 let unknown_packet = InPacket(b"unknown_packet".to_vec());
250 inboud_tx.send(unknown_packet).unwrap();
251 signal_rx.next().await.unwrap().unwrap_err();
252
253 let supported_packet = InPacket(b"new_state".to_vec());
254 inboud_tx.send(supported_packet.clone()).unwrap();
255 let signal = signal_rx.next().await.unwrap().unwrap();
256 assert_eq!(
257 Signal {
258 id: StateId::new_with_u128(TestKind, 1),
259 input: TestInput::Packet(supported_packet),
260 },
261 signal,
262 );
263 }
264}