rex/
ingress.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use bigerror::{Report, ResultIntoContext};
4use tokio::sync::mpsc::{self, UnboundedSender};
5use tracing::{debug, error, trace, warn};
6
7use crate::{
8    manager::{HashKind, Signal, SignalQueue},
9    notification::{GetTopic, Message, Notification, NotificationProcessor, Topic},
10    queue::StreamableDeque,
11    StateId, StateMachineError,
12};
13use bigerror::{ConversionError, LogError};
14
15pub trait StateRouter<K>: Send + Sync
16where
17    K: HashKind,
18{
19    type Inbound;
20    fn get_id(
21        &self,
22        input: &Self::Inbound,
23    ) -> Result<Option<StateId<K>>, Report<StateMachineError>>;
24    fn get_kind(&self) -> K;
25}
26
27type BoxedStateRouter<K, In> = Box<dyn StateRouter<K, Inbound = In>>;
28
29/// top level router that holds all [`Kind`] indexed [`StateRouter`]s
30pub struct PacketRouter<K, In>(HashMap<K, BoxedStateRouter<K, In>>)
31where
32    K: HashKind;
33
34impl<'a, K, P: 'a> PacketRouter<K, P>
35where
36    K: HashKind + TryFrom<&'a P, Error = Report<ConversionError>>,
37{
38    fn get_id(&self, packet: &'a P) -> Result<Option<StateId<K>>, Report<StateMachineError>> {
39        let kind = K::try_from(packet).into_ctx()?;
40        let Some(router) = self.0.get(&kind) else {
41            return Ok(None);
42        };
43        router.get_id(packet)
44    }
45}
46
47/// Represents a bidirectional network connection
48pub struct IngressAdapter<K, SI, In, Out, T, const U: usize>
49where
50    K: HashKind,
51    SI: Send + Sync + fmt::Debug,
52    In: Send + Sync + fmt::Debug,
53    Out: Send + Sync + fmt::Debug,
54{
55    outbound_tx: UnboundedSender<Out>,
56    signal_queue: Arc<StreamableDeque<Signal<K, SI>>>,
57    router: Arc<PacketRouter<K, In>>,
58    // Option<P> is used to guard against
59    // an invalid <IngressAdapter as NotificationProcessor>::init (one where
60    // IngressAdapter::init_packet_processor was not called)
61    inbound_tx: Option<UnboundedSender<In>>,
62    topics: [Topic<T>; U],
63}
64
65impl<K, SI, In, Out, T, const U: usize> IngressAdapter<K, SI, In, Out, T, U>
66where
67    for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
68    SI: Send + Sync + fmt::Debug + TryFrom<In, Error = Report<ConversionError>> + 'static,
69    In: Send + Sync + fmt::Debug + 'static,
70    Out: Send + Sync + fmt::Debug + 'static,
71{
72    pub fn new<const N: usize>(
73        signal_queue: Arc<SignalQueue<K, SI>>,
74        outbound_tx: UnboundedSender<Out>,
75        state_routers: [BoxedStateRouter<K, In>; N],
76        topics: [Topic<T>; U],
77    ) -> Self {
78        let state_routers: HashMap<K, BoxedStateRouter<K, In>> = state_routers
79            .into_iter()
80            .map(|r| (r.get_kind(), r))
81            .collect();
82
83        Self {
84            signal_queue,
85            outbound_tx,
86            router: Arc::new(PacketRouter(state_routers)),
87            inbound_tx: None,
88            topics,
89        }
90    }
91
92    // This needs to be a precursor step for now
93    // TODO change to builder pattern
94    pub fn init_packet_processor(&mut self) -> UnboundedSender<In> {
95        let router = self.router.clone();
96        let signal_queue = self.signal_queue.clone();
97        let (packet_tx, mut packet_rx) = mpsc::unbounded_channel::<In>();
98        let _nw_handle = tokio::spawn(async move {
99            debug!(target: "state_machine", spawning = "IngressAdapter.packet_tx");
100            while let Some(packet) = packet_rx.recv().await {
101                trace!("receiving packet");
102                let id = match router.get_id(&packet) {
103                    Err(e) => {
104                        error!(err = ?e, ?packet, "could not get id from router");
105                        continue;
106                    }
107                    Ok(None) => {
108                        warn!(?packet, "unable to route packet");
109                        continue;
110                    }
111                    Ok(Some(state_id)) => state_id,
112                };
113                SI::try_from(packet)
114                    .map(|input| {
115                        signal_queue.push_back(Signal { id, input });
116                    })
117                    .log_attached_err("ia::processors from packet failed");
118            }
119        });
120        self.inbound_tx = Some(packet_tx.clone());
121
122        packet_tx
123    }
124
125    pub fn init_notification_processor<N>(&self) -> UnboundedSender<N>
126    where
127        N: TryInto<Out, Error = Report<ConversionError>> + Send + 'static + fmt::Debug,
128    {
129        debug!("starting IngressAdapter notification_tx");
130        self.inbound_tx
131            .as_ref()
132            .expect("IngressAdapter did not initialize packet_tx!");
133
134        let (input_tx, mut input_rx) = mpsc::unbounded_channel::<N>();
135        let outbound_tx = self.outbound_tx.clone();
136
137        let _notification_handle = tokio::spawn(async move {
138            debug!(target: "state_machine", spawning = "IngressAdapter.notification_tx");
139            while let Some(notification) = input_rx.recv().await {
140                notification
141                    .try_into()
142                    .map(|packet| {
143                        trace!("sending packet");
144                        outbound_tx.send(packet).log_err();
145                    })
146                    .log_attached_err("Invalid input");
147            }
148        });
149
150        input_tx
151    }
152}
153
154impl<K, T, M, SI, In, Out, const U: usize> NotificationProcessor<T, Notification<K, M>>
155    for IngressAdapter<K, SI, In, Out, T, U>
156where
157    for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
158    SI: Send + Sync + fmt::Debug + TryFrom<In, Error = Report<ConversionError>> + 'static,
159    M: Message + GetTopic<T>,
160    T: Send + Sync + 'static,
161    In: Send + Sync + fmt::Debug + 'static,
162    Out: Send
163        + Sync
164        + fmt::Debug
165        + TryFrom<Notification<K, M>, Error = Report<ConversionError>>
166        + 'static,
167{
168    fn init(&self) -> UnboundedSender<Notification<K, M>> {
169        self.init_notification_processor()
170    }
171
172    fn get_topics(&self) -> &[Topic<T>] {
173        &self.topics
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use std::time::Duration;
180
181    use tokio::sync::mpsc::UnboundedReceiver;
182    use tokio_stream::StreamExt;
183
184    use super::*;
185    use crate::{notification::NotificationManager, test_support::*, StateId, TestDefault};
186
187    type TestIngressAdapter = (
188        IngressAdapter<TestKind, TestInput, InPacket, OutPacket, TestTopic, 1>,
189        UnboundedReceiver<OutPacket>,
190    );
191
192    impl TestDefault for TestIngressAdapter {
193        fn test_default() -> Self {
194            let signal_queue = Arc::new(SignalQueue::new());
195            let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
196
197            let nw_adapter = IngressAdapter::new(
198                signal_queue,
199                outbound_tx,
200                [Box::new(TestStateRouter)],
201                [Topic::Message(TestTopic::Ingress)],
202            );
203            (nw_adapter, outbound_rx)
204        }
205    }
206
207    #[tokio::test]
208    #[tracing_test::traced_test]
209    async fn route_to_network() {
210        let (mut nw_adapter, mut network_rx) = TestIngressAdapter::test_default();
211        let _inbound_tx = nw_adapter.init_packet_processor();
212
213        let notification_manager: NotificationManager<TestTopic, Notification<_, TestMsg>> =
214            NotificationManager::new([&nw_adapter]);
215        let notification_tx = notification_manager.init();
216
217        let unknown_packet = OutPacket(b"unknown_packet".to_vec());
218
219        // Any packet should get to the GatewayClient since routing rules
220        // are only used at the ingress of the state machine
221        notification_tx.send(unknown_packet.clone().into()).unwrap();
222        tokio::time::sleep(Duration::from_millis(1)).await;
223        assert_eq!(Ok(unknown_packet), network_rx.try_recv());
224
225        let unsupported_packet = OutPacket(b"unsupported_packet".to_vec());
226
227        notification_tx
228            .send(unsupported_packet.clone().into())
229            .unwrap();
230        tokio::time::sleep(Duration::from_millis(1)).await;
231        assert_eq!(Ok(unsupported_packet), network_rx.try_recv());
232    }
233
234    #[tokio::test]
235    #[tracing_test::traced_test]
236    async fn route_from_network() {
237        let (mut nw_adapter, _outbound_rx) = TestIngressAdapter::test_default();
238        let signal_queue = nw_adapter.signal_queue.clone();
239        let signal_rx = signal_queue.stream().timeout(Duration::from_millis(2));
240        tokio::pin!(signal_rx);
241
242        let inboud_tx = nw_adapter.init_packet_processor();
243
244        let notification_manager: NotificationManager<TestTopic, Notification<_, TestMsg>> =
245            NotificationManager::new([&nw_adapter]);
246        let _notification_tx = notification_manager.init();
247
248        // An unknown packet should be unrouteable
249        let unknown_packet = InPacket(b"unknown_packet".to_vec());
250        inboud_tx.send(unknown_packet).unwrap();
251        signal_rx.next().await.unwrap().unwrap_err();
252
253        let supported_packet = InPacket(b"new_state".to_vec());
254        inboud_tx.send(supported_packet.clone()).unwrap();
255        let signal = signal_rx.next().await.unwrap().unwrap();
256        assert_eq!(
257            Signal {
258                id: StateId::new_with_u128(TestKind, 1),
259                input: TestInput::Packet(supported_packet),
260            },
261            signal,
262        );
263    }
264}