lambda_grpc_web/
lambda_server_builder.rs

1#[cfg(feature = "deadline")]
2use crate::deadline_layer::LambdaDeadlineLayer;
3#[cfg(feature = "wire-log")]
4use crate::wire_log::WireLogLayer;
5use http::{Request, Response};
6use lambda_runtime::Error;
7use std::any::Any;
8use std::convert::Infallible;
9use std::time::Duration;
10use tonic::body::Body;
11use tonic::server::NamedService;
12use tonic::service::Routes;
13use tonic::Status;
14use tonic_web::GrpcWebLayer;
15use tower::layer::util::{Identity, Stack};
16use tower::{Layer, Service, ServiceBuilder};
17use tower_http::catch_panic::CatchPanicLayer;
18
19type GrpcRequest = Request<Body>;
20type GrpcResponse = Response<Body>;
21
22#[derive(Clone)]
23pub struct LambdaServer<L = Identity> {
24    service_builder: ServiceBuilder<L>,
25}
26
27impl LambdaServer {
28    pub fn builder() -> Self {
29        Self {
30            service_builder: ServiceBuilder::new(),
31        }
32    }
33}
34
35pub struct LambdaRouter<L> {
36    routes: Routes,
37    service_builder: ServiceBuilder<L>,
38}
39
40impl<L> LambdaServer<L> {
41    pub fn layer<NewLayer>(self, new_layer: NewLayer) -> LambdaServer<Stack<NewLayer, L>> {
42        LambdaServer {
43            service_builder: self.service_builder.layer(new_layer),
44        }
45    }
46
47    pub fn add_service<S>(self, svc: S) -> LambdaRouter<L>
48    where
49        S: Service<Request<Body>, Error = Infallible>
50        + NamedService
51        + Clone
52        + Send
53        + Sync
54        + 'static,
55        S::Response: axum::response::IntoResponse,
56        S::Future: Send + 'static,
57        L: Clone,
58    {
59        LambdaRouter {
60            routes: Routes::new(svc),
61            service_builder: self.service_builder,
62        }
63    }
64}
65
66impl<L> LambdaRouter<L> {
67    pub fn add_service<S>(mut self, svc: S) -> Self
68    where
69        S: Service<Request<Body>, Error = Infallible>
70        + NamedService
71        + Clone
72        + Send
73        + Sync
74        + 'static,
75        S::Response: axum::response::IntoResponse,
76        S::Future: Send + 'static,
77    {
78        self.routes = self.routes.add_service(svc);
79        self
80    }
81
82    pub async fn serve(self) -> Result<(), Error>
83    where
84        L: Layer<Routes>,
85        L::Service: Service<
86                GrpcRequest,
87                Response = GrpcResponse,
88                Error = Infallible,
89                Future: Send + 'static,
90            > + Clone
91            + Send
92            + 'static,
93    {
94        let service_builder = ServiceBuilder::new();
95
96        #[cfg(feature = "wire-log")]
97        let service_builder = service_builder.layer(WireLogLayer);
98
99        let service_builder = service_builder.layer(GrpcWebLayer::new());
100
101        #[cfg(feature = "catch-panic")]
102        let service_builder = service_builder.layer(CatchPanicLayer::custom(
103            |err: Box<dyn Any + Send + 'static>| {
104                let details = if let Some(s) = err.downcast_ref::<String>() {
105                    s.clone()
106                } else if let Some(s) = err.downcast_ref::<&str>() {
107                    s.to_string()
108                } else {
109                    "Unknown panic message".to_string()
110                };
111
112                Status::internal(details).into_http::<Body>()
113            },
114        ));
115
116        #[cfg(feature = "deadline")]
117        let service_builder =
118            service_builder.layer(LambdaDeadlineLayer::new(Duration::from_millis(500)));
119
120        let svc = service_builder.service(self.service_builder.service(self.routes));
121
122        let handler = tower::service_fn(move |req: lambda_http::Request| {
123            let mut svc = svc.clone();
124            async move {
125                let req = req.map(|body| Body::new(tonic::service::AxumBody::new(body)));
126                let res = svc.call(req).await.expect("infallible");
127                let (parts, body) = res.into_parts();
128                let body =
129                    lambda_runtime::streaming::Body::new(body);
130                Ok::<_, Error>(Response::from_parts(parts, body))
131            }
132        });
133
134        lambda_http::run_with_streaming_response(handler).await
135    }
136}