1use axum::http::StatusCode;
2use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
3use di::{KeyedRef, KeyedRefMut, ServiceProvider};
4use std::any::type_name;
5use std::convert::Infallible;
6
7#[derive(Clone, Debug)]
9pub struct TryInjectWithKey<TKey, TSvc: ?Sized + 'static>(pub Option<KeyedRef<TKey, TSvc>>);
10
11#[derive(Clone, Debug)]
13pub struct InjectWithKey<TKey, TSvc: ?Sized + 'static>(pub KeyedRef<TKey, TSvc>);
14
15#[derive(Clone, Debug)]
17pub struct TryInjectWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub Option<KeyedRefMut<TKey, TSvc>>);
18
19#[derive(Clone, Debug)]
21pub struct InjectWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub KeyedRefMut<TKey, TSvc>);
22
23#[derive(Clone, Debug)]
25pub struct InjectAllWithKey<TKey, TSvc: ?Sized + 'static>(pub Vec<KeyedRef<TKey, TSvc>>);
26
27#[derive(Clone, Debug)]
29pub struct InjectAllWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub Vec<KeyedRefMut<TKey, TSvc>>);
30
31#[inline]
32fn unregistered_type_with_key<TKey, TSvc: ?Sized>() -> String {
33 format!(
34 "No service for type '{}' with the key '{}' has been registered.",
35 type_name::<TSvc>(),
36 type_name::<TKey>()
37 )
38}
39
40#[async_trait]
41impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKey<TKey, TSvc>
42where
43 TSvc: ?Sized + 'static,
44 S: Send + Sync,
45{
46 type Rejection = Infallible;
47
48 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
49 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
50 Ok(Self(provider.get_by_key::<TKey, TSvc>()))
51 } else {
52 Ok(Self(None))
53 }
54 }
55}
56
57#[async_trait]
58impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKey<TKey, TSvc>
59where
60 TSvc: ?Sized + 'static,
61 S: Send + Sync,
62{
63 type Rejection = (StatusCode, String);
64
65 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
67 if let Some(service) = provider.get_by_key::<TKey, TSvc>() {
68 return Ok(Self(service));
69 }
70 }
71
72 Err((
73 StatusCode::INTERNAL_SERVER_ERROR,
74 unregistered_type_with_key::<TKey, TSvc>(),
75 ))
76 }
77}
78
79#[async_trait]
80impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKeyMut<TKey, TSvc>
81where
82 TSvc: ?Sized + 'static,
83 S: Send + Sync,
84{
85 type Rejection = Infallible;
86
87 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
88 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
89 Ok(Self(provider.get_by_key_mut::<TKey, TSvc>()))
90 } else {
91 Ok(Self(None))
92 }
93 }
94}
95
96#[async_trait]
97impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKeyMut<TKey, TSvc>
98where
99 TSvc: ?Sized + 'static,
100 S: Send + Sync,
101{
102 type Rejection = (StatusCode, String);
103
104 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
105 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
106 if let Some(service) = provider.get_by_key_mut::<TKey, TSvc>() {
107 return Ok(Self(service));
108 }
109 }
110
111 Err((
112 StatusCode::INTERNAL_SERVER_ERROR,
113 unregistered_type_with_key::<TKey, TSvc>(),
114 ))
115 }
116}
117
118#[async_trait]
119impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKey<TKey, TSvc>
120where
121 TSvc: ?Sized + 'static,
122 S: Send + Sync,
123{
124 type Rejection = Infallible;
125
126 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
127 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
128 Ok(Self(provider.get_all_by_key::<TKey, TSvc>().collect()))
129 } else {
130 Ok(Self(Vec::with_capacity(0)))
131 }
132 }
133}
134
135#[async_trait]
136impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKeyMut<TKey, TSvc>
137where
138 TSvc: ?Sized + 'static,
139 S: Send + Sync,
140{
141 type Rejection = Infallible;
142
143 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
144 if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
145 Ok(Self(provider.get_all_by_key_mut::<TKey, TSvc>().collect()))
146 } else {
147 Ok(Self(Vec::with_capacity(0)))
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::{RouterServiceProviderExtensions, TestClient};
156 use axum::{
157 routing::{get, post},
158 Router, extract::State,
159 };
160 use di::{injectable, Injectable, ServiceCollection};
161 use http::StatusCode;
162
163 mod key {
164 pub struct Basic;
165 pub struct Advanced;
166 }
167
168 #[tokio::test]
169 async fn request_should_fail_with_500_for_unregistered_service_with_key() {
170 struct Service;
172
173 impl Service {
174 fn do_work(&self) -> String {
175 "Test".into()
176 }
177 }
178
179 async fn handler(InjectWithKey(service): InjectWithKey<key::Basic, Service>) -> String {
180 service.do_work()
181 }
182
183 let app = Router::new()
184 .route("/test", get(handler))
185 .with_provider(ServiceProvider::default());
186
187 let client = TestClient::new(app);
188
189 let response = client.get("/test").send().await;
191
192 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
194 }
195
196 #[tokio::test]
197 async fn try_inject_with_key_into_handler() {
198 #[injectable]
200 struct Service;
201
202 async fn handler(
203 TryInjectWithKey(_service): TryInjectWithKey<key::Advanced, Service>,
204 ) -> StatusCode {
205 StatusCode::NO_CONTENT
206 }
207
208 let app = Router::new()
209 .route("/test", post(handler))
210 .with_provider(ServiceProvider::default());
211
212 let client = TestClient::new(app);
213
214 let response = client.post("/test").send().await;
216
217 assert_eq!(response.status(), StatusCode::NO_CONTENT);
219 }
220
221 #[tokio::test]
222 async fn inject_with_key_into_handler() {
223 trait Service: Send + Sync {
225 fn do_work(&self) -> String;
226 }
227
228 #[injectable(Service)]
229 struct ServiceImpl;
230
231 impl Service for ServiceImpl {
232 fn do_work(&self) -> String {
233 "Test".into()
234 }
235 }
236
237 async fn handler(InjectWithKey(service): InjectWithKey<key::Basic, dyn Service>) -> String {
238 service.do_work()
239 }
240
241 let provider = ServiceCollection::new()
242 .add(ServiceImpl::scoped().with_key::<key::Basic>())
243 .build_provider()
244 .unwrap();
245
246 let app = Router::new()
247 .route("/test", get(handler))
248 .with_provider(provider);
249
250 let client = TestClient::new(app);
251
252 let response = client.get("/test").send().await;
254 let text = response.text().await;
255
256 assert_eq!(&text, "Test");
258 }
259
260 #[tokio::test]
261 async fn inject_all_with_key_into_handler() {
262 trait Thing: Send + Sync {}
264
265 #[injectable(Thing)]
266 struct Thing1;
267
268 #[injectable(Thing)]
269 struct Thing2;
270
271 #[injectable(Thing)]
272 struct Thing3;
273
274 impl Thing for Thing1 {}
275 impl Thing for Thing2 {}
276 impl Thing for Thing3 {}
277
278 async fn handler(
279 InjectAllWithKey(things): InjectAllWithKey<key::Basic, dyn Thing>,
280 ) -> String {
281 things.len().to_string()
282 }
283
284 let provider = ServiceCollection::new()
285 .try_add_to_all(Thing1::scoped().with_key::<key::Basic>())
286 .try_add_to_all(Thing2::scoped().with_key::<key::Basic>())
287 .try_add_to_all(Thing3::scoped().with_key::<key::Advanced>())
288 .build_provider()
289 .unwrap();
290
291 let app = Router::new()
292 .route("/test", get(handler))
293 .with_provider(provider);
294
295 let client = TestClient::new(app);
296
297 let response = client.get("/test").send().await;
299 let text = response.text().await;
300
301 assert_eq!(&text, "2");
303 }
304
305 #[tokio::test]
306 async fn inject_with_key_and_state_into_handler() {
307 trait Service: Send + Sync {
309 fn do_work(&self) -> String;
310 }
311
312 #[injectable(Service)]
313 struct ServiceImpl;
314
315 impl Service for ServiceImpl {
316 fn do_work(&self) -> String {
317 "Test".into()
318 }
319 }
320
321 #[derive(Clone)]
322 struct AppState;
323
324 async fn handler(
325 InjectWithKey(service): InjectWithKey<key::Basic, dyn Service>,
326 State(_state): State<AppState>) -> String {
327 service.do_work()
328 }
329
330 let provider = ServiceCollection::new()
331 .add(ServiceImpl::scoped().with_key::<key::Basic>())
332 .build_provider()
333 .unwrap();
334
335 let app = Router::new()
336 .route("/test", get(handler))
337 .with_state(AppState)
338 .with_provider(provider);
339
340 let client = TestClient::new(app);
341
342 let response = client.get("/test").send().await;
344 let text = response.text().await;
345
346 assert_eq!(&text, "Test");
348 }
349}