use std::{str::FromStr, sync::Arc};
use super::*;
use reqwest::{get as reqwest_get, Client, Response};
#[derive(PartialEq, Eq)]
pub enum RequestType {
Get,
Post(String),
}
pub struct NetRequest(String, RequestType);
impl NetRequest {
#[allow(dead_code)]
pub fn ensure_type(&self, t: RequestType) -> LuaResult<()> {
if self.1 != t {
Err(LuaError::runtime("Incorrect request type"))
} else {
Ok(())
}
}
pub async fn request(&self) -> LuaResult<Response> {
let uri = match reqwest::Url::from_str(&self.0) {
Ok(u) => u,
Err(e) => return Err(LuaError::external(Arc::new(e))),
};
match &self.1 {
RequestType::Get => match reqwest_get(uri).await {
Ok(r) => Ok(r),
Err(e) => Err(LuaError::external(Arc::new(e))),
},
RequestType::Post(body) => {
let client = Client::new();
let response = match client.post(&self.0).body(body.clone()).send().await {
Ok(r) => r,
Err(e) => return Err(LuaError::external(Arc::new(e))),
};
Ok(response)
}
}
}
}
impl LuaUserData for NetRequest {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_async_method("text", |_, this, _: ()| async move {
let response = this.request().await?;
match response.text().await {
Ok(text) => Ok(text),
Err(e) => Err(LuaError::external(Arc::new(e))),
}
});
methods.add_async_method("json", |_, this, _: ()| async move {
let response = this.request().await?;
match response.json::<serde_json::Value>().await {
Ok(json) => Ok(json.to_string()),
Err(e) => Err(LuaError::external(Arc::new(e))),
}
});
methods.add_async_method("status", |_, this, _: ()| async move {
let response = this.request().await?;
Ok(response.status().to_string())
});
methods.add_async_method("headers", |lua, this, _: ()| async move {
let response = this.request().await?;
let table = lua.create_table()?;
for (key, value) in response.headers() {
let str = match value.to_str() {
Ok(s) => s.to_string(),
Err(e) => return Err(LuaError::external(Arc::new(e))),
};
table.set(key.to_string(), str)?;
}
Ok(table)
});
}
}
pub struct Net;
impl Module for Net {
async fn load(&self, state: &mut State) -> Result<(), LuaError> {
let module = state.lua.create_table()?;
module.set(
"get",
state.lua.create_async_function(|_, url: String| async {
Ok(NetRequest(url, RequestType::Get))
})?,
)?;
module.set(
"post",
state
.lua
.create_async_function(|_, (url, body): (String, String)| async {
Ok(NetRequest(url, RequestType::Post(body)))
})?,
)?;
create_module(&state.lua, "autosway.net", module)?;
Ok(())
}
}