use crate::components::AstraBuffer;
use futures::StreamExt;
use mlua::{ExternalError, UserData};
use reqwest_websocket::Upgrade;
use std::collections::HashMap;
impl UserData for super::HTTPClientRequest {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_method("set_method", |_, this, method: String| {
let mut request = this.clone();
request.method = method;
Ok(request)
});
methods.add_method_mut("set_header", |_, this, (key, value): (String, String)| {
let mut request = this.clone();
request.headers.insert(key, value);
Ok(request)
});
methods.add_method_mut(
"set_headers",
|_, this, headers: HashMap<String, String>| {
let mut request = this.clone();
request.headers = headers;
Ok(request)
},
);
methods.add_method_mut("set_forms", |_, _, _body: mlua::Value| {
panic!("set_forms is deprecated, moved to set_form.");
#[allow(unreachable_code)]
Ok(())
});
methods.add_method_mut("set_form", |_, this, form: HashMap<String, String>| {
let mut request = this.clone();
request.form = form;
Ok(request)
});
methods.add_method_mut("set_body", |lua, this, body: mlua::Value| {
let mut request = this.clone();
request.body = Self::body_parser(lua, &mut request.headers, body)?;
if !request.headers.contains_key("Content-Type") {
request
.headers
.insert("Content-Type".to_string(), "text/plain".to_string());
}
Ok(request)
});
methods.add_method_mut("set_bytes", |_, _, _body: mlua::Value| {
panic!("set_bytes is deprecated, use the set_body instead.");
#[allow(unreachable_code)]
Ok(())
});
methods.add_method_mut("set_json", |_, _, _body: mlua::Value| {
panic!("set_json is deprecated, use the set_body instead.");
#[allow(unreachable_code)]
Ok(())
});
methods.add_method_mut("set_file", |_, this, file_path: String| {
let mut request = this.clone();
request.file = Some(file_path);
Ok(request)
});
methods.add_async_method("execute", |_, this, ()| async move {
let request = this.request_builder().await?;
match request.send().await {
Ok(response) => Ok(Self::response_to_http_client_response(response).await),
Err(e) => Err(e.into_lua_err()),
}
});
methods.add_method("execute_task", |_, _, _: ()| {
panic!("execute_task is deprecated, use execute within async task instead.");
#[allow(unreachable_code)]
Ok(())
});
methods.add_async_method(
"execute_streaming",
|_, this, callback: mlua::Function| async move {
tokio::spawn(async move {
let request = this.request_builder().await?;
let response = match request.send().await {
Ok(response) => response,
Err(e) => {
tracing::error!("HTTP Request did not execute successfully: {e}");
return mlua::Result::Ok(());
}
};
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
.collect();
let initial_response = HTTPClientResponse {
url: response.url().to_string(),
status_code: response.status().as_u16(),
remote_address: response.remote_addr().map(|i| i.to_string()),
body: AstraBuffer::new(bytes::Bytes::new()),
headers,
};
if let Err(e) = callback.call::<()>(initial_response.clone()) {
tracing::error!("Error running initial callback: {e}");
return Ok(());
}
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => {
let mut chunk_response = initial_response.clone();
chunk_response.body = AstraBuffer::new(chunk);
if let Err(e) = callback.call::<()>(chunk_response) {
tracing::error!("Error running chunk callback: {e}");
break;
}
}
Err(e) => {
tracing::error!("Error receiving chunk: {e}");
break;
}
}
}
Ok(())
});
Ok(())
},
);
methods.add_async_method(
"execute_websocket",
|lua, this, callback: mlua::Function| async move {
tokio::spawn(async move {
let request = this.request_builder().await?;
let request = request.upgrade();
if let Ok(response) = request.send().await
&& let Ok(response) = response.into_websocket().await
{
if let Err(e) = callback
.call_async::<()>(lua.create_userdata(super::AstraWebSocket(response)))
.await
{
tracing::error!("Error running a task: {e}")
}
} else {
tracing::error!("Websocket request did not execute successfully");
};
mlua::Result::Ok(())
});
Ok(())
},
);
}
}
#[derive(Debug, Clone)]
pub struct HTTPClientResponse {
pub url: String,
pub status_code: u16,
pub remote_address: Option<String>,
pub body: AstraBuffer,
pub headers: HashMap<String, String>,
}
impl UserData for HTTPClientResponse {
fn add_methods<M: mlua::UserDataMethods<Self>>(methods: &mut M) {
methods.add_method("url", |_, this, ()| Ok(this.url.clone()));
methods.add_method("status_code", |_, this, ()| Ok(this.status_code));
methods.add_method("remote_address", |_, this, ()| {
Ok(this.remote_address.clone())
});
methods.add_method("body", |_, this, ()| Ok(this.body.clone()));
methods.add_method("headers", |_, this, ()| Ok(this.headers.clone()));
}
}