use std::{collections::HashMap, convert::Infallible};
use axum::{
body::{to_bytes, Body},
extract::Request,
response::{IntoResponse, Response},
};
use bytes::Bytes;
use futures::future::BoxFuture;
use http::StatusCode;
use tower::Service;
use crate::{Error, Frame, MessageFrame, StatefulSystem, StatelessSystem};
impl Service<Request<Body>> for StatelessSystem {
type Response = Response<Body>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let system = self.clone();
let frame = HttpRequestFrame::from(request);
Box::pin(async move {
let frame = frame.into_frame().await;
let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
.handle_frame(frame)
.await
.map(Into::into)
.map_err(Into::into);
Ok(response
.map(|response| response.into_json_response())
.into_response())
})
}
}
impl<State> Service<Request<Body>> for StatefulSystem<State>
where
State: Clone + Sync + Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::result::Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let system = self.clone();
let frame = HttpRequestFrame::from(request);
Box::pin(async move {
let frame = frame.into_frame().await;
let response: Result<HttpFrameResponse, HttpFrameResponseError> = system
.handle_frame(frame)
.await
.map(Into::into)
.map_err(Into::into);
Ok(response
.map(|response| response.into_json_response())
.into_response())
})
}
}
#[derive(Debug)]
pub struct HttpFrameResponse(Frame);
impl HttpFrameResponse {
fn status_code(&self) -> StatusCode {
match &self.0 {
Frame::Anonymous(_) | Frame::Unit => StatusCode::OK,
Frame::Message(MessageFrame { meta, .. }) => {
match serde_json::from_slice::<HttpFrameMeta>(meta) {
Ok(http_meta) => {
StatusCode::from_u16(http_meta.status).unwrap_or(StatusCode::OK)
}
Err(_) => StatusCode::OK,
}
}
Frame::Error(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn body(&self) -> Bytes {
self.0.clone().into_bytes()
}
pub fn into_plain_response(self) -> impl IntoResponse {
PlainHttpFrameResponse(self)
}
pub fn into_json_response(self) -> impl IntoResponse {
JsonHttpFrameResponse(self)
}
}
impl From<HttpFrameResponse> for Frame {
fn from(frame: HttpFrameResponse) -> Self {
frame.0
}
}
impl From<Frame> for HttpFrameResponse {
fn from(frame: Frame) -> Self {
Self(frame)
}
}
#[derive(Debug)]
pub struct PlainHttpFrameResponse(HttpFrameResponse);
impl IntoResponse for PlainHttpFrameResponse {
fn into_response(self) -> Response<Body> {
(self.0.status_code(), self.0.body()).into_response()
}
}
#[derive(Debug)]
pub struct JsonHttpFrameResponse(HttpFrameResponse);
impl IntoResponse for JsonHttpFrameResponse {
fn into_response(self) -> Response<Body> {
use axum::Json;
let body = self.0.body();
let body = if body.is_empty() {
return (self.0.status_code(), Json(serde_json::Value::Null)).into_response();
} else {
body
};
let value = match serde_json::from_slice::<serde_json::Value>(&body) {
Ok(value) => value,
Err(error) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Failed to parse response: {error}") })),
)
.into_response();
}
};
(self.0.status_code(), Json(value)).into_response()
}
}
#[derive(Debug)]
pub struct HttpFrameResponseError(Error);
impl From<Error> for HttpFrameResponseError {
fn from(error: Error) -> Self {
Self(error)
}
}
impl IntoResponse for HttpFrameResponseError {
fn into_response(self) -> Response<Body> {
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq)]
pub struct HttpFrameMeta {
#[serde(default = "default_status")]
pub status: u16,
#[serde(default = "default_method")]
pub method: String,
#[serde(flatten)]
pub details: HashMap<String, String>,
}
fn default_status() -> u16 {
200
}
fn default_method() -> String {
"GET".to_string()
}
impl Default for HttpFrameMeta {
fn default() -> Self {
Self {
status: 200,
method: default_method(),
details: HashMap::new(),
}
}
}
#[derive(Default, Debug)]
pub struct HttpRequestFrame {
uri: String,
meta: HttpFrameMeta,
body: Body,
}
impl HttpRequestFrame {
pub async fn into_frame(self) -> Frame {
let meta = serde_json::to_vec(&self.meta).unwrap();
let body = to_bytes(self.body, usize::MAX).await.unwrap();
Frame::message(self.uri, body, meta)
}
}
impl From<Request<Body>> for HttpRequestFrame {
fn from(request: Request<Body>) -> Self {
let (parts, body) = request.into_parts();
let mut http_frame = Self {
body,
uri: parts.uri.to_string(),
..Default::default()
};
http_frame.meta.method = parts.method.to_string();
http_frame.meta.details = parts
.headers
.iter()
.map(|(key, value)| (key.to_string(), value.to_str().unwrap().to_string()))
.collect();
http_frame
}
}