rama_core/layer/
layer_fn.rs

1use super::Layer;
2use std::fmt;
3
4/// Returns a new [`LayerFn`] that implements [`Layer`] by calling the
5/// given function.
6///
7/// The [`Layer::layer`] method takes a type implementing [`Service`] and
8/// returns a different type implementing [`Service`]. In many cases, this can
9/// be implemented by a function or a closure. The [`LayerFn`] helper allows
10/// writing simple [`Layer`] implementations without needing the boilerplate of
11/// a new struct implementing [`Layer`].
12///
13/// [`Service`]: crate
14/// [`Layer::layer`]: crate::Layer::layer
15pub fn layer_fn<T>(f: T) -> LayerFn<T> {
16    LayerFn { f }
17}
18
19/// A `Layer` implemented by a closure. See the docs for [`layer_fn`] for more details.
20pub struct LayerFn<F> {
21    f: F,
22}
23
24impl<F, S, Out> Layer<S> for LayerFn<F>
25where
26    F: FnOnce(S) -> Out + Clone,
27{
28    type Service = Out;
29
30    fn layer(&self, inner: S) -> Self::Service {
31        (self.f.clone())(inner)
32    }
33
34    fn into_layer(self, inner: S) -> Self::Service {
35        (self.f)(inner)
36    }
37}
38
39impl<F> Clone for LayerFn<F>
40where
41    F: Clone,
42{
43    fn clone(&self) -> Self {
44        Self { f: self.f.clone() }
45    }
46}
47
48impl<F> fmt::Debug for LayerFn<F> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("LayerFn")
51            .field("f", &format_args!("<{}>", std::any::type_name::<F>()))
52            .finish()
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    /// This test shows how one can make a LayerFn that wraps a service.
61    /// Due to the immature state of Async Rust, possibly combined with the usage of the current type resolver,
62    /// it is at the moment not possible to use closures for `layer_fn` as it cannot infer the type of the inner service.
63    /// One can probably try to declare it explicitly, but that can get unwieldy very quickly,
64    /// and has pretty poor UX.
65    ///
66    /// Therefore the approach as shown in this test is probably also the only approach that we should document,
67    /// for users that want to declare a Layer without implementing the Layer trait explicitly themselves.
68    #[tokio::test]
69    async fn test_layer_fn() {
70        use crate::{Context, Service, service::service_fn};
71        use std::convert::Infallible;
72
73        struct ToUpper<S>(S);
74
75        impl<S> fmt::Debug for ToUpper<S>
76        where
77            S: fmt::Debug,
78        {
79            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80                f.debug_tuple("ToUpper").field(&self.0).finish()
81            }
82        }
83
84        impl<S> Clone for ToUpper<S>
85        where
86            S: Clone,
87        {
88            fn clone(&self) -> Self {
89                Self(self.0.clone())
90            }
91        }
92
93        impl<S, State, Request> Service<State, Request> for ToUpper<S>
94        where
95            Request: Send + 'static,
96            S: Service<State, Request, Response = &'static str>,
97            State: Clone + Send + Sync + 'static,
98        {
99            type Response = String;
100            type Error = S::Error;
101
102            async fn serve(
103                &self,
104                ctx: Context<State>,
105                req: Request,
106            ) -> Result<Self::Response, Self::Error> {
107                let res = self.0.serve(ctx, req).await;
108                res.map(|msg| msg.to_uppercase())
109            }
110        }
111
112        let layer = layer_fn(ToUpper);
113        let f = async |_, req| Ok::<_, Infallible>(req);
114
115        let res = layer
116            .layer(service_fn(f))
117            .serve(Context::default(), "hello")
118            .await;
119        assert_eq!(res, Ok("HELLO".to_owned()));
120
121        // can be cloned the layer, and the service
122        let svc = layer.layer(service_fn(f));
123        let res = svc.serve(Context::default(), "hello").await;
124        assert_eq!(res, Ok("HELLO".to_owned()));
125        let res = svc.clone().serve(Context::default(), "hello").await;
126        assert_eq!(res, Ok("HELLO".to_owned()));
127    }
128
129    #[allow(dead_code)]
130    #[test]
131    fn layer_fn_has_useful_debug_impl() {
132        struct WrappedService<S> {
133            inner: S,
134        }
135        let layer = layer_fn(|svc| WrappedService { inner: svc });
136        let _svc = layer.layer("foo");
137
138        assert_eq!(
139            "LayerFn { f: <rama_core::layer::layer_fn::tests::layer_fn_has_useful_debug_impl::{{closure}}> }".to_owned(),
140            format!("{:?}", layer),
141        );
142    }
143}