lua-astra 0.47.0

🔥 Blazingly Fast 🔥 runtime environment for Lua
use crate::components::{AstraBuffer, astra_serde::sanetize_lua_input};
use mlua::{ExternalResult, LuaSerdeExt};
use reqwest::{Client, RequestBuilder};
use std::collections::HashMap;

#[derive(Debug, Clone)]
pub enum HTTPClientRequestBodyTypes {
    String(String),
    Json(serde_json::Value),
    Bytes(Vec<u8>),
}

#[derive(Debug, Clone)]
pub struct HTTPClientRequest {
    pub url: String,
    pub method: String,
    pub headers: HashMap<String, String>,
    pub body: Option<HTTPClientRequestBodyTypes>,
    pub file: Option<String>,
    pub form: HashMap<String, String>,
}

impl HTTPClientRequest {
    pub fn register_to_lua(lua: &mlua::Lua) -> mlua::Result<()> {
        let function = lua.create_function(|lua, details: mlua::Value| match details {
            mlua::Value::String(details) => Ok(Self {
                url: details.to_string_lossy(),
                method: "GET".to_string(),
                headers: HashMap::new(),
                body: None,
                file: None,
                form: HashMap::new(),
            }),
            mlua::Value::Table(details) => {
                let mut headers: HashMap<String, String> =
                    details.get("headers").unwrap_or(HashMap::new());
                let body = details.get::<mlua::Value>("body")?;
                let body = Self::body_parser(lua, &mut headers, body)?;

                Ok(Self {
                    url: details.get("url")?,
                    method: details
                        .get::<String>("method")
                        .map(|method| method.to_uppercase())
                        .unwrap_or("GET".to_string()),
                    headers,
                    body,
                    file: details.get::<String>("file").ok(),
                    form: details
                        .get::<HashMap<String, String>>("form")
                        .unwrap_or_default(),
                })
            }
            _ => Err(mlua::Error::runtime(
                "Bad argument, expected string or table",
            )),
        })?;
        lua.globals().set("astra_internal__http_request", function)
    }

    pub async fn request_builder(&self) -> mlua::Result<RequestBuilder> {
        let mut client = match self.method.to_uppercase().as_str() {
            "CONNECT" => Client::new().request(reqwest::Method::CONNECT, &self.url),
            "OPTIONS" => Client::new().request(reqwest::Method::OPTIONS, &self.url),
            "DELETE" => Client::new().request(reqwest::Method::DELETE, &self.url),
            "TRACE" => Client::new().request(reqwest::Method::TRACE, &self.url),
            "PATCH" => Client::new().request(reqwest::Method::PATCH, &self.url),
            "HEAD" => Client::new().request(reqwest::Method::HEAD, &self.url),
            "POST" => Client::new().request(reqwest::Method::POST, &self.url),
            "PUT" => Client::new().request(reqwest::Method::PUT, &self.url),
            "GET" => Client::new().request(reqwest::Method::GET, &self.url),
            _ => Client::new().request(
                reqwest::Method::from_bytes(self.method.to_uppercase().as_bytes())
                    .into_lua_err()?,
                &self.url,
            ),
        };

        if let Some(HTTPClientRequestBodyTypes::String(body)) = &self.body {
            client = client.body(body.clone())
        } else if let Some(HTTPClientRequestBodyTypes::Bytes(body)) = &self.body {
            client = client.body(body.clone())
        } else if let Some(HTTPClientRequestBodyTypes::Json(body)) = &self.body {
            client = client.json(&body)
        } else if let Some(file_body) = &self.file {
            let path = std::path::PathBuf::from(&file_body);
            let path_filename = path.clone();
            let file_form = reqwest::multipart::Form::new();

            let filename = path_filename
                .file_name()
                .and_then(|filename| filename.to_str())
                .unwrap_or("file.txt")
                .to_string();

            if let Ok(file_form) = file_form.file(filename, path).await {
                client = client.multipart(file_form)
            }
        }

        if !self.headers.is_empty() {
            for (key, value) in self.headers.iter() {
                client = client.header(key, value);
            }
        }
        if !self.form.is_empty() {
            client = client.form(&self.form);
        }

        Ok(client)
    }

    pub fn body_parser(
        lua: &mlua::Lua,
        headers: &mut HashMap<String, String>,
        body: mlua::Value,
    ) -> mlua::Result<Option<HTTPClientRequestBodyTypes>> {
        match body.clone() {
            mlua::Value::String(value) => {
                if !headers.contains_key("Content-Type") {
                    headers.insert("Content-Type".to_string(), "text/plain".to_string());
                }
                Ok(Some(HTTPClientRequestBodyTypes::String(
                    value.to_string_lossy(),
                )))
            }
            mlua::Value::Table(value) => {
                if crate::components::is_table_byte_array(&value)? {
                    return Ok(Some(HTTPClientRequestBodyTypes::Bytes(
                        lua.from_value::<Vec<u8>>(body.clone())?,
                    )));
                } else if crate::components::is_table_json(&value)? {
                    if !headers.contains_key("Content-Type") {
                        headers.insert("Content-Type".to_string(), "application/json".to_string());
                    }
                    return Ok(Some(HTTPClientRequestBodyTypes::Json(
                        lua.from_value::<serde_json::Value>(sanetize_lua_input(
                            lua,
                            body.clone(),
                        )?)?,
                    )));
                }
                Ok(None)
            }
            _ => Ok(None),
        }
    }

    pub fn headers_parser(header_map: &reqwest::header::HeaderMap) -> HashMap<String, String> {
        header_map
            .iter()
            .map(|(key, value)| {
                (
                    key.to_string(),
                    String::from_utf8_lossy(value.as_bytes()).to_string(),
                )
            })
            .collect::<std::collections::HashMap<String, String>>()
    }

    pub async fn response_to_http_client_response(
        response: reqwest::Response,
    ) -> super::HTTPClientResponse {
        super::HTTPClientResponse {
            remote_address: response.remote_addr().map(|i| i.to_string()),
            headers: Self::headers_parser(response.headers()),
            status_code: response.status().as_u16(),
            url: response.url().to_string(),
            body: if let Ok(bytes) = response.bytes().await {
                AstraBuffer::new(bytes)
            } else {
                AstraBuffer::new(bytes::Bytes::new())
            },
        }
    }
}