1use std::{collections::HashMap, convert::Infallible};
2
3use axum::{
4 body::{to_bytes, Body},
5 extract::Request,
6 response::{IntoResponse, Response},
7};
8use bytes::Bytes;
9use futures::future::BoxFuture;
10use http::StatusCode;
11use tower::Service;
12
13use crate::{Error, Frame, MessageFrame, StatefulSystem, StatelessSystem};
14
15impl Service<Request<Body>> for StatelessSystem {
16 type Response = Response<Body>;
17 type Error = Infallible;
18 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
19
20 fn poll_ready(
21 &mut self,
22 _cx: &mut std::task::Context<'_>,
23 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
24 std::task::Poll::Ready(Ok(()))
25 }
26
27 fn call(&mut self, request: Request<Body>) -> Self::Future {
28 let system = self.clone();
29 let frame = HttpRequestFrame::from(request);
30
31 Box::pin(async move {
32 let frame = frame.into_frame().await;
33 let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
34 .handle_frame(frame)
35 .await
36 .map(Into::into)
37 .map_err(Into::into);
38
39 Ok(response
40 .map(|response| response.into_json_response())
41 .into_response())
42 })
43 }
44}
45
46impl<State> Service<Request<Body>> for StatefulSystem<State>
47where
48 State: Clone + Sync + Send + 'static,
49{
50 type Response = Response<Body>;
51 type Error = Infallible;
52 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
53
54 fn poll_ready(
55 &mut self,
56 _cx: &mut std::task::Context<'_>,
57 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
58 std::task::Poll::Ready(Ok(()))
59 }
60
61 fn call(&mut self, request: Request<Body>) -> Self::Future {
62 let system = self.clone();
63 let frame = HttpRequestFrame::from(request);
64
65 Box::pin(async move {
66 let frame = frame.into_frame().await;
67 let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
68 .handle_frame(frame)
69 .await
70 .map(Into::into)
71 .map_err(Into::into);
72
73 Ok(response
74 .map(|response| response.into_json_response())
75 .into_response())
76 })
77 }
78}
79
80#[derive(Debug)]
82pub struct HttpFrameResponse(Frame);
83
84impl HttpFrameResponse {
85 fn status_code(&self) -> StatusCode {
87 match &self.0 {
88 Frame::Anonymous(_) | Frame::Unit => StatusCode::OK,
89 Frame::Message(MessageFrame { meta, .. }) => {
90 match serde_json::from_slice::<HttpFrameMeta>(meta) {
91 Ok(http_meta) => {
92 StatusCode::from_u16(http_meta.status).unwrap_or(StatusCode::OK)
93 }
94 Err(_) => StatusCode::OK,
95 }
96 }
97 Frame::Error(_) => StatusCode::INTERNAL_SERVER_ERROR,
98 }
99 }
100
101 fn body(&self) -> Bytes {
103 self.0.clone().into_bytes()
104 }
105
106 pub fn into_plain_response(self) -> impl IntoResponse {
108 PlainHttpFrameResponse(self)
109 }
110
111 pub fn into_json_response(self) -> impl IntoResponse {
113 JsonHttpFrameResponse(self)
114 }
115}
116
117impl From<HttpFrameResponse> for Frame {
118 fn from(frame: HttpFrameResponse) -> Self {
119 frame.0
120 }
121}
122
123impl From<Frame> for HttpFrameResponse {
124 fn from(frame: Frame) -> Self {
125 Self(frame)
126 }
127}
128
129#[derive(Debug)]
131pub struct PlainHttpFrameResponse(HttpFrameResponse);
132
133impl IntoResponse for PlainHttpFrameResponse {
134 fn into_response(self) -> Response<Body> {
135 (self.0.status_code(), self.0.body()).into_response()
136 }
137}
138
139#[derive(Debug)]
141pub struct JsonHttpFrameResponse(HttpFrameResponse);
142
143impl IntoResponse for JsonHttpFrameResponse {
144 fn into_response(self) -> Response<Body> {
145 use axum::Json;
146 let body = self.0.body();
147 let body = if body.is_empty() {
148 return (self.0.status_code(), Json(serde_json::Value::Null)).into_response();
149 } else {
150 body
151 };
152
153 let value = match serde_json::from_slice::<serde_json::Value>(&body) {
154 Ok(value) => value,
155 Err(error) => {
156 return (
157 StatusCode::INTERNAL_SERVER_ERROR,
158 Json(serde_json::json!({ "error": format!("Failed to parse response: {error}") })),
159 )
160 .into_response();
161 }
162 };
163
164 (self.0.status_code(), Json(value)).into_response()
165 }
166}
167
168#[derive(Debug)]
170pub struct HttpFrameResponseError(Error);
171
172impl From<Error> for HttpFrameResponseError {
173 fn from(error: Error) -> Self {
174 Self(error)
175 }
176}
177
178impl IntoResponse for HttpFrameResponseError {
179 fn into_response(self) -> Response<Body> {
180 (StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
181 }
182}
183
184#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq)]
186pub struct HttpFrameMeta {
187 #[serde(default = "default_status")]
189 pub status: u16,
190 #[serde(default = "default_method")]
192 pub method: String,
193 #[serde(flatten)]
195 pub details: HashMap<String, String>,
196}
197
198fn default_status() -> u16 {
199 200
200}
201
202fn default_method() -> String {
203 "GET".to_string()
204}
205
206impl Default for HttpFrameMeta {
207 fn default() -> Self {
208 Self {
209 status: 200,
210 method: default_method(),
211 details: HashMap::new(),
212 }
213 }
214}
215
216#[derive(Default, Debug)]
218pub struct HttpRequestFrame {
219 uri: String,
220 meta: HttpFrameMeta,
221 body: Body,
222}
223
224impl HttpRequestFrame {
225 pub async fn into_frame(self) -> Frame {
227 let meta = serde_json::to_vec(&self.meta).unwrap();
228 let body = to_bytes(self.body, usize::MAX).await.unwrap();
229
230 Frame::message(self.uri, body, meta)
231 }
232}
233
234impl From<Request<Body>> for HttpRequestFrame {
235 fn from(request: Request<Body>) -> Self {
236 let (parts, body) = request.into_parts();
237 let mut http_frame = Self {
238 body,
239 uri: parts.uri.to_string(),
240 ..Default::default()
241 };
242
243 http_frame.meta.method = parts.method.to_string();
244 http_frame.meta.details = parts
245 .headers
246 .iter()
247 .map(|(key, value)| (key.to_string(), value.to_str().unwrap().to_string()))
248 .collect();
249
250 http_frame
251 }
252}
253
254