1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
/// Creates a new struct around a [`UntypedTransport`](crate::UntypedTransport) that routes incoming
/// and outgoing messages to different transports, enabling the ability to transform a singular
/// transport into multiple typed transports that can be combined with [`Client`](crate::Client)
/// and [`Server`](crate::Server) to mix having a variety of clients and servers available on the
/// same underlying [`UntypedTransport`](crate::UntypedTransport).
///
/// ```no_run
/// use distant_net::router;
///
/// # // To send, the data needs to be serializable
/// # // To receive, the data needs to be deserializable
/// # #[derive(serde::Serialize, serde::Deserialize)]
/// # struct CustomData(u8, u8);
///
/// // Create a router that produces three transports from one:
/// // 1. `Transport<u8, String>` - receives `String` and sends `u8`
/// // 2. `Transport<bool, CustomData>` - receives `CustomData` and sends `bool`
/// // 3. `Transport<Option<String>, u8>` - receives `u8` and sends `Option<String>`
/// router!(TestRouter {
///     one: String => u8,
///     two: CustomData => bool,
///     three: u8 => Option<String>,
/// });
///
/// router!(
///     #[router(inbound = 10, outbound = 20)]
///     TestRouterWithCustomBounds {
///         one: String => u8,
///         two: CustomData => bool,
///         three: u8 => Option<String>,
///     }
/// );
///
/// # let (transport, _) = distant_net::FramedTransport::pair(1);
///
/// let router = TestRouter::new(transport);
///
/// let one   = router.one;   // MpscTransport<u8, String>
/// let two   = router.two;   // MpscTransport<bool, CustomData>
/// let three = router.three; // MpscTransport<Option<String>, u8>
/// ```
#[macro_export]
macro_rules! router {
    (
        $(#[router($($mname:ident = $mvalue:literal),*)])?
        $vis:vis $name:ident {
            $($transport:ident : $res_ty:ty => $req_ty:ty),+ $(,)?
        }
    ) => {
        $crate::paste::paste! {
            #[doc = "Implements a message router for splitting out transport messages"]
            #[allow(dead_code)]
            $vis struct $name {
                reader_task: tokio::task::JoinHandle<()>,
                writer_task: tokio::task::JoinHandle<()>,
                $(
                    pub $transport: $crate::MpscTransport<$req_ty, $res_ty>,
                )+
            }

            #[allow(dead_code)]
            impl $name {
                /// Returns the size of the inbound buffer used by this router
                pub const fn inbound_buffer_size() -> usize {
                    Self::buffer_sizes().0
                }

                /// Returns the size of the outbound buffer used by this router
                pub const fn outbound_buffer_size() -> usize {
                    Self::buffer_sizes().1
                }

                /// Returns the size of the inbound and outbound buffers used by this router
                /// in the form of `(inbound, outbound)`
                pub const fn buffer_sizes() -> (usize, usize) {
                    // Set defaults for inbound and outbound buffer sizes
                    let _inbound = 10000;
                    let _outbound = 10000;

                    $($(
                        let [<_ $mname:snake>] = $mvalue;
                    )*)?

                    (_inbound, _outbound)
                }

                #[doc = "Creates a new instance of [`" $name "`]"]
                pub fn new<T, W, R>(split: T) -> Self
                where
                    T: $crate::IntoSplit<Write = W, Read = R>,
                    W: $crate::UntypedTransportWrite + 'static,
                    R: $crate::UntypedTransportRead + 'static,
                {
                    let (writer, reader) = split.into_split();
                    Self::from_writer_and_reader(writer, reader)
                }

                #[doc = "Creates a new instance of [`" $name "`] from the given writer and reader"]
                pub fn from_writer_and_reader<W, R>(mut writer: W, mut reader: R) -> Self
                where
                    W: $crate::UntypedTransportWrite + 'static,
                    R: $crate::UntypedTransportRead + 'static,
                {

                    $(
                        let (
                            [<$transport:snake _inbound_tx>],
                            [<$transport:snake _inbound_rx>]
                        ) = tokio::sync::mpsc::channel(Self::inbound_buffer_size());
                        let (
                            [<$transport:snake _outbound_tx>],
                            mut [<$transport:snake _outbound_rx>]
                        ) = tokio::sync::mpsc::channel(Self::outbound_buffer_size());
                        let [<$transport:snake>]: $crate::MpscTransport<$req_ty, $res_ty> =
                            $crate::MpscTransport::new(
                                [<$transport:snake _outbound_tx>],
                                [<$transport:snake _inbound_rx>]
                            );
                    )+

                    #[derive(serde::Deserialize)]
                    #[serde(untagged)]
                    enum [<$name:camel In>] {
                        $([<$transport:camel>]($res_ty)),+
                    }

                    let reader_task = tokio::spawn(async move {
                        loop {
                            match $crate::UntypedTransportRead::read(&mut reader).await {
                                $(
                                    Ok(Some([<$name:camel In>]::[<$transport:camel>](x))) => {
                                        if let Err(x) = [<$transport:snake _inbound_tx>].send(x).await {
                                            $crate::log::error!(
                                                "Failed to forward received data from {} of {}: {}",
                                                std::stringify!($transport),
                                                std::stringify!($name),
                                                x
                                            );
                                        }
                                    }
                                )+

                                // Quit if the reader no longer has data
                                // NOTE: Compiler says this is unreachable, but it is?
                                #[allow(unreachable_patterns)]
                                Ok(None) => {
                                    $crate::log::trace!(
                                        "Router {} has closed",
                                        std::stringify!($name),
                                    );
                                    break;
                                }

                                // Drop any received data that does not map to something
                                // NOTE: Compiler says this is unreachable, but it is?
                                #[allow(unreachable_patterns)]
                                Err(x) => {
                                    $crate::log::error!(
                                        "Failed to read from any transport of {}: {}",
                                        std::stringify!($name),
                                        x
                                    );
                                    continue;
                                }
                            }
                        }
                    });

                    let writer_task = tokio::spawn(async move {
                        loop {
                            tokio::select! {
                                $(
                                    Some(x) = [<$transport:snake _outbound_rx>].recv() => {
                                        if let Err(x) = $crate::UntypedTransportWrite::write(
                                            &mut writer,
                                            x,
                                        ).await {
                                            $crate::log::error!(
                                                "Failed to write to {} of {}: {}",
                                                std::stringify!($transport),
                                                std::stringify!($name),
                                                x
                                            );
                                        }
                                    }
                                )+
                                else => break,
                            }
                        }
                    });

                    Self {
                        reader_task,
                        writer_task,
                        $([<$transport:snake>]),+
                    }
                }

                pub fn abort(&self) {
                    self.reader_task.abort();
                    self.writer_task.abort();
                }

                pub fn is_finished(&self) -> bool {
                    self.reader_task.is_finished() && self.writer_task.is_finished()
                }
            }
        }
    };
}

#[cfg(test)]
mod tests {
    use crate::{FramedTransport, TypedAsyncRead, TypedAsyncWrite};
    use serde::{Deserialize, Serialize};

    // NOTE: Must implement deserialize for our router,
    //       but we also need serialize to send for our test
    #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
    struct CustomData(u8, String);

    // Creates a private `TestRouter` implementation
    //
    // 1. Transport receiving `CustomData` and sending `String`
    // 2. Transport receiving `String` and sending `u8`
    // 3. Transport receiving `bool` and sending `bool`
    // 4. Transport receiving `Result<String, bool>` and sending `Option<String>`
    router!(TestRouter {
        one: CustomData => String,
        two: String => u8,
        three: bool => bool,
        should_compile: Result<String, bool> => Option<String>,
    });

    #[test]
    fn router_buffer_sizes_should_support_being_overridden() {
        router!(DefaultSizes { data: u8 => u8 });
        router!(#[router(inbound = 5)] CustomInboundSize { data: u8 => u8 });
        router!(#[router(outbound = 5)] CustomOutboundSize { data: u8 => u8 });
        router!(#[router(inbound = 5, outbound = 6)] CustomSizes { data: u8 => u8 });

        assert_eq!(DefaultSizes::buffer_sizes(), (10000, 10000));
        assert_eq!(DefaultSizes::inbound_buffer_size(), 10000);
        assert_eq!(DefaultSizes::outbound_buffer_size(), 10000);

        assert_eq!(CustomInboundSize::buffer_sizes(), (5, 10000));
        assert_eq!(CustomInboundSize::inbound_buffer_size(), 5);
        assert_eq!(CustomInboundSize::outbound_buffer_size(), 10000);

        assert_eq!(CustomOutboundSize::buffer_sizes(), (10000, 5));
        assert_eq!(CustomOutboundSize::inbound_buffer_size(), 10000);
        assert_eq!(CustomOutboundSize::outbound_buffer_size(), 5);

        assert_eq!(CustomSizes::buffer_sizes(), (5, 6));
        assert_eq!(CustomSizes::inbound_buffer_size(), 5);
        assert_eq!(CustomSizes::outbound_buffer_size(), 6);
    }

    #[tokio::test]
    async fn router_should_wire_transports_to_distinguish_incoming_data() {
        let (t1, mut t2) = FramedTransport::make_test_pair();
        let TestRouter {
            mut one,
            mut two,
            mut three,
            ..
        } = TestRouter::new(t1);

        // Send some data of different types that these transports expect
        t2.write(false).await.unwrap();
        t2.write("hello world".to_string()).await.unwrap();
        t2.write(CustomData(123, "goodbye world".to_string()))
            .await
            .unwrap();

        // Get that data through the appropriate transport
        let data = one.read().await.unwrap().unwrap();
        assert_eq!(
            data,
            CustomData(123, "goodbye world".to_string()),
            "string_custom_data_transport got unexpected result"
        );

        let data = two.read().await.unwrap().unwrap();
        assert_eq!(
            data, "hello world",
            "u8_string_transport got unexpected result"
        );

        let data = three.read().await.unwrap().unwrap();
        assert!(!data, "bool_bool_transport got unexpected result");
    }

    #[tokio::test]
    async fn router_should_wire_transports_to_ignore_unknown_incoming_data() {
        let (t1, mut t2) = FramedTransport::make_test_pair();
        let TestRouter {
            mut one, mut two, ..
        } = TestRouter::new(t1);

        #[derive(Serialize, Deserialize)]
        struct UnknownData(char, u8);

        // Send some known and unknown data
        t2.write("hello world".to_string()).await.unwrap();
        t2.write(UnknownData('a', 99)).await.unwrap();
        t2.write(CustomData(123, "goodbye world".to_string()))
            .await
            .unwrap();

        // Get that data through the appropriate transport
        let data = one.read().await.unwrap().unwrap();
        assert_eq!(
            data,
            CustomData(123, "goodbye world".to_string()),
            "string_custom_data_transport got unexpected result"
        );

        let data = two.read().await.unwrap().unwrap();
        assert_eq!(
            data, "hello world",
            "u8_string_transport got unexpected result"
        );
    }

    #[tokio::test]
    async fn router_should_wire_transports_to_relay_outgoing_data() {
        let (t1, mut t2) = FramedTransport::make_test_pair();
        let TestRouter {
            mut one,
            mut two,
            mut three,
            ..
        } = TestRouter::new(t1);

        // NOTE: Introduce a sleep between each send, otherwise we are
        //       resolving futures in a way where the ordering may
        //       get mixed up on the way out
        async fn wait() {
            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
        }

        // Send some data of different types that these transports expect
        three.write(true).await.unwrap();
        wait().await;
        two.write(123).await.unwrap();
        wait().await;
        one.write("hello world".to_string()).await.unwrap();

        // All of that data should funnel through our primary transport,
        // but the order is NOT guaranteed! So we need to store
        let data: bool = t2.read().await.unwrap().unwrap();
        assert!(
            data,
            "Unexpected data received from bool_bool_transport output"
        );

        let data: u8 = t2.read().await.unwrap().unwrap();
        assert_eq!(
            data, 123,
            "Unexpected data received from u8_string_transport output"
        );

        let data: String = t2.read().await.unwrap().unwrap();
        assert_eq!(
            data, "hello world",
            "Unexpected data received from string_custom_data_transport output"
        );
    }
}