use std::marker::PhantomData;
use axum::{
async_trait,
body::{Body, Bytes},
extract::FromRequest,
response::IntoResponse,
};
use futures::future::BoxFuture;
use http::{Request, Response, StatusCode};
use tower::Service;
use crate::{Error, Frame, FrameFuture, Handler, StatefulSystem};
pub struct FrameRequest(Bytes);
impl From<FrameRequest> for Frame {
fn from(frame: FrameRequest) -> Self {
Frame::new(frame.0)
}
}
#[async_trait]
impl<State> FromRequest<State> for FrameRequest
where
Bytes: FromRequest<State>,
State: Send + Sync,
{
type Rejection = Response<Body>;
async fn from_request(req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
let body = Bytes::from_request(req, state)
.await
.map_err(IntoResponse::into_response)?;
Ok(Self(body))
}
}
pub struct FrameResponse(Frame);
impl From<Frame> for FrameResponse {
fn from(frame: Frame) -> Self {
Self(frame)
}
}
impl IntoResponse for FrameResponse {
fn into_response(self) -> Response<Body> {
let body = self.0.into_bytes();
(StatusCode::OK, body).into_response()
}
}
#[derive(Debug)]
pub struct FrameResponseError(Error);
impl From<Error> for FrameResponseError {
fn from(error: Error) -> Self {
Self(error)
}
}
impl IntoResponse for FrameResponseError {
fn into_response(self) -> Response<Body> {
(StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response()
}
}
#[derive(Clone)]
pub struct RequestHandler<GivenHandler, Args, State>
where
GivenHandler: crate::Handler<Args, State> + Clone + Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
Args: Clone + Send + Sync + 'static,
{
handler: GivenHandler,
_state: PhantomData<State>,
_args: PhantomData<Args>,
}
impl<GivenHandler, Args, State> RequestHandler<GivenHandler, Args, State>
where
GivenHandler: crate::Handler<Args, State> + Clone + Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
Args: Clone + Send + Sync + 'static,
{
pub fn new(handler: GivenHandler) -> Self {
Self {
handler,
_state: PhantomData,
_args: PhantomData,
}
}
}
impl<GivenHandler, Args, State> Handler<Args, State> for RequestHandler<GivenHandler, Args, State>
where
GivenHandler: Handler<Args, State> + Clone + Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
Args: Clone + Send + Sync + 'static,
{
type Future = FrameFuture;
fn invoke(&self, frame: impl Into<Frame>, state: State) -> Self::Future {
let handler = self.handler.clone();
let frame = frame.into();
FrameFuture::from_async_block(async move { handler.invoke(frame, state).await })
}
}
impl<GivenHandler, Args, State> axum::handler::Handler<Args, State>
for RequestHandler<GivenHandler, Args, State>
where
GivenHandler: Handler<Args, State> + Clone + Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
Args: Clone + Send + Sync + 'static,
{
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = http::Response<axum::body::Body>> + Send>,
>;
fn call(self, request: http::Request<axum::body::Body>, state: State) -> Self::Future {
use axum::extract::FromRequest;
use axum::response::IntoResponse;
let handler = self.handler.clone();
Box::pin(async move {
match crate::FrameRequest::from_request(request, &state).await {
Ok(frame_request) => handler
.invoke(frame_request, state)
.await
.map(crate::FrameResponse::from)
.map_err(crate::FrameResponseError::from)
.into_response(),
Err(rejection) => rejection.into_response(),
}
})
}
}
impl<State> Service<FrameRequest> for StatefulSystem<State>
where
State: Clone + Send + Sync + 'static,
{
type Response = FrameResponse;
type Error = FrameResponseError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: FrameRequest) -> Self::Future {
let instance = self.clone();
Box::pin(async move {
instance
.handle_frame(request.into())
.await
.map(Into::into)
.map_err(Into::into)
})
}
}