di_axum/
inject_keyed.rs

1use axum::http::StatusCode;
2use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
3use di::{KeyedRef, KeyedRefMut, ServiceProvider};
4use std::any::type_name;
5use std::convert::Infallible;
6
7/// Represents a container for an optional, injected, keyed service.
8#[derive(Clone, Debug)]
9pub struct TryInjectWithKey<TKey, TSvc: ?Sized + 'static>(pub Option<KeyedRef<TKey, TSvc>>);
10
11/// Represents a container for a required, injected, keyed service.
12#[derive(Clone, Debug)]
13pub struct InjectWithKey<TKey, TSvc: ?Sized + 'static>(pub KeyedRef<TKey, TSvc>);
14
15/// Represents a container for an optional, mutable, injected, keyed service.
16#[derive(Clone, Debug)]
17pub struct TryInjectWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub Option<KeyedRefMut<TKey, TSvc>>);
18
19/// Represents a container for a required, mutable, injected, keyed service.
20#[derive(Clone, Debug)]
21pub struct InjectWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub KeyedRefMut<TKey, TSvc>);
22
23/// Represents a container for a collection of injected, keyed services.
24#[derive(Clone, Debug)]
25pub struct InjectAllWithKey<TKey, TSvc: ?Sized + 'static>(pub Vec<KeyedRef<TKey, TSvc>>);
26
27/// Represents a container for a collection of mutable, injected, keyed services.
28#[derive(Clone, Debug)]
29pub struct InjectAllWithKeyMut<TKey, TSvc: ?Sized + 'static>(pub Vec<KeyedRefMut<TKey, TSvc>>);
30
31#[inline]
32fn unregistered_type_with_key<TKey, TSvc: ?Sized>() -> String {
33    format!(
34        "No service for type '{}' with the key '{}' has been registered.",
35        type_name::<TSvc>(),
36        type_name::<TKey>()
37    )
38}
39
40#[async_trait]
41impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKey<TKey, TSvc>
42where
43    TSvc: ?Sized + 'static,
44    S: Send + Sync,
45{
46    type Rejection = Infallible;
47
48    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
49        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
50            Ok(Self(provider.get_by_key::<TKey, TSvc>()))
51        } else {
52            Ok(Self(None))
53        }
54    }
55}
56
57#[async_trait]
58impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKey<TKey, TSvc>
59where
60    TSvc: ?Sized + 'static,
61    S: Send + Sync,
62{
63    type Rejection = (StatusCode, String);
64
65    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
66        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
67            if let Some(service) = provider.get_by_key::<TKey, TSvc>() {
68                return Ok(Self(service));
69            }
70        }
71
72        Err((
73            StatusCode::INTERNAL_SERVER_ERROR,
74            unregistered_type_with_key::<TKey, TSvc>(),
75        ))
76    }
77}
78
79#[async_trait]
80impl<TKey, TSvc, S> FromRequestParts<S> for TryInjectWithKeyMut<TKey, TSvc>
81where
82    TSvc: ?Sized + 'static,
83    S: Send + Sync,
84{
85    type Rejection = Infallible;
86
87    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
88        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
89            Ok(Self(provider.get_by_key_mut::<TKey, TSvc>()))
90        } else {
91            Ok(Self(None))
92        }
93    }
94}
95
96#[async_trait]
97impl<TKey, TSvc, S> FromRequestParts<S> for InjectWithKeyMut<TKey, TSvc>
98where
99    TSvc: ?Sized + 'static,
100    S: Send + Sync,
101{
102    type Rejection = (StatusCode, String);
103
104    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
105        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
106            if let Some(service) = provider.get_by_key_mut::<TKey, TSvc>() {
107                return Ok(Self(service));
108            }
109        }
110
111        Err((
112            StatusCode::INTERNAL_SERVER_ERROR,
113            unregistered_type_with_key::<TKey, TSvc>(),
114        ))
115    }
116}
117
118#[async_trait]
119impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKey<TKey, TSvc>
120where
121    TSvc: ?Sized + 'static,
122    S: Send + Sync,
123{
124    type Rejection = Infallible;
125
126    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
127        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
128            Ok(Self(provider.get_all_by_key::<TKey, TSvc>().collect()))
129        } else {
130            Ok(Self(Vec::with_capacity(0)))
131        }
132    }
133}
134
135#[async_trait]
136impl<TKey, TSvc, S> FromRequestParts<S> for InjectAllWithKeyMut<TKey, TSvc>
137where
138    TSvc: ?Sized + 'static,
139    S: Send + Sync,
140{
141    type Rejection = Infallible;
142
143    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
144        if let Some(provider) = parts.extensions.get::<ServiceProvider>() {
145            Ok(Self(provider.get_all_by_key_mut::<TKey, TSvc>().collect()))
146        } else {
147            Ok(Self(Vec::with_capacity(0)))
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::{RouterServiceProviderExtensions, TestClient};
156    use axum::{
157        routing::{get, post},
158        Router, extract::State,
159    };
160    use di::{injectable, Injectable, ServiceCollection};
161    use http::StatusCode;
162
163    mod key {
164        pub struct Basic;
165        pub struct Advanced;
166    }
167
168    #[tokio::test]
169    async fn request_should_fail_with_500_for_unregistered_service_with_key() {
170        // arrange
171        struct Service;
172
173        impl Service {
174            fn do_work(&self) -> String {
175                "Test".into()
176            }
177        }
178
179        async fn handler(InjectWithKey(service): InjectWithKey<key::Basic, Service>) -> String {
180            service.do_work()
181        }
182
183        let app = Router::new()
184            .route("/test", get(handler))
185            .with_provider(ServiceProvider::default());
186
187        let client = TestClient::new(app);
188
189        // act
190        let response = client.get("/test").send().await;
191
192        // assert
193        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
194    }
195
196    #[tokio::test]
197    async fn try_inject_with_key_into_handler() {
198        // arrange
199        #[injectable]
200        struct Service;
201
202        async fn handler(
203            TryInjectWithKey(_service): TryInjectWithKey<key::Advanced, Service>,
204        ) -> StatusCode {
205            StatusCode::NO_CONTENT
206        }
207
208        let app = Router::new()
209            .route("/test", post(handler))
210            .with_provider(ServiceProvider::default());
211
212        let client = TestClient::new(app);
213
214        // act
215        let response = client.post("/test").send().await;
216
217        // assert
218        assert_eq!(response.status(), StatusCode::NO_CONTENT);
219    }
220
221    #[tokio::test]
222    async fn inject_with_key_into_handler() {
223        // arrange
224        trait Service: Send + Sync {
225            fn do_work(&self) -> String;
226        }
227
228        #[injectable(Service)]
229        struct ServiceImpl;
230
231        impl Service for ServiceImpl {
232            fn do_work(&self) -> String {
233                "Test".into()
234            }
235        }
236
237        async fn handler(InjectWithKey(service): InjectWithKey<key::Basic, dyn Service>) -> String {
238            service.do_work()
239        }
240
241        let provider = ServiceCollection::new()
242            .add(ServiceImpl::scoped().with_key::<key::Basic>())
243            .build_provider()
244            .unwrap();
245
246        let app = Router::new()
247            .route("/test", get(handler))
248            .with_provider(provider);
249
250        let client = TestClient::new(app);
251
252        // act
253        let response = client.get("/test").send().await;
254        let text = response.text().await;
255
256        // assert
257        assert_eq!(&text, "Test");
258    }
259
260    #[tokio::test]
261    async fn inject_all_with_key_into_handler() {
262        // arrange
263        trait Thing: Send + Sync {}
264
265        #[injectable(Thing)]
266        struct Thing1;
267
268        #[injectable(Thing)]
269        struct Thing2;
270
271        #[injectable(Thing)]
272        struct Thing3;
273
274        impl Thing for Thing1 {}
275        impl Thing for Thing2 {}
276        impl Thing for Thing3 {}
277
278        async fn handler(
279            InjectAllWithKey(things): InjectAllWithKey<key::Basic, dyn Thing>,
280        ) -> String {
281            things.len().to_string()
282        }
283
284        let provider = ServiceCollection::new()
285            .try_add_to_all(Thing1::scoped().with_key::<key::Basic>())
286            .try_add_to_all(Thing2::scoped().with_key::<key::Basic>())
287            .try_add_to_all(Thing3::scoped().with_key::<key::Advanced>())
288            .build_provider()
289            .unwrap();
290
291        let app = Router::new()
292            .route("/test", get(handler))
293            .with_provider(provider);
294
295        let client = TestClient::new(app);
296
297        // act
298        let response = client.get("/test").send().await;
299        let text = response.text().await;
300
301        // assert
302        assert_eq!(&text, "2");
303    }
304
305    #[tokio::test]
306    async fn inject_with_key_and_state_into_handler() {
307        // arrange
308        trait Service: Send + Sync {
309            fn do_work(&self) -> String;
310        }
311
312        #[injectable(Service)]
313        struct ServiceImpl;
314
315        impl Service for ServiceImpl {
316            fn do_work(&self) -> String {
317                "Test".into()
318            }
319        }
320
321        #[derive(Clone)]
322        struct AppState;
323
324        async fn handler(
325            InjectWithKey(service): InjectWithKey<key::Basic, dyn Service>,
326            State(_state): State<AppState>) -> String {
327            service.do_work()
328        }
329
330        let provider = ServiceCollection::new()
331            .add(ServiceImpl::scoped().with_key::<key::Basic>())
332            .build_provider()
333            .unwrap();
334
335        let app = Router::new()
336            .route("/test", get(handler))
337            .with_state(AppState)
338            .with_provider(provider);
339
340        let client = TestClient::new(app);
341
342        // act
343        let response = client.get("/test").send().await;
344        let text = response.text().await;
345
346        // assert
347        assert_eq!(&text, "Test");
348    }
349}