lua-astra 0.47.0

🔥 Blazingly Fast 🔥 runtime environment for Lua
use crate::{
    LUA,
    components::http::server::{
        configs::RouteConfiguration,
        requests::{self, RequestLua},
        responses::{self, CookieOperation},
        routes,
        websocket::AstraWebSocket,
    },
};
use axum::{
    Router,
    body::Body,
    extract::{DefaultBodyLimit, WebSocketUpgrade},
    http::Request,
    response::IntoResponse,
    routing::{any, delete, get, options, patch, post, put, trace},
};
use axum_extra::extract::{CookieJar, cookie::Cookie};
use mlua::LuaSerdeExt;

#[derive(Debug, Clone, Copy, mlua::FromLua, serde::Serialize, serde::Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Method {
    Get,
    Post,
    Put,
    Delete,
    Options,
    Patch,
    Trace,
    StaticDir,
    StaticFile,
    WebSocket,
    Fallback,
}
#[derive(Debug, Clone, mlua::FromLua)]
pub struct Route {
    pub path: String,
    pub method: Method,
    pub function: mlua::Function,
    pub static_dir: Option<String>,
    pub static_file: Option<String>,
    pub config: RouteConfiguration,
}

pub async fn route(
    lua: &mlua::Lua,
    details: Route,
    request: Request<Body>,
) -> Result<(CookieJar, axum::response::Response), axum::http::StatusCode> {
    let request = requests::RequestLua::new(request).await;
    // find a way to add keys here
    let cookie_jar = request.cookie_jar.clone();

    async fn route_inner(
        lua: &mlua::Lua,
        details: Route,
        cookie_jar: CookieJar,
        request: RequestLua,
    ) -> mlua::Result<(CookieJar, axum::response::Response)> {
        let request = lua.create_userdata(request)?;
        let response = lua.create_userdata(responses::ResponseLua::default())?;
        let mut cookie_jar = cookie_jar.clone();

        // if a response userdata can be created
        let result = details
            .function
            .call_async::<mlua::Value>((request, response.clone()))
            .await?;

        let response_details = response.borrow::<responses::ResponseLua>()?;

        if let Some(redirect_to) = &response_details.redirect {
            return Ok((cookie_jar, redirect_to.clone().into_response()));
        }

        let mut resulting_response = match result {
            mlua::Value::String(plain) => plain.to_string_lossy().into_response(),
            mlua::Value::Table(ref table) => {
                if let Ok(true) = crate::components::is_table_byte_array(table) {
                    let bytes: Vec<u8> = lua.from_value(result.clone())?;
                    Body::from(bytes).into_response()
                } else {
                    axum::Json(lua.from_value::<serde_json::Value>(result.clone())?).into_response()
                }
            }
            _ => axum::http::StatusCode::OK.into_response(),
        };
        *resulting_response.status_mut() = response_details.status_code;

        for (key, value) in response_details.headers.iter() {
            resulting_response.headers_mut().insert(key, value.clone());
        }

        for cookie_operation in response_details.cookie_operations.clone().into_iter() {
            match cookie_operation {
                CookieOperation::Add(cookie) => {
                    cookie_jar = cookie_jar.clone().remove(cookie.0.clone());
                    cookie_jar = cookie_jar.clone().add(cookie.0);
                }
                CookieOperation::Remove { key } => {
                    cookie_jar = cookie_jar.clone().remove(Cookie::from(key));
                }
            };
        }

        Ok((cookie_jar, resulting_response))
    }

    match route_inner(lua, details, cookie_jar.clone(), request).await {
        Ok(response) => Ok(response),
        Err(e) => {
            tracing::error!("Error executing the route: {e}");

            Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
        }
    }
}

pub fn load_routes(server: mlua::Table) -> Router {
    let mut router = Router::new();
    #[allow(clippy::expect_used)]
    let lua = LUA.get().expect("Could not get access to the global VM");

    let mut routes = Vec::new();
    let mut parse_route = |entry: &mlua::Table| -> mlua::Result<()> {
        routes.push(routes::Route {
            path: lua.from_value(entry.get("path")?)?,
            method: lua.from_value(entry.get("method")?)?,
            function: entry.get::<mlua::Function>("func")?,
            static_dir: lua.from_value(entry.get("static_dir")?)?,
            static_file: lua.from_value(entry.get("static_file")?)?,
            config: lua.from_value(entry.get("config")?)?,
        });

        Ok(())
    };

    if let Ok(server) = server.get::<mlua::Table>("routes") {
        #[allow(clippy::expect_used)]
        server
            .for_each(|_key: mlua::Value, entry: mlua::Value| {
                if let Some(entry) = entry.as_table() {
                    let _ = parse_route(entry);
                }

                Ok(())
            })
            .expect("Could not parse the routes");

        for route_values in routes.clone() {
            let path = route_values.path.clone();
            let path = path.as_str();

            let config = route_values.config.clone();
            let body_limit = config.body_limit;
            let compression = config.compression;

            macro_rules! match_routes {
                ($route_function:expr) => {{
                    let mut route_function =
                        $route_function(move |request: Request<Body>| async move {
                            route(lua, route_values, request).await
                        });

                    if let Some(body_limit) = body_limit {
                        route_function = route_function.layer(DefaultBodyLimit::max(body_limit))
                    }
                    if let Some(compression) = compression
                        && compression
                    {
                        route_function = route_function.layer(
                            tower::ServiceBuilder::new()
                                .layer(tower_http::decompression::RequestDecompressionLayer::new())
                                .layer(tower_http::compression::CompressionLayer::new()),
                        )
                    }

                    router.route(path, route_function)
                }};
            }

            router = match route_values.method {
                Method::Get => match_routes!(get),
                Method::Post => match_routes!(post),
                Method::Put => match_routes!(put),
                Method::Delete => match_routes!(delete),
                Method::Options => match_routes!(options),
                Method::Patch => match_routes!(patch),
                Method::Trace => match_routes!(trace),
                Method::StaticDir => {
                    if let Some(serve_path) = route_values.static_dir {
                        let service = tower_http::services::ServeDir::new(serve_path);
                        let mut router_part = if path == "/" {
                            router.fallback_service(service)
                        } else {
                            router.nest_service(path, service)
                        };
                        if let Some(headers) = &route_values.config.headers {
                            for (k, v) in headers {
                                if let Ok(header_name) = k.parse::<axum::http::HeaderName>()
                                    && let Ok(header_value) = v.parse::<axum::http::HeaderValue>()
                                {
                                    router_part = router_part.layer(
                                        tower_http::set_header::SetResponseHeaderLayer::overriding(
                                            header_name,
                                            header_value,
                                        ),
                                    );
                                }
                            }
                        }
                        router_part
                    } else {
                        router
                    }
                }
                Method::StaticFile => {
                    if let Some(serve_path) = route_values.static_file {
                        let service = tower_http::services::ServeFile::new(serve_path);
                        let mut router_part = if path == "/" {
                            router.fallback_service(service)
                        } else {
                            router.nest_service(path, service)
                        };
                        if let Some(headers) = &route_values.config.headers {
                            for (k, v) in headers {
                                if let Ok(header_name) = k.parse::<axum::http::HeaderName>()
                                    && let Ok(header_value) = v.parse::<axum::http::HeaderValue>()
                                {
                                    router_part = router_part.layer(
                                        tower_http::set_header::SetResponseHeaderLayer::overriding(
                                            header_name,
                                            header_value,
                                        ),
                                    );
                                }
                            }
                        }
                        router_part
                    } else {
                        router
                    }
                }
                Method::WebSocket => router.route(
                    &route_values.path,
                    any(|ws: WebSocketUpgrade| async {
                        ws.on_failed_upgrade(|err| {
                            mlua::Error::runtime(format!("failed to upgrade connection: {err}"));
                        })
                        .on_upgrade(|socket| async move {
                            let lua_socket = AstraWebSocket(socket);
                            let _ = route_values.function.call_async::<()>(lua_socket).await;
                        })
                    }),
                ),
                Method::Fallback => {
                    router.fallback(|request: Request<Body>| route(lua, route_values, request))
                }
            }
        }

        if let Ok(should_compress) = server.get::<bool>("compression")
            && should_compress
        {
            router = router.layer(
                tower::ServiceBuilder::new()
                    .layer(tower_http::decompression::RequestDecompressionLayer::new())
                    .layer(tower_http::compression::CompressionLayer::new()),
            );
        }
    }

    router
}