intrepid_core/
axum.rs

1use std::{collections::HashMap, convert::Infallible};
2
3use axum::{
4    body::{to_bytes, Body},
5    extract::Request,
6    response::{IntoResponse, Response},
7};
8use bytes::Bytes;
9use futures::future::BoxFuture;
10use http::StatusCode;
11use tower::Service;
12
13use crate::{Error, Frame, MessageFrame, StatefulSystem, StatelessSystem};
14
15impl Service<Request<Body>> for StatelessSystem {
16    type Response = Response<Body>;
17    type Error = Infallible;
18    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
19
20    fn poll_ready(
21        &mut self,
22        _cx: &mut std::task::Context<'_>,
23    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
24        std::task::Poll::Ready(Ok(()))
25    }
26
27    fn call(&mut self, request: Request<Body>) -> Self::Future {
28        let system = self.clone();
29        let frame = HttpRequestFrame::from(request);
30
31        Box::pin(async move {
32            let frame = frame.into_frame().await;
33            let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
34                .handle_frame(frame)
35                .await
36                .map(Into::into)
37                .map_err(Into::into);
38
39            Ok(response
40                .map(|response| response.into_json_response())
41                .into_response())
42        })
43    }
44}
45
46impl<State> Service<Request<Body>> for StatefulSystem<State>
47where
48    State: Clone + Sync + Send + 'static,
49{
50    type Response = Response<Body>;
51    type Error = Infallible;
52    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
53
54    fn poll_ready(
55        &mut self,
56        _cx: &mut std::task::Context<'_>,
57    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
58        std::task::Poll::Ready(Ok(()))
59    }
60
61    fn call(&mut self, request: Request<Body>) -> Self::Future {
62        let system = self.clone();
63        let frame = HttpRequestFrame::from(request);
64
65        Box::pin(async move {
66            let frame = frame.into_frame().await;
67            let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
68                .handle_frame(frame)
69                .await
70                .map(Into::into)
71                .map_err(Into::into);
72
73            Ok(response
74                .map(|response| response.into_json_response())
75                .into_response())
76        })
77    }
78}
79
80/// An intrepid frame being turned into an axum json response.
81#[derive(Debug)]
82pub struct HttpFrameResponse(Frame);
83
84impl HttpFrameResponse {
85    /// Get the status code of the frame.
86    fn status_code(&self) -> StatusCode {
87        match &self.0 {
88            Frame::Anonymous(_) | Frame::Unit => StatusCode::OK,
89            Frame::Message(MessageFrame { meta, .. }) => {
90                match serde_json::from_slice::<HttpFrameMeta>(meta) {
91                    Ok(http_meta) => {
92                        StatusCode::from_u16(http_meta.status).unwrap_or(StatusCode::OK)
93                    }
94                    Err(_) => StatusCode::OK,
95                }
96            }
97            Frame::Error(_) => StatusCode::INTERNAL_SERVER_ERROR,
98        }
99    }
100
101    /// Get the body of the frame.
102    fn body(&self) -> Bytes {
103        self.0.clone().into_bytes()
104    }
105
106    /// Become a plain response.
107    pub fn into_plain_response(self) -> impl IntoResponse {
108        PlainHttpFrameResponse(self)
109    }
110
111    /// Become a JSON response.
112    pub fn into_json_response(self) -> impl IntoResponse {
113        JsonHttpFrameResponse(self)
114    }
115}
116
117impl From<HttpFrameResponse> for Frame {
118    fn from(frame: HttpFrameResponse) -> Self {
119        frame.0
120    }
121}
122
123impl From<Frame> for HttpFrameResponse {
124    fn from(frame: Frame) -> Self {
125        Self(frame)
126    }
127}
128
129/// An intrepid frame being turned into an axum json response.
130#[derive(Debug)]
131pub struct PlainHttpFrameResponse(HttpFrameResponse);
132
133impl IntoResponse for PlainHttpFrameResponse {
134    fn into_response(self) -> Response<Body> {
135        (self.0.status_code(), self.0.body()).into_response()
136    }
137}
138
139/// An intrepid frame being turned into an axum response.
140#[derive(Debug)]
141pub struct JsonHttpFrameResponse(HttpFrameResponse);
142
143impl IntoResponse for JsonHttpFrameResponse {
144    fn into_response(self) -> Response<Body> {
145        use axum::Json;
146        let body = self.0.body();
147        let body = if body.is_empty() {
148            return (self.0.status_code(), Json(serde_json::Value::Null)).into_response();
149        } else {
150            body
151        };
152
153        let value = match serde_json::from_slice::<serde_json::Value>(&body) {
154            Ok(value) => value,
155            Err(error) => {
156                return (
157                    StatusCode::INTERNAL_SERVER_ERROR,
158                    Json(serde_json::json!({ "error": format!("Failed to parse response: {error}") })),
159                )
160                    .into_response();
161            }
162        };
163
164        (self.0.status_code(), Json(value)).into_response()
165    }
166}
167
168/// An intrepid error being turned into an axum response.
169#[derive(Debug)]
170pub struct HttpFrameResponseError(Error);
171
172impl From<Error> for HttpFrameResponseError {
173    fn from(error: Error) -> Self {
174        Self(error)
175    }
176}
177
178impl IntoResponse for HttpFrameResponseError {
179    fn into_response(self) -> Response<Body> {
180        (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
181    }
182}
183
184/// Metadata for a HTTP frame built from an axum request.
185#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq)]
186pub struct HttpFrameMeta {
187    /// The HTTP status code.
188    #[serde(default = "default_status")]
189    pub status: u16,
190    /// The HTTP method.
191    #[serde(default = "default_method")]
192    pub method: String,
193    /// Any additional details, like headers
194    #[serde(flatten)]
195    pub details: HashMap<String, String>,
196}
197
198fn default_status() -> u16 {
199    200
200}
201
202fn default_method() -> String {
203    "GET".to_string()
204}
205
206impl Default for HttpFrameMeta {
207    fn default() -> Self {
208        Self {
209            status: 200,
210            method: default_method(),
211            details: HashMap::new(),
212        }
213    }
214}
215
216/// An HTTP frame built from an axum request.
217#[derive(Default, Debug)]
218pub struct HttpRequestFrame {
219    uri: String,
220    meta: HttpFrameMeta,
221    body: Body,
222}
223
224impl HttpRequestFrame {
225    /// Turn the HTTP frame into an intrepid frame.
226    pub async fn into_frame(self) -> Frame {
227        let meta = serde_json::to_vec(&self.meta).unwrap();
228        let body = to_bytes(self.body, usize::MAX).await.unwrap();
229
230        Frame::message(self.uri, body, meta)
231    }
232}
233
234impl From<Request<Body>> for HttpRequestFrame {
235    fn from(request: Request<Body>) -> Self {
236        let (parts, body) = request.into_parts();
237        let mut http_frame = Self {
238            body,
239            uri: parts.uri.to_string(),
240            ..Default::default()
241        };
242
243        http_frame.meta.method = parts.method.to_string();
244        http_frame.meta.details = parts
245            .headers
246            .iter()
247            .map(|(key, value)| (key.to_string(), value.to_str().unwrap().to_string()))
248            .collect();
249
250        http_frame
251    }
252}
253
254// fn wat() -> BoxCloneService<Request<Body>, Response<Body>, Infallible> {
255//     use std::{iter::once, sync::Arc};
256//     use tower::ServiceBuilder;
257//     use tower_http::{
258//         add_extension::AddExtensionLayer, compression::CompressionLayer,
259//         propagate_header::PropagateHeaderLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
260//         trace::TraceLayer, validate_request::ValidateRequestHeaderLayer,
261//     };
262//     let service = ServiceBuilder::new()
263//         .boxed_clone()
264//         .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION)))
265//         .layer(TraceLayer::new_for_http())
266//         .layer(AddExtensionLayer::new(Arc::new(())))
267//         .layer(CompressionLayer::new())
268//         .layer(PropagateHeaderLayer::new(HeaderName::from_static(
269//             "x-request-id",
270//         )))
271//         .layer(ValidateRequestHeaderLayer::bearer("passwordlol"))
272//         .layer(ValidateRequestHeaderLayer::accept("application/json"))
273//         .service_fn(|_| async { Ok("hay gusy  lol".to_string().into_response()) });
274
275//     service
276// }
277
278// mod wut {
279//     use std::{convert::Infallible, iter::once, sync::Arc};
280
281//     use axum::{
282//         body::Body,
283//         extract::Request,
284//         response::{IntoResponse, Response},
285//     };
286//     use http::{
287//         header::{AUTHORIZATION, CONTENT_TYPE},
288//         HeaderName,
289//     };
290//     use tower::{util::BoxService, ServiceBuilder};
291//     use tower_http::{
292//         add_extension::AddExtensionLayer, compression::CompressionLayer,
293//         propagate_header::PropagateHeaderLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
294//         set_header::SetResponseHeaderLayer, trace::TraceLayer,
295//         validate_request::ValidateRequestHeaderLayer,
296//     };
297
298//     fn wat() -> BoxService<Request, Response, Infallible> {
299//         let service = ServiceBuilder::new()
300//             .boxed()
301//             // Mark the `Authorization` request header as sensitive so it doesn't show in logs
302//             .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION)))
303//             // High level logging of requests and responses
304//             .layer(TraceLayer::new_for_http())
305//             // Share an `Arc<State>` with all requests
306//             .layer(AddExtensionLayer::new(Arc::new(())))
307//             // Compress responses
308//             .layer(CompressionLayer::new())
309//             // Propagate `X-Request-Id`s from requests to responses
310//             .layer(PropagateHeaderLayer::new(HeaderName::from_static(
311//                 "x-request-id",
312//             )))
313//             // If the response has a known size set the `Content-Length` header
314//             // Authorize requests using a token
315//             .layer(ValidateRequestHeaderLayer::bearer("passwordlol"))
316//             // Accept only application/json, application/* and */* in a request's ACCEPT header
317//             .layer(ValidateRequestHeaderLayer::accept("application/json"))
318//             // Wrap a `Service` in our middleware stack
319//             .service_fn(|_| async { Ok("hay gusy  lol".to_string().into_response()) });
320
321//         service
322//     }
323// }