1use ::axum::{
2 body::Body,
3 extract::{self, FromRequest},
4 response::IntoResponse,
5 routing::{delete, get, post, put},
6 Json, Router,
7};
8use ajars_core::{HttpMethod, RestType};
9use serde::{de::DeserializeOwned, Serialize};
10use std::future::Future;
11
12pub mod axum {
13 pub use axum::*;
14}
15
16pub trait AxumHandler<I: Serialize + DeserializeOwned, O: Serialize + DeserializeOwned, T, H> {
17 fn to(&self, handler: H) -> Router;
18}
19
20macro_rules! factory_tuple ({ $($param:ident)* } => {
21 #[allow(non_snake_case)]
22 impl <I: Serialize + DeserializeOwned + Send + 'static, O: Serialize + DeserializeOwned + Send + 'static, H, R, E, REST: RestType<I, O>, $($param,)*> AxumHandler<I, O, ($($param,)*), H>
23 for REST
24 where
25 R: Future<Output = Result<O, E>> + Send,
26 E: IntoResponse + Send + 'static,
27 H: 'static + Send + Sync + Clone + Fn(I, $($param,)*) -> R,
28 $( $param: FromRequest<Body> + Send + 'static, )*
29 {
30 fn to(&self, handler: H) -> Router {
31 let route = match self.method() {
32 HttpMethod::DELETE => Router::new().route(self.path(), delete(
33 |payload: extract::Query<I>, $( $param: $param,)*| async move {
34 (handler)(payload.0, $( $param,)*).await.map(Json)
35 })),
36 HttpMethod::GET => Router::new().route(self.path(), get(
37 |payload: extract::Query<I>, $( $param: $param,)*| async move {
38 (handler)(payload.0, $( $param,)*).await.map(Json)
39 })),
40 HttpMethod::POST => Router::new().route(self.path(), post(
41 |payload: Json<I>, $( $param: $param,)*| async move {
42 (handler)(payload.0, $( $param,)*).await.map(Json)
43 })),
44 HttpMethod::PUT => Router::new().route(self.path(), put(
45 |payload: Json<I>, $( $param: $param,)*| async move {
46 (handler)(payload.0, $( $param,)*).await.map(Json)
47 })),
48 };
49
50 route
51 }
52 }
53});
54
55factory_tuple! {}
90factory_tuple! { P0 }
91factory_tuple! { P0 P1 }
92factory_tuple! { P0 P1 P2 }
93factory_tuple! { P0 P1 P2 P3 }
94factory_tuple! { P0 P1 P2 P3 P4 }
95factory_tuple! { P0 P1 P2 P3 P4 P5 }
96factory_tuple! { P0 P1 P2 P3 P4 P5 P6 }
97factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 }
98factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 P8 }
99factory_tuple! { P0 P1 P2 P3 P4 P5 P6 P7 P8 P9 }
100
101#[cfg(test)]
102mod tests {
103
104 use std::fmt::Display;
105
106 use super::*;
107 use ::axum::{
108 body::{Body, BoxBody},
109 extract::Extension,
110 http::{header, Method, Request, Response, StatusCode},
111 };
112 use ajars_core::RestFluent;
113 use serde::{Deserialize, Serialize};
114 use tower::ServiceExt; #[derive(Serialize, Deserialize, Debug)]
117 pub struct PingRequest {
118 pub message: String,
119 }
120
121 #[derive(Serialize, Deserialize, Debug)]
122 pub struct PingResponse {
123 pub message: String,
124 }
125
126 async fn ping(body: PingRequest, _data: Extension<()>) -> Result<PingResponse, ServerError> {
127 Ok(PingResponse { message: body.message })
128 }
129
130 #[derive(Debug, Clone)]
131 struct ServerError {}
132
133 impl Display for ServerError {
134 fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 Ok(())
136 }
137 }
138
139 impl IntoResponse for ServerError {
140 fn into_response(self) -> Response<BoxBody> {
141 Response::new(axum::body::boxed(Body::empty()))
142 }
143 }
144
145 #[tokio::test]
146 async fn should_create_a_delete_endpoint() {
147 let rest =
149 RestFluent::<PingRequest, PingResponse>::delete(format!("/api/something/{}", rand::random::<usize>()));
150
151 let app = rest.to(ping).layer(Extension(()));
152
153 let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
154
155 let response = app
157 .oneshot(
158 Request::builder()
159 .method(Method::DELETE)
160 .header(header::CONTENT_TYPE, "application/json")
161 .uri(&format!("{}?message={}", rest.path(), payload.message))
162 .body(Body::empty())
163 .unwrap(),
164 )
165 .await
166 .unwrap();
167
168 assert_eq!(response.status(), StatusCode::OK);
170 assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
171
172 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
173 let body: PingResponse = serde_json::from_slice(&body).unwrap();
174
175 assert_eq!(body.message, payload.message);
177 }
178
179 #[tokio::test]
180 async fn should_create_a_get_endpoint() {
181 let rest = RestFluent::<PingRequest, PingResponse>::get(format!("/api/something/{}", rand::random::<usize>()));
183
184 let app = rest.to(ping).layer(Extension(()));
185
186 let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
187
188 let response = app
190 .oneshot(
191 Request::builder()
192 .method(Method::GET)
193 .header(header::CONTENT_TYPE, "application/json")
194 .uri(&format!("{}?message={}", rest.path(), payload.message))
195 .body(Body::empty())
196 .unwrap(),
197 )
198 .await
199 .unwrap();
200
201 assert_eq!(response.status(), StatusCode::OK);
203 assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
204
205 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
206 let body: PingResponse = serde_json::from_slice(&body).unwrap();
207
208 assert_eq!(body.message, payload.message);
210 }
211
212 #[tokio::test]
213 async fn should_create_a_post_endpoint() {
214 let rest = RestFluent::<PingRequest, PingResponse>::post(format!("/api/something/{}", rand::random::<usize>()));
216
217 let app = rest.to(ping).layer(Extension(()));
218
219 let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
220
221 let response = app
223 .oneshot(
224 Request::builder()
225 .method(Method::POST)
226 .header(header::CONTENT_TYPE, "application/json")
227 .uri(rest.path())
228 .body(Body::from(serde_json::to_vec(&payload).unwrap()))
229 .unwrap(),
230 )
231 .await
232 .unwrap();
233
234 assert_eq!(response.status(), StatusCode::OK);
236 assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
237
238 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
239 let body: PingResponse = serde_json::from_slice(&body).unwrap();
240
241 assert_eq!(body.message, payload.message);
243 }
244
245 #[tokio::test]
246 async fn should_create_a_put_endpoint() {
247 let rest = RestFluent::<PingRequest, PingResponse>::put(format!("/api/something/{}", rand::random::<usize>()));
249
250 let app = rest.to(ping).layer(Extension(()));
251
252 let payload = PingRequest { message: format!("message{}", rand::random::<usize>()) };
253
254 let response = app
256 .oneshot(
257 Request::builder()
258 .method(Method::PUT)
259 .header(header::CONTENT_TYPE, "application/json")
260 .uri(rest.path())
261 .body(Body::from(serde_json::to_vec(&payload).unwrap()))
262 .unwrap(),
263 )
264 .await
265 .unwrap();
266
267 assert_eq!(response.status(), StatusCode::OK);
269 assert_eq!("application/json", response.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap());
270
271 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
272 let body: PingResponse = serde_json::from_slice(&body).unwrap();
273
274 assert_eq!(body.message, payload.message);
276 }
277
278 #[tokio::test]
279 async fn route_should_accept_variable_number_of_params() {
280 let rest =
282 RestFluent::<PingRequest, PingResponse>::delete(format!("/api/something/{}", rand::random::<usize>()));
283
284 rest.to(|body: PingRequest| async { Result::<_, ServerError>::Ok(PingResponse { message: body.message }) });
286
287 rest.to(|body: PingRequest, _: Extension<()>| async {
289 Result::<_, ServerError>::Ok(PingResponse { message: body.message })
290 });
291
292 rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>| async {
294 Result::<_, ServerError>::Ok(PingResponse { message: body.message })
295 });
296
297 rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>, _: Request<Body>| async {
299 Result::<_, ServerError>::Ok(PingResponse { message: body.message })
300 });
301
302 rest.to(|body: PingRequest, _: Extension<()>, _: Request<Body>, _: Request<Body>, _: Request<Body>| async {
304 Result::<_, ServerError>::Ok(PingResponse { message: body.message })
305 });
306 }
307}