use crate::script::LuaValueExt;
use mlua::{FromLua, Lua, Value};
use reqwest::ClientBuilder;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::redirect::Policy;
use std::collections::HashMap;
pub const DEFAULT_REDIRECT_LIMIT: i32 = 5;
pub const DEFAULT_UA_AIPACK: &str = "aipack";
pub const DEFAULT_UA_BROWSER: &str =
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
#[derive(Default)]
pub struct WebOptions {
pub user_agent: Option<String>,
pub headers: Option<HashMap<String, Vec<String>>>,
pub redirect_limit: Option<i32>,
pub parse: Option<bool>,
}
impl FromLua for WebOptions {
fn from_lua(value: Value, _lua: &Lua) -> mlua::Result<Self> {
match value {
Value::Nil => Ok(WebOptions::default()),
Value::Table(table) => {
let user_agent = match table.get::<Value>("user_agent")? {
Value::String(s) => Some(s.to_string_lossy()),
Value::Boolean(true) => Some(DEFAULT_UA_AIPACK.to_owned()),
Value::Boolean(false) => Some("".to_owned()),
_ => None,
};
let redirect_limit = table.x_get_i64("redirect_limit").map(|v| v as i32);
let parse = table.x_get_bool("parse");
let headers = if let Ok(headers_table) = table.get::<mlua::Table>("headers") {
let mut headers_map = HashMap::new();
for pair in headers_table.pairs::<String, Value>() {
let (key, value) = pair?;
match value {
Value::String(s) => {
headers_map.insert(key, vec![s.to_string_lossy()]);
}
Value::Table(t) => {
let mut values = Vec::new();
for v in t.sequence_values::<String>() {
values.push(v?);
}
headers_map.insert(key, values);
}
_ => {
return Err(mlua::Error::FromLuaConversionError {
from: value.type_name(),
to: "String or Array".to_string(),
message: Some("Header values must be strings or arrays of strings".into()),
});
}
}
}
Some(headers_map)
} else {
None
};
Ok(WebOptions {
user_agent,
headers,
redirect_limit,
parse,
})
}
other => Err(mlua::Error::FromLuaConversionError {
from: other.type_name(),
to: "WebOptions".to_string(),
message: Some("Expected nil or a table for WebOptions".into()),
}),
}
}
}
impl WebOptions {
pub fn apply_to_reqwest_builder(mut self, mut client_builder: ClientBuilder) -> ClientBuilder {
let limit = self.redirect_limit.unwrap_or(DEFAULT_REDIRECT_LIMIT);
client_builder = client_builder.redirect(Policy::limited(limit as usize));
let mut user_agent_to_set: Option<String> = self.user_agent.take();
if user_agent_to_set.is_some()
&& let Some(headers) = self.headers.as_mut()
&& let Some(ua_key) = headers.keys().find(|k| k.eq_ignore_ascii_case("user-agent")).cloned()
{
headers.remove(&ua_key);
}
if user_agent_to_set.is_none()
&& let Some(headers) = self.headers.as_mut()
&& let Some(ua_key) = headers.keys().find(|k| k.eq_ignore_ascii_case("user-agent")).cloned()
&& let Some(ua_values) = headers.remove(&ua_key)
&& let Some(first_value) = ua_values.into_iter().next()
{
user_agent_to_set = Some(first_value);
}
if user_agent_to_set.is_none() {
user_agent_to_set = Some(DEFAULT_UA_AIPACK.to_owned());
}
if let Some(ua) = user_agent_to_set
&& !ua.is_empty()
{
client_builder = client_builder.user_agent(ua);
}
if let Some(headers) = self.headers {
let mut header_map = HeaderMap::new();
for (key, values) in headers {
if key.eq_ignore_ascii_case("user-agent") {
continue;
}
if let Ok(header_name) = HeaderName::from_bytes(key.as_bytes()) {
for value in values {
if let Ok(header_value) = HeaderValue::from_str(&value) {
header_map.append(header_name.clone(), header_value);
}
}
}
}
if !header_map.is_empty() {
client_builder = client_builder.default_headers(header_map);
}
}
client_builder
}
}