Skip to main content

nest_rs_http/
scope.rs

1//! HTTP binding for request-scoped providers — [`RequestScopeEndpoint`] installs
2//! a fresh [`RequestScope`] per request; [`Scoped<T>`] reads it back to resolve
3//! an `#[injectable(scope = request)]` provider (or, falling through, a
4//! singleton — prefer plain `#[inject]` for those).
5
6use std::any::type_name;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use nest_rs_core::{Container, RequestScope};
11use poem::http::StatusCode;
12use poem::{Endpoint, Error, FromRequest, IntoResponse, Request, RequestBody, Response, Result};
13
14/// Installs a fresh [`RequestScope`] (over the singleton container) into each
15/// request's extensions before delegating inward, so guards and handlers can
16/// resolve request-scoped providers via [`Scoped<T>`]. Applied outermost by
17/// [`HttpTransport`](crate::HttpTransport).
18pub struct RequestScopeEndpoint<E> {
19    inner: E,
20    container: Container,
21}
22
23impl<E> RequestScopeEndpoint<E> {
24    pub fn new(inner: E, container: Container) -> Self {
25        Self { inner, container }
26    }
27}
28
29impl<E> Endpoint for RequestScopeEndpoint<E>
30where
31    E: Endpoint,
32    E::Output: IntoResponse,
33{
34    type Output = Response;
35
36    async fn call(&self, mut req: Request) -> Result<Self::Output> {
37        req.extensions_mut()
38            .insert(Arc::new(RequestScope::new(self.container.clone())));
39        self.inner.call(req).await.map(IntoResponse::into_response)
40    }
41}
42
43/// Resolves a provider of type `T` from the current request's
44/// [`RequestScope`]. Rejects with `500` if the scope is absent (a transport
45/// wiring bug) or if no provider is registered for `T`.
46pub struct Scoped<T>(pub Arc<T>);
47
48impl<T> Scoped<T> {
49    pub fn into_inner(self) -> Arc<T> {
50        self.0
51    }
52}
53
54impl<T> Deref for Scoped<T> {
55    type Target = T;
56    fn deref(&self) -> &T {
57        &self.0
58    }
59}
60
61impl<'a, T: Send + Sync + 'static> FromRequest<'a> for Scoped<T> {
62    async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
63        let scope = req.extensions().get::<Arc<RequestScope>>().ok_or_else(|| {
64            Error::from_string(
65                "request scope not installed — RequestScopeEndpoint must wrap the route tree",
66                StatusCode::INTERNAL_SERVER_ERROR,
67            )
68        })?;
69        match scope.get::<T>() {
70            Some(value) => Ok(Scoped(value)),
71            None => Err(Error::from_string(
72                format!(
73                    "no provider registered for `{}` — add it to a module's providers",
74                    type_name::<T>()
75                ),
76                StatusCode::INTERNAL_SERVER_ERROR,
77            )),
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use poem::Body;
86    use poem::handler;
87
88    struct Marker(&'static str);
89
90    #[test]
91    fn scoped_into_inner_yields_the_arc() {
92        let value: Arc<Marker> = Arc::new(Marker("hi"));
93        let scoped = Scoped(value.clone());
94        let inner = scoped.into_inner();
95        assert!(Arc::ptr_eq(&inner, &value));
96    }
97
98    #[test]
99    fn scoped_deref_borrows_the_inner_value() {
100        let scoped = Scoped(Arc::new(Marker("bye")));
101        assert_eq!(scoped.0.as_ref().0, "bye");
102        // Deref reaches the field through `&*scoped`.
103        assert_eq!((*scoped).0, "bye");
104    }
105
106    #[handler]
107    async fn observe(req: &Request) -> &'static str {
108        assert!(
109            req.extensions().get::<Arc<RequestScope>>().is_some(),
110            "RequestScopeEndpoint installed an Arc<RequestScope> per request",
111        );
112        "ok"
113    }
114
115    #[tokio::test]
116    async fn endpoint_installs_a_request_scope_into_the_request_extensions() {
117        let container = Container::builder().build();
118        let endpoint = RequestScopeEndpoint::new(observe, container);
119
120        let req = Request::builder().body(Body::empty());
121        let resp = endpoint.call(req).await.expect("handler runs");
122        assert_eq!(resp.status(), StatusCode::OK);
123    }
124
125    #[tokio::test]
126    async fn endpoint_propagates_the_inner_response_unchanged() {
127        // `IntoResponse::into_response` is invoked on the inner endpoint output —
128        // a plain `&str` becomes a 200 with that body.
129        let container = Container::builder().build();
130        let endpoint = RequestScopeEndpoint::new(observe, container);
131
132        let resp = endpoint
133            .call(Request::builder().body(Body::empty()))
134            .await
135            .expect("ok");
136        assert_eq!(resp.status(), StatusCode::OK);
137        let bytes = resp.into_body().into_bytes().await.expect("body");
138        assert_eq!(bytes.as_ref(), b"ok");
139    }
140
141    #[tokio::test]
142    async fn scoped_from_request_resolves_a_registered_provider() {
143        // A singleton falls through `RequestScope::get`, the documented escape
144        // hatch for `Scoped<T>` when no scoped factory exists for `T`.
145        let container = Container::builder().provide(Marker("registered")).build();
146        let scope = Arc::new(RequestScope::new(container));
147
148        let mut req = Request::default();
149        req.extensions_mut().insert(scope);
150        let (req, mut body) = req.split();
151
152        let scoped: Scoped<Marker> = Scoped::from_request(&req, &mut body)
153            .await
154            .expect("resolves via singleton fallback");
155        assert_eq!(scoped.0.0, "registered");
156    }
157
158    #[tokio::test]
159    async fn scoped_from_request_returns_500_when_no_scope_is_installed() {
160        let req = Request::default();
161        let (req, mut body) = req.split();
162
163        let err = match Scoped::<Marker>::from_request(&req, &mut body).await {
164            Ok(_) => panic!("no Arc<RequestScope> in extensions should reject"),
165            Err(e) => e,
166        };
167        let resp = err.into_response();
168        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
169        let bytes = resp.into_body().into_bytes().await.expect("body");
170        let text = String::from_utf8_lossy(&bytes);
171        assert!(
172            text.contains("request scope not installed"),
173            "diagnostic mentions the wiring bug: {text}",
174        );
175    }
176
177    #[tokio::test]
178    async fn scoped_from_request_returns_500_when_no_provider_is_registered() {
179        // Scope installed but `Marker` was never provided.
180        let container = Container::builder().build();
181        let scope = Arc::new(RequestScope::new(container));
182
183        let mut req = Request::default();
184        req.extensions_mut().insert(scope);
185        let (req, mut body) = req.split();
186
187        let err = match Scoped::<Marker>::from_request(&req, &mut body).await {
188            Ok(_) => panic!("no provider for Marker should reject"),
189            Err(e) => e,
190        };
191        let resp = err.into_response();
192        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
193        let bytes = resp.into_body().into_bytes().await.expect("body");
194        let text = String::from_utf8_lossy(&bytes);
195        assert!(
196            text.contains("no provider registered for"),
197            "diagnostic names the missing type: {text}",
198        );
199        assert!(text.contains("Marker"), "the type name surfaces: {text}");
200    }
201}