async_middleware/
lib.rs

1//! Middleware types.
2
3use async_trait::async_trait;
4use std::{future::Future, marker::PhantomData, sync::Arc};
5
6/// Middleware that transforms around an input to output type.
7#[async_trait]
8pub trait Transform<Args, T, O>: Send + Sync + 'static {
9    /// Asynchronously execute this handler to modify state
10    async fn transform(&self, input: T) -> O;
11}
12
13/// Middleware implementation for an async function that produces an output
14#[async_trait]
15impl<Func, Fut, O> Transform<(), (), O> for Func
16where
17    Func: Send + Sync + 'static + Fn() -> Fut,
18    Fut: Future<Output = O> + Send + Sync + 'static,
19    O: Send + Sync + 'static,
20{
21    async fn transform(&self, _input: ()) -> O {
22        (self)().await
23    }
24}
25
26/// Middleware implementation for an async function that returns nothing
27#[async_trait]
28impl<Func, Fut, T, O> Transform<(T, O), T, O> for Func
29where
30    Func: Send + Sync + 'static + Fn(T) -> Fut,
31    Fut: Future<Output = O> + Send + Sync + 'static,
32    T: Send + Sync + 'static,
33    O: Send + Sync + 'static,
34{
35    async fn transform(&self, input: T) -> O {
36        (self)(input).await
37    }
38}
39
40/// Middleware that performs an operation.
41#[async_trait]
42pub trait Middleware<I, O>: Send + Sync + 'static {
43    async fn call(&self, input: I) -> O;
44}
45
46/// Encapsulates the conversion between two different transform types
47pub struct ConvertMiddleware<T, T2, A, B, C> {
48    t: Arc<dyn Transform<T, A, B>>,
49    t2: Arc<dyn Transform<T2, B, C>>,
50}
51
52/// Implements the transform trait on the conversion middleware (for downstream)
53#[async_trait]
54impl<T, T2, A, B, C> Transform<(A, C), A, C> for ConvertMiddleware<T, T2, A, B, C>
55where
56    T: Send + Sync + 'static,
57    T2: Send + Sync + 'static,
58    A: Send + Sync + 'static,
59    B: Send + Sync + 'static,
60    C: Send + Sync + 'static,
61{
62    async fn transform(&self, input: A) -> C {
63        let input = self.t.transform(input).await;
64        self.t2.transform(input).await
65    }
66}
67
68/// Implements the middleware trait on the conversion middleware to make it A -> C
69#[async_trait]
70impl<T, T2, A, B, C> Middleware<A, C> for ConvertMiddleware<T, T2, A, B, C>
71where
72    T: Send + Sync + 'static,
73    T2: Send + Sync + 'static,
74    A: Send + Sync + 'static,
75    B: Send + Sync + 'static,
76    C: Send + Sync + 'static,
77{
78    async fn call(&self, input: A) -> C {
79        self.transform(input).await
80    }
81}
82
83/// Creates a new conversion middleware from two existing transforms
84pub fn convert<T, T2, A, B, C>(
85    t: impl Transform<T, A, B>,
86    t2: impl Transform<T2, B, C>,
87) -> ConvertMiddleware<T, T2, A, B, C>
88where
89    T: Send + Sync + 'static,
90    T2: Send + Sync + 'static,
91    A: Send + Sync + 'static,
92    B: Send + Sync + 'static,
93    C: Send + Sync + 'static,
94{
95    ConvertMiddleware {
96        t: Arc::new(t),
97        t2: Arc::new(t2),
98    }
99}
100
101/// Pied constructs the way we pipe between lots of functions via middleware
102pub struct Pied<T, Args, I, O> {
103    middleware: Arc<dyn Middleware<I, O>>,
104    _phantom: PhantomData<T>,
105    _phantom2: PhantomData<Args>,
106}
107
108/// Implements the middleware trait for the main Pied structure
109#[async_trait]
110impl<T, Args, I, O> Middleware<I, O> for Pied<T, Args, I, O>
111where
112    T: Send + Sync + 'static,
113    Args: Send + Sync + 'static,
114    I: Send + Sync + 'static,
115    O: Send + Sync + 'static,
116{
117    async fn call(&self, input: I) -> O {
118        self.middleware.call(input).await
119    }
120}
121
122#[async_trait]
123impl<T, Args, I, O> Transform<(I, O), I, O> for Pied<T, Args, I, O>
124where
125    T: Send + Sync + 'static,
126    Args: Send + Sync + 'static,
127    I: Send + Sync + 'static,
128    O: Send + Sync + 'static,
129{
130    async fn transform(&self, input: I) -> O {
131        self.middleware.call(input).await
132    }
133}
134
135/// Common pipe trait used to create implementations for each tuple
136pub trait Piper<T, Args, I, O> {
137    fn pipe(self) -> Pied<T, Args, I, O>;
138}
139
140/// Helper utility to execute the .pipe on a Pipe implementation and returns a middleware
141pub fn pipe<T, Args, I, O>(f: impl Piper<T, Args, I, O>) -> Pied<T, Args, I, O>
142where
143    T: Send + Sync + 'static,
144    Args: Send + Sync + 'static,
145    I: Send + Sync + 'static,
146    O: Send + Sync + 'static,
147{
148    f.pipe()
149}
150
151// Pipe middleware for source -> transform from (A, B)
152impl<T, O, A, B> Piper<(T, O), (A, B), (), O> for (A, B)
153where
154    A: Transform<(), (), T>,
155    B: Transform<(T, O), T, O>,
156    T: Send + Sync + 'static,
157    O: Send + Sync + 'static,
158{
159    fn pipe(self) -> Pied<(T, O), (A, B), (), O> {
160        let args = self;
161        Pied {
162            middleware: Arc::new(convert(args.0, args.1)),
163            _phantom: PhantomData::default(),
164            _phantom2: PhantomData::default(),
165        }
166    }
167}
168
169// Pipe middleware for transform -> transform from (A, B)
170impl<T, T2, O, A, B> Piper<(T, T2, O), (A, B), T, O> for (A, B)
171where
172    A: Transform<(T, T2), T, T2>,
173    B: Transform<(T2, O), T2, O>,
174    T: Send + Sync + 'static,
175    T2: Send + Sync + 'static,
176    O: Send + Sync + 'static,
177{
178    fn pipe(self) -> Pied<(T, T2, O), (A, B), T, O> {
179        let args = self;
180        Pied {
181            middleware: Arc::new(convert(args.0, args.1)),
182            _phantom: PhantomData::default(),
183            _phantom2: PhantomData::default(),
184        }
185    }
186}
187
188// Pipe middleware for source -> transform -> transform for (A, B, C)
189impl<T, T2, O, A, B, C> Piper<(T, T2, O), (A, B, C), (), O> for (A, B, C)
190where
191    A: Transform<(), (), T>,
192    B: Transform<(T, T2), T, T2>,
193    C: Transform<(T2, O), T2, O>,
194    T: Send + Sync + 'static,
195    T2: Send + Sync + 'static,
196    O: Send + Sync + 'static,
197{
198    fn pipe(self) -> Pied<(T, T2, O), (A, B, C), (), O> {
199        let args = self;
200        Pied {
201            middleware: Arc::new(convert(convert(args.0, args.1), args.2)),
202            _phantom: PhantomData::default(),
203            _phantom2: PhantomData::default(),
204        }
205    }
206}
207
208// Pipe middleware for transform -> transform -> transform for (A, B, C)
209impl<T, T2, T3, O, A, B, C> Piper<(T, T2, T3, O), (A, B, C), T, O> for (A, B, C)
210where
211    A: Transform<(T, T2), T, T2>,
212    B: Transform<(T2, T3), T2, T3>,
213    C: Transform<(T3, O), T3, O>,
214    T: Send + Sync + 'static,
215    T2: Send + Sync + 'static,
216    T3: Send + Sync + 'static,
217    O: Send + Sync + 'static,
218{
219    fn pipe(self) -> Pied<(T, T2, T3, O), (A, B, C), T, O> {
220        let args = self;
221        Pied {
222            middleware: Arc::new(convert(convert(args.0, args.1), args.2)),
223            _phantom: PhantomData::default(),
224            _phantom2: PhantomData::default(),
225        }
226    }
227}
228
229// Pipe middleware for source -> transform -> transform -> transform for (A, B, C, D)
230impl<T, T2, T3, O, A, B, C, D> Piper<(T, T2, T3, O), (A, B, C, D), (), O> for (A, B, C, D)
231where
232    A: Transform<(), (), T>,
233    B: Transform<(T, T2), T, T2>,
234    C: Transform<(T2, T3), T2, T3>,
235    D: Transform<(T3, O), T3, O>,
236    T: Send + Sync + 'static,
237    T2: Send + Sync + 'static,
238    T3: Send + Sync + 'static,
239    O: Send + Sync + 'static,
240{
241    fn pipe(self) -> Pied<(T, T2, T3, O), (A, B, C, D), (), O> {
242        let args = self;
243        Pied {
244            middleware: Arc::new(convert(convert(convert(args.0, args.1), args.2), args.3)),
245            _phantom: PhantomData::default(),
246            _phantom2: PhantomData::default(),
247        }
248    }
249}
250
251// Pipe middleware for transform -> transform -> transform -> transform for (A, B, C, D)
252impl<T, T2, T3, T4, O, A, B, C, D> Piper<(T, T2, T3, T4, O), (A, B, C, D), T, O> for (A, B, C, D)
253where
254    A: Transform<(T, T2), T, T2>,
255    B: Transform<(T2, T3), T2, T3>,
256    C: Transform<(T3, T4), T3, T4>,
257    D: Transform<(T4, O), T4, O>,
258    T: Send + Sync + 'static,
259    T2: Send + Sync + 'static,
260    T3: Send + Sync + 'static,
261    T4: Send + Sync + 'static,
262    O: Send + Sync + 'static,
263{
264    fn pipe(self) -> Pied<(T, T2, T3, T4, O), (A, B, C, D), T, O> {
265        let args = self;
266        Pied {
267            middleware: Arc::new(convert(convert(convert(args.0, args.1), args.2), args.3)),
268            _phantom: PhantomData::default(),
269            _phantom2: PhantomData::default(),
270        }
271    }
272}
273
274// Pipe middleware for source -> transform -> transform -> transform -> transform for (A, B, C, D, E)
275impl<T, T2, T3, T4, O, A, B, C, D, E> Piper<(T, T2, T3, T4, O), (A, B, C, D, E), (), O>
276    for (A, B, C, D, E)
277where
278    A: Transform<(), (), T>,
279    B: Transform<(T, T2), T, T2>,
280    C: Transform<(T2, T3), T2, T3>,
281    D: Transform<(T3, T4), T3, T4>,
282    E: Transform<(T4, O), T4, O>,
283    T: Send + Sync + 'static,
284    T2: Send + Sync + 'static,
285    T3: Send + Sync + 'static,
286    T4: Send + Sync + 'static,
287    O: Send + Sync + 'static,
288{
289    fn pipe(self) -> Pied<(T, T2, T3, T4, O), (A, B, C, D, E), (), O> {
290        let args = self;
291        Pied {
292            middleware: Arc::new(convert(
293                convert(convert(convert(args.0, args.1), args.2), args.3),
294                args.4,
295            )),
296            _phantom: PhantomData::default(),
297            _phantom2: PhantomData::default(),
298        }
299    }
300}
301
302// Pipe middleware for transform -> transform -> transform -> transform -> transform for (A, B, C, D, E)
303impl<T, T2, T3, T4, T5, O, A, B, C, D, E> Piper<(T, T2, T3, T4, T5, O), (A, B, C, D, E), T, O>
304    for (A, B, C, D, E)
305where
306    A: Transform<(T, T2), T, T2>,
307    B: Transform<(T2, T3), T2, T3>,
308    C: Transform<(T3, T4), T3, T4>,
309    D: Transform<(T4, T5), T4, T5>,
310    E: Transform<(T5, O), T5, O>,
311    T: Send + Sync + 'static,
312    T2: Send + Sync + 'static,
313    T3: Send + Sync + 'static,
314    T4: Send + Sync + 'static,
315    T5: Send + Sync + 'static,
316    O: Send + Sync + 'static,
317{
318    fn pipe(self) -> Pied<(T, T2, T3, T4, T5, O), (A, B, C, D, E), T, O> {
319        let args = self;
320        Pied {
321            middleware: Arc::new(convert(
322                convert(convert(convert(args.0, args.1), args.2), args.3),
323                args.4,
324            )),
325            _phantom: PhantomData::default(),
326            _phantom2: PhantomData::default(),
327        }
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    async fn producer() -> i32 {
336        3
337    }
338
339    async fn multipler(i: i32) -> i32 {
340        i * 32
341    }
342
343    async fn stringer(i: i32) -> String {
344        i.to_string()
345    }
346
347    async fn logger(s: String) {
348        println!("{}", s);
349    }
350
351    async fn log_nums(i: i32) {
352        println!("{}", i);
353    }
354
355    #[async_std::test]
356    async fn test_piper_tuple() {
357        pipe((producer, log_nums));
358        pipe((producer, stringer, logger));
359        pipe((producer, multipler, stringer, logger));
360        pipe((multipler, multipler, multipler));
361        pipe((multipler, multipler, stringer));
362
363        // alternative syntax
364        (producer, log_nums).pipe();
365        (producer, stringer, logger).pipe();
366        (producer, multipler, stringer, logger).pipe();
367        (multipler, multipler, multipler).pipe();
368        (multipler, multipler, stringer).pipe();
369
370        // pipe different pipes
371        let m = (producer, multipler).pipe(); // 3 * 32 = 96
372        let m = (m, multipler).pipe(); // * 32 = 3072
373        let m = pipe((m, stringer)); // 3072
374
375        assert_eq!(String::from("3072"), m.call(()).await);
376    }
377
378    #[async_std::test]
379    async fn test_piper_tuple_inputs() {
380        let m = (multipler, multipler, stringer).pipe();
381        assert_eq!(String::from("1024"), m.call(1).await);
382        assert_eq!(String::from("2048"), m.call(2).await);
383        assert_eq!(String::from("3072"), m.call(3).await);
384    }
385
386    // lack of support for variadics at the moment for the initial source
387    // downstream functions will only be able to accept a single value
388    // as a future's output can only be a single return value
389    // input should however be flexible to be variadic here though
390    async fn multi(a: i32, b: i32) -> i32 {
391        a + b
392    }
393
394    #[cfg(todo)]
395    #[async_std::test]
396    async fn test_piper_multiple_tuple_inputs() {
397        let m = (multi, multipler, stringer).pipe();
398        assert_eq!(String::from("1024"), m.call(1).await);
399        assert_eq!(String::from("2048"), m.call(2).await);
400        assert_eq!(String::from("3072"), m.call(3).await);
401    }
402
403    #[test]
404    fn test_convert_transform() {
405        convert(multipler, stringer);
406        convert(multipler, multipler);
407    }
408
409    #[test]
410    fn test_source_transform() {
411        convert(producer, multipler);
412    }
413
414    #[test]
415    fn test_source_sink() {
416        convert(producer, log_nums);
417    }
418
419    #[test]
420    fn test_transform() {
421        convert(convert(producer, multipler), stringer);
422    }
423
424    #[test]
425    fn test_transform_source_transform_sink() {
426        convert(convert(convert(producer, multipler), stringer), logger);
427    }
428}