ajars_axum/
lib.rs

1use ::axum::{
2    body::Body,
3    extract::{self, FromRequest},
4    response::IntoResponse,
5    routing::{delete, get, post, put},
6    Json, Router,
7};
8use ajars_core::{HttpMethod, RestType};
9use serde::{de::DeserializeOwned, Serialize};
10use std::future::Future;
11
12pub mod axum {
13    pub use axum::*;
14}
15
16pub trait AxumHandler<I: Serialize + DeserializeOwned, O: Serialize + DeserializeOwned, T, H> {
17    fn to(&self, handler: H) -> Router;
18}
19
20macro_rules! factory_tuple ({ $($param:ident)* } => {
21    #[allow(non_snake_case)]
22    impl <I: Serialize + DeserializeOwned + Send + 'static, O: Serialize + DeserializeOwned + Send + 'static, H, R, E, REST: RestType<I, O>, $($param,)*> AxumHandler<I, O, ($($param,)*), H>
23    for REST
24    where
25    R: Future<Output = Result<O, E>> + Send,
26    E: IntoResponse + Send + 'static,
27    H: 'static + Send + Sync + Clone + Fn(I, $($param,)*) -> R,
28    $( $param: FromRequest<Body> + Send + 'static, )*
29    {
30        fn to(&self, handler: H) -> Router {
31            let route = match self.method() {
32                HttpMethod::DELETE => Router::new().route(self.path(), delete(
33                    |payload: extract::Query<I>, $( $param: $param,)*| async move {
34                        (handler)(payload.0, $( $param,)*).await.map(Json)
35                })),
36                HttpMethod::GET => Router::new().route(self.path(), get(
37                    |payload: extract::Query<I>, $( $param: $param,)*| async move {
38                        (handler)(payload.0, $( $param,)*).await.map(Json)
39                    })),
40                HttpMethod::POST => Router::new().route(self.path(), post(
41                    |payload: Json<I>, $( $param: $param,)*| async move {
42                        (handler)(payload.0, $( $param,)*).await.map(Json)
43                    })),
44                HttpMethod::PUT => Router::new().route(self.path(), put(
45                    |payload: Json<I>, $( $param: $param,)*| async move {
46                        (handler)(payload.0, $( $param,)*).await.map(Json)
47                    })),
48            };
49
50            route
51        }
52    }
53});
54
55//
56// MODEL FN USED FOR CREATING THE MACRO
57//
58// impl <I: Serialize + DeserializeOwned + Send + 'static, O: Serialize + DeserializeOwned + Send + 'static, H, R, E, REST: RestType<I, O>, P> AxumHandler<I, O, P, H>
59// for REST
60// where
61// R: Future<Output = Result<O, E>> + Send,
62// E: IntoResponse + Send + 'static,
63// H: 'static + Send + Sync + Clone + Fn(I, String) -> R,
64// P: FromRequest<Body> + Send + 'static,
65// {
66//     fn to(&self, handler: H) -> Router {
67//         let route = match self.method() {
68//             HttpMethod::DELETE => Router::new().route(self.path(), delete(
69//                 |payload: extract::Query<I>, p: P| async move {
70//                     (handler)(payload.0, p).await.map(Json)
71//             })),
72//             HttpMethod::GET => Router::new().route(self.path(), get(
73//                 |payload: extract::Query<I>, p: P| async move {
74//                     (handler)(payload.0, p).await.map(Json)
75//                 })),
76//             HttpMethod::POST => Router::new().route(self.path(), post(
77//                 |payload: Json<I>, p: P| async move {
78//                     (handler)(payload.0, p).await.map(Json)
79//                 })),
80//             HttpMethod::PUT => Router::new().route(self.path(), put(
81//                 |payload: Json<I>, p: P| async move {
82//                     (handler)(payload.0, p).await.map(Json)
83//                 })),
84//         };
85//         route
86//     }
87// }
88
89factory_tuple! {}
90factory_tuple! { P0 }
91factory_tuple! { P0 P1 }
92factory_tuple! { P0 P1 P2 }
93factory_tuple! { P0 P1 P2 P3 }
94factory_tuple! { P0 P1 P2 P3 P4 }
95factory_tuple! { P0 P1 P2 P3 P4 P5 }
96factory_tuple! { P0 P1 P2 P3 P4 P5 P6 }
97factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 }
98factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 P8 }
99factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 P8 P9 }
100
101#[cfg(test)]
102mod tests {
103
104    use std::fmt::Display;
105
106    use super::*;
107    use ::axum::{
108        body::{Body, BoxBody},
109        extract::Extension,
110        http::{header, Method, Request, Response, StatusCode},
111    };
112    use ajars_core::RestFluent;
113    use serde::{Deserialize, Serialize};
114    use tower::ServiceExt; // for `app.oneshot()`
115
116    #[derive(Serialize, Deserialize, Debug)]
117    pub struct PingRequest {
118        pub message: String,
119    }
120
121    #[derive(Serialize, Deserialize, Debug)]
122    pub struct PingResponse {
123        pub message: String,
124    }
125
126    async fn ping(body: PingRequest, _data: Extension<()>) -> Result<PingResponse, ServerError> {
127        Ok(PingResponse { message: body.message })
128    }
129
130    #[derive(Debug, Clone)]
131    struct ServerError {}
132
133    impl Display for ServerError {
134        fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135            Ok(())
136        }
137    }
138
139    impl IntoResponse for ServerError {
140        fn into_response(self) -> Response<BoxBody> {
141            Response::new(axum::body::boxed(Body::empty()))
142        }
143    }
144
145    #[tokio::test]
146    async fn should_create_a_delete_endpoint() {
147        // Arrange
148        let rest =
149            RestFluent::<PingRequest, PingResponse>::delete(format!("/api/something/{}", rand::random::<usize>()));
150
151        let app = rest.to(ping).layer(Extension(()));
152
153        let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
154
155        // Act
156        let response = app
157            .oneshot(
158                Request::builder()
159                    .method(Method::DELETE)
160                    .header(header::CONTENT_TYPE, "application/json")
161                    .uri(&format!("{}?message={}", rest.path(), payload.message))
162                    .body(Body::empty())
163                    .unwrap(),
164            )
165            .await
166            .unwrap();
167
168        // Assert
169        assert_eq!(response.status(), StatusCode::OK);
170        assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
171
172        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
173        let body: PingResponse = serde_json::from_slice(&body).unwrap();
174
175        // Assert
176        assert_eq!(body.message, payload.message);
177    }
178
179    #[tokio::test]
180    async fn should_create_a_get_endpoint() {
181        // Arrange
182        let rest = RestFluent::<PingRequest, PingResponse>::get(format!("/api/something/{}", rand::random::<usize>()));
183
184        let app = rest.to(ping).layer(Extension(()));
185
186        let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
187
188        // Act
189        let response = app
190            .oneshot(
191                Request::builder()
192                    .method(Method::GET)
193                    .header(header::CONTENT_TYPE, "application/json")
194                    .uri(&format!("{}?message={}", rest.path(), payload.message))
195                    .body(Body::empty())
196                    .unwrap(),
197            )
198            .await
199            .unwrap();
200
201        // Assert
202        assert_eq!(response.status(), StatusCode::OK);
203        assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
204
205        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
206        let body: PingResponse = serde_json::from_slice(&body).unwrap();
207
208        // Assert
209        assert_eq!(body.message, payload.message);
210    }
211
212    #[tokio::test]
213    async fn should_create_a_post_endpoint() {
214        // Arrange
215        let rest = RestFluent::<PingRequest, PingResponse>::post(format!("/api/something/{}", rand::random::<usize>()));
216
217        let app = rest.to(ping).layer(Extension(()));
218
219        let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
220
221        // Act
222        let response = app
223            .oneshot(
224                Request::builder()
225                    .method(Method::POST)
226                    .header(header::CONTENT_TYPE, "application/json")
227                    .uri(rest.path())
228                    .body(Body::from(serde_json::to_vec(&payload).unwrap()))
229                    .unwrap(),
230            )
231            .await
232            .unwrap();
233
234        // Assert
235        assert_eq!(response.status(), StatusCode::OK);
236        assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
237
238        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
239        let body: PingResponse = serde_json::from_slice(&body).unwrap();
240
241        // Assert
242        assert_eq!(body.message, payload.message);
243    }
244
245    #[tokio::test]
246    async fn should_create_a_put_endpoint() {
247        // Arrange
248        let rest = RestFluent::<PingRequest, PingResponse>::put(format!("/api/something/{}", rand::random::<usize>()));
249
250        let app = rest.to(ping).layer(Extension(()));
251
252        let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
253
254        // Act
255        let response = app
256            .oneshot(
257                Request::builder()
258                    .method(Method::PUT)
259                    .header(header::CONTENT_TYPE, "application/json")
260                    .uri(rest.path())
261                    .body(Body::from(serde_json::to_vec(&payload).unwrap()))
262                    .unwrap(),
263            )
264            .await
265            .unwrap();
266
267        // Assert
268        assert_eq!(response.status(), StatusCode::OK);
269        assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
270
271        let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
272        let body: PingResponse = serde_json::from_slice(&body).unwrap();
273
274        // Assert
275        assert_eq!(body.message, payload.message);
276    }
277
278    #[tokio::test]
279    async fn route_should_accept_variable_number_of_params() {
280        // Arrange
281        let rest =
282            RestFluent::<PingRequest, PingResponse>::delete(format!("/api/something/{}", rand::random::<usize>()));
283
284        // Accept 1 param
285        rest.to(|body: PingRequest| async { Result::<_, ServerError>::Ok(PingResponse { message: body.message }) });
286
287        // Accept 2 param
288        rest.to(|body: PingRequest, _: Extension<()>| async {
289            Result::<_, ServerError>::Ok(PingResponse { message: body.message })
290        });
291
292        // Accept 3 param
293        rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>| async {
294            Result::<_, ServerError>::Ok(PingResponse { message: body.message })
295        });
296
297        // Accept 4 param
298        rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>, _: Request<Body>| async {
299            Result::<_, ServerError>::Ok(PingResponse { message: body.message })
300        });
301
302        // Accept 5 param
303        rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>, _: Request<Body>, _: Request<Body>| async {
304            Result::<_, ServerError>::Ok(PingResponse { message: body.message })
305        });
306    }
307}