1use std::any::type_name;
7use std::ops::Deref;
8use std::sync::Arc;
9
10use nest_rs_core::{Container, RequestScope};
11use poem::http::StatusCode;
12use poem::{Endpoint, Error, FromRequest, IntoResponse, Request, RequestBody, Response, Result};
13
14pub struct RequestScopeEndpoint<E> {
19 inner: E,
20 container: Container,
21}
22
23impl<E> RequestScopeEndpoint<E> {
24 pub fn new(inner: E, container: Container) -> Self {
25 Self { inner, container }
26 }
27}
28
29impl<E> Endpoint for RequestScopeEndpoint<E>
30where
31 E: Endpoint,
32 E::Output: IntoResponse,
33{
34 type Output = Response;
35
36 async fn call(&self, mut req: Request) -> Result<Self::Output> {
37 req.extensions_mut()
38 .insert(Arc::new(RequestScope::new(self.container.clone())));
39 self.inner.call(req).await.map(IntoResponse::into_response)
40 }
41}
42
43pub struct Scoped<T>(pub Arc<T>);
47
48impl<T> Scoped<T> {
49 pub fn into_inner(self) -> Arc<T> {
50 self.0
51 }
52}
53
54impl<T> Deref for Scoped<T> {
55 type Target = T;
56 fn deref(&self) -> &T {
57 &self.0
58 }
59}
60
61impl<'a, T: Send + Sync + 'static> FromRequest<'a> for Scoped<T> {
62 async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
63 let scope = req.extensions().get::<Arc<RequestScope>>().ok_or_else(|| {
64 Error::from_string(
65 "request scope not installed — RequestScopeEndpoint must wrap the route tree",
66 StatusCode::INTERNAL_SERVER_ERROR,
67 )
68 })?;
69 match scope.get::<T>() {
70 Some(value) => Ok(Scoped(value)),
71 None => Err(Error::from_string(
72 format!(
73 "no provider registered for `{}` — add it to a module's providers",
74 type_name::<T>()
75 ),
76 StatusCode::INTERNAL_SERVER_ERROR,
77 )),
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use poem::Body;
86 use poem::handler;
87
88 struct Marker(&'static str);
89
90 #[test]
91 fn scoped_into_inner_yields_the_arc() {
92 let value: Arc<Marker> = Arc::new(Marker("hi"));
93 let scoped = Scoped(value.clone());
94 let inner = scoped.into_inner();
95 assert!(Arc::ptr_eq(&inner, &value));
96 }
97
98 #[test]
99 fn scoped_deref_borrows_the_inner_value() {
100 let scoped = Scoped(Arc::new(Marker("bye")));
101 assert_eq!(scoped.0.as_ref().0, "bye");
102 assert_eq!((*scoped).0, "bye");
104 }
105
106 #[handler]
107 async fn observe(req: &Request) -> &'static str {
108 assert!(
109 req.extensions().get::<Arc<RequestScope>>().is_some(),
110 "RequestScopeEndpoint installed an Arc<RequestScope> per request",
111 );
112 "ok"
113 }
114
115 #[tokio::test]
116 async fn endpoint_installs_a_request_scope_into_the_request_extensions() {
117 let container = Container::builder().build();
118 let endpoint = RequestScopeEndpoint::new(observe, container);
119
120 let req = Request::builder().body(Body::empty());
121 let resp = endpoint.call(req).await.expect("handler runs");
122 assert_eq!(resp.status(), StatusCode::OK);
123 }
124
125 #[tokio::test]
126 async fn endpoint_propagates_the_inner_response_unchanged() {
127 let container = Container::builder().build();
130 let endpoint = RequestScopeEndpoint::new(observe, container);
131
132 let resp = endpoint
133 .call(Request::builder().body(Body::empty()))
134 .await
135 .expect("ok");
136 assert_eq!(resp.status(), StatusCode::OK);
137 let bytes = resp.into_body().into_bytes().await.expect("body");
138 assert_eq!(bytes.as_ref(), b"ok");
139 }
140
141 #[tokio::test]
142 async fn scoped_from_request_resolves_a_registered_provider() {
143 let container = Container::builder().provide(Marker("registered")).build();
146 let scope = Arc::new(RequestScope::new(container));
147
148 let mut req = Request::default();
149 req.extensions_mut().insert(scope);
150 let (req, mut body) = req.split();
151
152 let scoped: Scoped<Marker> = Scoped::from_request(&req, &mut body)
153 .await
154 .expect("resolves via singleton fallback");
155 assert_eq!(scoped.0.0, "registered");
156 }
157
158 #[tokio::test]
159 async fn scoped_from_request_returns_500_when_no_scope_is_installed() {
160 let req = Request::default();
161 let (req, mut body) = req.split();
162
163 let err = match Scoped::<Marker>::from_request(&req, &mut body).await {
164 Ok(_) => panic!("no Arc<RequestScope> in extensions should reject"),
165 Err(e) => e,
166 };
167 let resp = err.into_response();
168 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
169 let bytes = resp.into_body().into_bytes().await.expect("body");
170 let text = String::from_utf8_lossy(&bytes);
171 assert!(
172 text.contains("request scope not installed"),
173 "diagnostic mentions the wiring bug: {text}",
174 );
175 }
176
177 #[tokio::test]
178 async fn scoped_from_request_returns_500_when_no_provider_is_registered() {
179 let container = Container::builder().build();
181 let scope = Arc::new(RequestScope::new(container));
182
183 let mut req = Request::default();
184 req.extensions_mut().insert(scope);
185 let (req, mut body) = req.split();
186
187 let err = match Scoped::<Marker>::from_request(&req, &mut body).await {
188 Ok(_) => panic!("no provider for Marker should reject"),
189 Err(e) => e,
190 };
191 let resp = err.into_response();
192 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
193 let bytes = resp.into_body().into_bytes().await.expect("body");
194 let text = String::from_utf8_lossy(&bytes);
195 assert!(
196 text.contains("no provider registered for"),
197 "diagnostic names the missing type: {text}",
198 );
199 assert!(text.contains("Marker"), "the type name surfaces: {text}");
200 }
201}