toni 0.1.1

Fast and modular web framework for scalable applications
Documentation
use std::{collections::HashMap, str::FromStr};

use anyhow::{Result, anyhow};
use axum::{
    RequestPartsExt,
    body::to_bytes,
    extract::{Path, Query},
    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode}
};
use serde_json::Value;

use crate::http_helpers::{self, Body, HttpRequest, HttpResponse};

use super::RouteAdapter;

pub struct AxumRouteAdapter;

impl RouteAdapter for AxumRouteAdapter {
    type Request = Request<axum::body::Body>;
    type Response = Response<axum::body::Body>;

    async fn adapt_request(request: Self::Request) -> Result<HttpRequest> {
        let (mut parts, body) = request.into_parts();
        let body_bytes = to_bytes(body, usize::MAX).await?;
        let bytes = body_bytes.to_vec();

        let body = if let Ok(body_str) = String::from_utf8(bytes) {
            if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
                Body::Json(json)
            } else {
                Body::Text(body_str)
            }
        } else {
            Body::Text(String::from_utf8_lossy(&body_bytes).to_string())
        };

        let Path(path_params) = parts
            .extract::<Path<HashMap<String, String>>>()
            .await
            .map_err(|e| anyhow!("Failed to extract path parameters: {:?}", e))?;

        let Query(query_params) = parts
            .extract::<Query<HashMap<String, String>>>()
            .await
            .map_err(|e| anyhow!("Failed to extract query parameters: {:?}", e))?;

        let headers = parts
            .headers
            .iter()
            .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
            .collect();

        Ok(HttpRequest {
            body,
            headers,
            method: parts.method.to_string(),
            uri: parts.uri.to_string(),
            query_params,
            path_params,
        })
    }

    fn adapt_response(
        response: Box<dyn http_helpers::IntoResponse<Response = HttpResponse>>,
    ) -> Result<Self::Response> {
        let response = response.to_response();

        let status =
            StatusCode::from_u16(response.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);

        let mut body_is_json = false;

        let body = match response.body {
            Some(Body::Text(text)) => axum::body::Body::from(text),
            Some(Body::Json(json)) => {
                body_is_json = true;
                let vec = serde_json::to_vec(&json)
                    .map_err(|e| anyhow::anyhow!("Failed to serialize JSON: {}", e))?;
                axum::body::Body::from(vec)
            }
            _ => axum::body::Body::empty(),
        };

        let mut headers = HeaderMap::new();
        for (k, v) in &response.headers {
            if let Ok(header_name) = HeaderName::from_bytes(k.as_bytes()) {
                if let Ok(header_value) = HeaderValue::from_str(v) {
                    headers.insert(header_name, header_value);
                }
            }
        }

        let content_type = if body_is_json {
            "application/json"
        } else {
            "text/plain"
        };
        headers.insert(
            HeaderName::from_str("Content-Type")
                .map_err(|e| anyhow::anyhow!("Failed to parse header name Content-Type: {}", e))?,
            HeaderValue::from_static(content_type),
        );

        let mut res = Response::builder()
            .status(status)
            .body(body)
            .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e))?;

        res.headers_mut().extend(headers);

        Ok(res)
    }
}