use crate::hub::get_hub;
use crate::runtime::Runtime;
use crate::support::StrExt as _;
use crate::{Error, Result};
use mlua::{Lua, LuaSerdeExt, Table, Value};
use reqwest::redirect::Policy;
use reqwest::{Client, Response, header};
pub fn init_module(lua: &Lua, _runtime_context: &Runtime) -> Result<Table> {
let table = lua.create_table()?;
let web_get_fn = lua.create_function(move |lua, (url,): (String,)| web_get(lua, url))?;
let web_post_fn = lua.create_function(move |lua, (url, data): (String, Value)| web_post(lua, url, data))?;
table.set("get", web_get_fn)?;
table.set("post", web_post_fn)?;
Ok(table)
}
fn web_get(lua: &Lua, url: String) -> mlua::Result<Value> {
let rt = tokio::runtime::Handle::try_current().map_err(Error::TokioTryCurrent)?;
let res: mlua::Result<Value> = tokio::task::block_in_place(|| {
rt.block_on(async {
let client = Client::builder()
.redirect(Policy::limited(5)) .build()
.map_err(crate::Error::from)?;
let res: mlua::Result<Value> = match client.get(&url).send().await {
Ok(response) => get_lua_response_value(lua, response, &url).await,
Err(err) => Err(crate::Error::custom(format!(
"\
Fail to do aip.web.get for url: {url}
Cause: {err}"
))
.into()),
};
if res.is_ok() {
get_hub().publish_sync(format!("-> lua web::get OK ({}) ", url));
}
res
})
});
res
}
fn web_post(lua: &Lua, url: String, data: Value) -> mlua::Result<Value> {
let rt = tokio::runtime::Handle::try_current().map_err(Error::TokioTryCurrent)?;
let res: mlua::Result<Value> = tokio::task::block_in_place(|| {
rt.block_on(async {
let client = Client::builder()
.redirect(Policy::limited(5)) .build()
.map_err(crate::Error::from)?;
let mut request_builder = client.post(&url);
match data {
Value::String(s) => {
request_builder = request_builder
.header(header::CONTENT_TYPE, "plain/text")
.body(s.to_string_lossy());
}
Value::Table(table) => {
let json: serde_json::Value = serde_json::to_value(table).map_err(|err| {
crate::Error::custom(format!(
"Cannot searlize to json the argument given to the post.\n Cause: {err}"
))
})?;
request_builder = request_builder
.header(header::CONTENT_TYPE, "application/json")
.body(json.to_string());
}
_ => {
return Err(mlua::Error::RuntimeError(
"Data must be a string or a table".to_string(),
));
}
}
let res: mlua::Result<Value> = match request_builder.send().await {
Ok(response) => get_lua_response_value(lua, response, &url).await,
Err(err) => Err(crate::Error::custom(format!(
"\
Fail to do aip.web.post for url: {url}
Cause: {err}"
))
.into()),
};
if res.is_ok() {
get_hub().publish_sync(format!("-> lua web::post OK ({}) ", url));
}
res
})
});
res
}
async fn get_lua_response_value(lua: &Lua, response: Response, url: &str) -> mlua::Result<Value> {
let content_type = get_content_type(&response);
let status = response.status();
let success = status.is_success();
let status_code = status.as_u16() as i64;
if success {
let res = lua.create_table()?;
res.set("success", true)?;
res.set("status", status_code)?;
res.set("url", url)?;
let content = response.text().await.map_err(Error::Reqwest)?;
let content = get_content_value_for_content_type(lua, content_type, &content)?;
res.set("content", content)?;
Ok(Value::Table(res))
} else {
let res = lua.create_table()?;
res.set("success", false)?;
res.set("status", status_code)?;
res.set("url", url)?;
let content = response.text().await.unwrap_or_default();
let content = Value::String(lua.create_string(&content)?);
res.set("content", content)?;
res.set("error", format!("Not a 2xx status code ({status_code})"))?;
Ok(Value::Table(res))
}
}
fn get_content_value_for_content_type(lua: &Lua, content_type: Option<String>, content: &str) -> Result<Value> {
let content: Value = if content_type.x_contains("application/json") {
let content: serde_json::Value = serde_json::from_str(content)
.map_err(|err| crate::Error::custom(format!("Fail to parse web response as json.\n Cause: {err}")))?;
lua.to_value(&content)?
} else {
Value::String(lua.create_string(content)?)
};
Ok(content)
}
fn get_content_type(response: &Response) -> Option<String> {
response
.headers()
.get(header::CONTENT_TYPE)
.map(|h| h.to_str().unwrap_or_default().to_lowercase())
}
#[cfg(test)]
mod tests {
type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>;
use crate::_test_support::{assert_contains, eval_lua, setup_lua};
use crate::script::lua_script::aip_web;
use value_ext::JsonValueExt;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_lua_web_get_simple_ok() -> Result<()> {
let lua = setup_lua(aip_web::init_module, "web")?;
let script = r#"
local url = "https://phet-dev.colorado.edu/html/build-an-atom/0.0.0-3/simple-text-only-test-page.html"
return aip.web.get(url)
"#;
let res = eval_lua(&lua, script)?;
let content = res.x_get_str("content")?;
assert_contains(content, "This page tests that simple text can be");
assert_eq!(res.x_get_i64("status")?, 200, "status code");
assert!(res.x_get_bool("success")?, "success should be true");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_lua_web_post_json_ok() -> Result<()> {
let lua = setup_lua(aip_web::init_module, "web")?;
let script = r#"
local url = "https://postman-echo.com/post"
return aip.web.post(url, {some = "stuff"})
"#;
let res = eval_lua(&lua, script)?;
let content = res.pointer("/content").ok_or("Should have content")?;
assert_eq!(content.x_get_str("/json/some")?, "stuff");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_lua_web_get_invalid_url() -> Result<()> {
let lua = setup_lua(aip_web::init_module, "web")?;
let script = r#"
local url = "https://this-cannot-go/anywhere-or-can-it.aip"
return aip.web.get(url)
"#;
let err = match eval_lua(&lua, script) {
Ok(_) => return Err("Should have returned an error".into()),
Err(e) => e,
};
let err_str = err.to_string();
assert_contains(&err_str, "Fail to do aip.web.get");
assert_contains(&err_str, "https://this-cannot-go/anywhere-or-can-it.aip");
Ok(())
}
}