rex/
ingress.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use bigerror::{ConversionError, IntoContext, LogError, Report};
4use tokio::{
5    sync::{
6        mpsc,
7        mpsc::{UnboundedReceiver, UnboundedSender},
8    },
9    task::JoinSet,
10};
11use tracing::{debug, error, trace, warn, Instrument};
12
13use crate::{
14    manager::{HashKind, Signal, SignalQueue},
15    notification::{Notification, NotificationProcessor, RexMessage},
16    queue::StreamableDeque,
17    Rex, RexError, StateId,
18};
19
20pub trait StateRouter<K>: Send + Sync
21where
22    K: HashKind,
23{
24    type Inbound;
25    fn get_id(&self, input: &Self::Inbound) -> Result<Option<StateId<K>>, Report<RexError>>;
26    fn get_kind(&self) -> K;
27}
28
29pub type BoxedStateRouter<K, In> = Box<dyn StateRouter<K, Inbound = In>>;
30
31/// top level router that holds all [`Kind`] indexed [`StateRouter`]s
32pub struct PacketRouter<K, In>(Arc<HashMap<K, BoxedStateRouter<K, In>>>)
33where
34    K: HashKind;
35
36impl<K: HashKind, In> Clone for PacketRouter<K, In> {
37    fn clone(&self) -> Self {
38        Self(self.0.clone())
39    }
40}
41
42impl<K, In> PacketRouter<K, In>
43where
44    for<'a> K: HashKind + TryFrom<&'a In, Error = Report<ConversionError>>,
45{
46    #[must_use]
47    pub fn new(state_routers: Vec<BoxedStateRouter<K, In>>) -> Self {
48        let mut router_map: HashMap<K, BoxedStateRouter<K, In>> = HashMap::new();
49        for router in state_routers {
50            if let Some(old_router) = router_map.insert(router.get_kind(), router) {
51                panic!(
52                    "Found multiple routers for kind: {:?}",
53                    old_router.get_kind()
54                );
55            }
56        }
57        Self(Arc::new(router_map))
58    }
59
60    fn get_id(&self, packet: &In) -> Result<Option<StateId<K>>, Report<RexError>> {
61        let kind = K::try_from(packet);
62        let kind = kind.map_err(IntoContext::into_ctx)?;
63        let Some(router) = self.0.get(&kind) else {
64            return Ok(None);
65        };
66        router.get_id(packet)
67    }
68}
69
70/// Represents a bidirectional network connection
71pub struct IngressAdapter<K>
72where
73    K: Rex + Ingress,
74    K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
75    K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
76{
77    pub(crate) outbound_tx: UnboundedSender<K::Out>,
78    pub(crate) signal_queue: SignalQueue<K>,
79    pub(crate) router: PacketRouter<K, K::In>,
80    pub inbound_tx: UnboundedSender<K::In>,
81    // `self.inbound_rx.take()` will be used on initialization
82    pub(crate) inbound_rx: Option<UnboundedReceiver<K::In>>,
83    pub(crate) topic: <K::Message as RexMessage>::Topic,
84}
85
86impl<K> IngressAdapter<K>
87where
88    K: Rex + Ingress,
89    K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
90    K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
91{
92    #[must_use]
93    pub fn new(
94        signal_queue: SignalQueue<K>,
95        outbound_tx: UnboundedSender<K::Out>,
96        state_routers: Vec<BoxedStateRouter<K, K::In>>,
97        topic: impl Into<<K::Message as RexMessage>::Topic>,
98    ) -> Self {
99        let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::<K::In>();
100
101        Self {
102            signal_queue,
103            outbound_tx,
104            router: PacketRouter::new(state_routers),
105            inbound_tx,
106            inbound_rx: Some(inbound_rx),
107            topic: topic.into(),
108        }
109    }
110
111    pub(crate) fn spawn_inbound(&mut self, join_set: &mut JoinSet<()>) {
112        let router = self.router.clone();
113        let signal_queue = self.signal_queue.clone();
114        let inbound_rx = self.inbound_rx.take().expect("inbound_rx missing");
115        join_set.spawn(Self::process_inbound(router, signal_queue, inbound_rx).in_current_span());
116    }
117
118    async fn process_inbound(
119        router: PacketRouter<K, K::In>,
120        signal_queue: Arc<StreamableDeque<Signal<K>>>,
121        mut packet_rx: UnboundedReceiver<K::In>,
122    ) {
123        debug!(target: "state_machine", spawning = "IngressAdapter.packet_tx");
124        while let Some(packet) = packet_rx.recv().await {
125            trace!("receiving packet");
126            let id = match router.get_id(&packet) {
127                Err(e) => {
128                    error!(err = ?e, ?packet, "could not get id from router");
129                    continue;
130                }
131                Ok(None) => {
132                    warn!(?packet, "unable to route packet");
133                    continue;
134                }
135                Ok(Some(state_id)) => state_id,
136            };
137            K::Input::try_from(packet)
138                .map(|input| {
139                    signal_queue.push_back(Signal { id, input });
140                })
141                .log_attached_err("ia::processors from packet failed");
142        }
143    }
144}
145
146pub trait Ingress: Rex
147where
148    Self::Input: TryFrom<Self::In, Error = Report<ConversionError>>,
149    Self::Message: TryInto<Self::Out, Error = Report<ConversionError>>,
150    for<'a> Self: TryFrom<&'a Self::In, Error = Report<ConversionError>>,
151{
152    type In: Send + Sync + fmt::Debug + 'static;
153    type Out: Send + Sync + fmt::Debug + 'static;
154}
155
156impl<K> NotificationProcessor<K::Message> for IngressAdapter<K>
157where
158    K: Rex + Ingress,
159    K::Input: TryFrom<K::In, Error = Report<ConversionError>>,
160    K::Message: TryInto<K::Out, Error = Report<ConversionError>>,
161{
162    fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<K::Message>> {
163        debug!("calling IngressAdapter::process_inbound");
164        self.spawn_inbound(join_set);
165
166        debug!("starting IngressAdapter notification_tx");
167
168        let (input_tx, mut input_rx) = mpsc::unbounded_channel::<Notification<K::Message>>();
169        let outbound_tx = self.outbound_tx.clone();
170
171        let _notification_handle = join_set.spawn(
172            async move {
173                debug!(target: "state_machine", spawning = "IngressAdapter.notification_tx");
174                while let Some(notification) = input_rx.recv().await {
175                    notification
176                        .0
177                        .try_into()
178                        .map(|packet| {
179                            trace!("sending packet");
180                            outbound_tx.send(packet).log_err();
181                        })
182                        .log_attached_err("Invalid input");
183                }
184            }
185            .in_current_span(),
186        );
187
188        input_tx
189    }
190
191    fn get_topics(&self) -> &[<K::Message as RexMessage>::Topic] {
192        std::slice::from_ref(&self.topic)
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use std::time::Duration;
199
200    use tokio::{sync::mpsc::UnboundedReceiver, task::JoinSet};
201    use tokio_stream::StreamExt;
202
203    use super::*;
204    use crate::{
205        notification::{NotificationManager, NotificationQueue},
206        test_support::*,
207        RexBuilder, StateId,
208    };
209
210    type TestIngressAdapter = (IngressAdapter<TestKind>, UnboundedReceiver<OutPacket>);
211
212    impl TestDefault for TestIngressAdapter {
213        fn test_default() -> Self {
214            let signal_queue = SignalQueue::default();
215            let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
216
217            let adapter = IngressAdapter::new(
218                signal_queue,
219                outbound_tx,
220                vec![Box::new(TestStateRouter)],
221                TestTopic::Ingress,
222            );
223            (adapter, outbound_rx)
224        }
225    }
226
227    #[tokio::test]
228    #[tracing_test::traced_test]
229    async fn route_to_network() {
230        let (adapter, mut network_rx) = TestIngressAdapter::test_default();
231        let mut join_set = JoinSet::new();
232
233        let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
234            vec![Box::new(adapter)],
235            &mut join_set,
236            NotificationQueue::new(),
237        );
238        let notification_tx = notification_manager.init(&mut join_set);
239
240        let unknown_packet = OutPacket(b"unknown_packet".to_vec());
241
242        // Any packet should get to the GatewayClient since routing rules
243        // are only used at the ingress of the state machine
244        notification_tx.send(Notification(unknown_packet.clone().into()));
245        tokio::time::sleep(Duration::from_millis(1)).await;
246        assert_eq!(Ok(unknown_packet), network_rx.try_recv());
247
248        let unsupported_packet = OutPacket(b"unsupported_packet".to_vec());
249
250        notification_tx.send(Notification(unsupported_packet.clone().into()));
251        tokio::time::sleep(Duration::from_millis(1)).await;
252        assert_eq!(Ok(unsupported_packet), network_rx.try_recv());
253    }
254
255    #[tokio::test]
256    #[tracing_test::traced_test]
257    async fn route_from_network() {
258        let (adapter, _outbound_rx) = TestIngressAdapter::test_default();
259        let signal_queue = adapter.signal_queue.clone();
260        let signal_rx = signal_queue.stream().timeout(Duration::from_millis(2));
261        tokio::pin!(signal_rx);
262
263        let inboud_tx = adapter.inbound_tx.clone();
264        let mut join_set = JoinSet::new();
265
266        let notification_manager: NotificationManager<TestMsg> = NotificationManager::new(
267            vec![Box::new(adapter)],
268            &mut join_set,
269            NotificationQueue::new(),
270        );
271        let _notification_tx = notification_manager.init(&mut join_set);
272
273        // An unknown packet should be unrouteable
274        let unknown_packet = InPacket(b"unknown_packet".to_vec());
275        inboud_tx.send(unknown_packet).unwrap();
276        signal_rx.next().await.unwrap().unwrap_err();
277
278        let supported_packet = InPacket(b"new_state".to_vec());
279        inboud_tx.send(supported_packet.clone()).unwrap();
280        let signal = signal_rx.next().await.unwrap().unwrap();
281        assert_eq!(
282            Signal {
283                id: StateId::new_with_u128(TestKind, 1),
284                input: TestInput::Packet(supported_packet),
285            },
286            signal,
287        );
288    }
289    #[tokio::test]
290    #[tracing_test::traced_test]
291    async fn rex_builder() {
292        // TODO test outbound_rx
293        let (outbound_tx, _outbound_rx) = mpsc::unbounded_channel::<OutPacket>();
294
295        let (inbound_tx, builder) = RexBuilder::new_connected(outbound_tx);
296        let ctx = builder
297            .with_ingress_adapter(vec![Box::new(TestStateRouter)], TestTopic::Ingress)
298            .build();
299
300        let signal_rx = ctx.signal_queue.stream().timeout(Duration::from_millis(2));
301        tokio::pin!(signal_rx);
302
303        // An unknown packet should be unrouteable
304        let unknown_packet = InPacket(b"unknown_packet".to_vec());
305        inbound_tx.send(unknown_packet).unwrap();
306        signal_rx.next().await.unwrap().unwrap_err();
307
308        let supported_packet = InPacket(b"new_state".to_vec());
309        inbound_tx.send(supported_packet.clone()).unwrap();
310        let signal = signal_rx.next().await.unwrap().unwrap();
311        assert_eq!(
312            Signal {
313                id: StateId::new_with_u128(TestKind, 1),
314                input: TestInput::Packet(supported_packet),
315            },
316            signal,
317        );
318    }
319}