Skip to main content

axum_connect/handler/
handler_stream.rs

1use std::pin::Pin;
2
3use axum::body::Body;
4use axum::http::Request;
5use axum::response::Response;
6use futures::{Future, Stream, StreamExt};
7use prost::Message;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10
11use crate::parts::RpcFromRequestParts;
12use crate::response::RpcIntoResponse;
13
14use super::codec::{decode_check_headers, decode_request_payload, ReqResInto, ResponseEncoder};
15
16pub trait RpcHandlerStream<TMReq, TMRes, TUid, TState>:
17    Clone + Send + Sync + Sized + 'static
18{
19    type Future: Future<Output = Response> + Send + 'static;
20
21    fn call(self, req: Request<Body>, state: TState) -> Self::Future;
22}
23
24// TODO: Get "connect-timeout-ms" (number as string) and apply timeout.
25// TODO: Parse request metadata from:
26//      - [0-9a-z]*!"-bin" ASCII value
27//      - [0-9a-z]*-bin" (base64 encoded binary)
28// TODO: Allow response to send back both leading and trailing metadata.
29
30macro_rules! impl_handler {
31    (
32        [$($ty:ident),*]
33    ) => {
34        #[allow(unused_parens, non_snake_case, unused_mut)]
35        impl<TMReq, TMRes, TInto, TFnItem, TFnFut, TFn, TState, $($ty,)*>
36            RpcHandlerStream<TMReq, TMRes, ($($ty,)* TMReq), TState> for TFn
37        where
38            TMReq: Message + DeserializeOwned + Default + Send + 'static,
39            TMRes: Message + Serialize + Send + 'static,
40            TInto: RpcIntoResponse<TMRes>,
41            TFnItem: Stream<Item = TInto> + Send + Sized + 'static,
42            TFnFut: Future<Output = TFnItem> + Send + Sync,
43            TFn: FnOnce($($ty,)* TMReq) -> TFnFut + Clone + Send + Sync + 'static,
44            TState: Send + Sync + 'static,
45            $( $ty: RpcFromRequestParts<TMRes, TState> + Send, )*
46        {
47
48            type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
49
50            fn call(self, req: Request<Body>, state: TState) -> Self::Future {
51                Box::pin(async move {
52                    let (mut parts, body) = req.into_parts();
53
54                    let ReqResInto { binary } = match decode_check_headers(&mut parts, true) {
55                        Ok(binary) => binary,
56                        Err(e) => return e,
57                    };
58
59                    let state = &state;
60
61                    $(
62                    let $ty = match $ty::rpc_from_request_parts(&mut parts, state).await {
63                        Ok(value) => value,
64                        Err(error) => {
65                            return ResponseEncoder::error(error, true, binary).encode_response();
66                        }
67                    };
68                    )*
69
70                    let req = Request::from_parts(parts, body);
71
72                    let proto_req: TMReq = match decode_request_payload(req, state, binary, true).await {
73                        Ok(value) => value,
74                        Err(e) => return e,
75                    };
76
77                    // TODO: Support returning trailers (they would need to bundle in the error type).
78                    let mut stream = self($($ty,)* proto_req).await.map(RpcIntoResponse::rpc_into_response);
79                    ResponseEncoder::<TMRes>::stream(stream.boxed(), binary).encode_response()
80                })
81            }
82        }
83    };
84}
85
86impl_handler!([]);
87impl_handler!([T1]);
88impl_handler!([T1, T2]);
89impl_handler!([T1, T2, T3]);
90impl_handler!([T1, T2, T3, T4]);
91impl_handler!([T1, T2, T3, T4, T5]);
92impl_handler!([T1, T2, T3, T4, T5, T6]);
93impl_handler!([T1, T2, T3, T4, T5, T6, T7]);
94impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8]);
95impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9]);
96impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]);
97impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]);
98impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]);
99impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]);
100impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]);
101impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);