use std::sync::Arc;
use std::time::Duration;
use reqwest::multipart::{Form, Part};
use serde::Serialize;
use super::{APIError, APIResponse, HttpMethod};
#[derive(Debug)]
pub struct APIRequestBuilder {
client: Arc<reqwest::Client>,
method: HttpMethod,
url: String,
base_url: Option<String>,
headers: Vec<(String, String)>,
default_headers: Vec<(String, String)>,
query_params: Vec<(String, String)>,
body: Option<RequestBody>,
timeout: Option<Duration>,
disposed: bool,
}
#[derive(Debug)]
pub(crate) enum RequestBody {
Json(Vec<u8>),
Form(Vec<(String, String)>),
Multipart(Vec<MultipartField>),
Bytes(Vec<u8>),
Text(String),
}
#[derive(Debug, Clone)]
pub struct MultipartField {
pub name: String,
pub value: Option<String>,
pub file_content: Option<Vec<u8>>,
pub filename: Option<String>,
pub content_type: Option<String>,
}
impl MultipartField {
pub fn text(name: impl Into<String>, value: impl Into<String>) -> Self {
Self {
name: name.into(),
value: Some(value.into()),
file_content: None,
filename: None,
content_type: None,
}
}
pub fn file(name: impl Into<String>, filename: impl Into<String>, content: Vec<u8>) -> Self {
Self {
name: name.into(),
value: None,
file_content: Some(content),
filename: Some(filename.into()),
content_type: None,
}
}
#[must_use]
pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
self.content_type = Some(content_type.into());
self
}
}
impl APIRequestBuilder {
pub(crate) fn new(
client: Arc<reqwest::Client>,
method: HttpMethod,
url: impl Into<String>,
base_url: Option<String>,
default_headers: Vec<(String, String)>,
) -> Self {
Self {
client,
method,
url: url.into(),
base_url,
headers: Vec::new(),
default_headers,
query_params: Vec::new(),
body: None,
timeout: None,
disposed: false,
}
}
pub(crate) fn set_disposed(&mut self) {
self.disposed = true;
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((name.into(), value.into()));
self
}
#[must_use]
pub fn headers(mut self, headers: impl IntoIterator<Item = (String, String)>) -> Self {
self.headers.extend(headers);
self
}
#[must_use]
pub fn query<K, V>(mut self, params: &[(K, V)]) -> Self
where
K: AsRef<str>,
V: AsRef<str>,
{
for (key, value) in params {
self.query_params
.push((key.as_ref().to_string(), value.as_ref().to_string()));
}
self
}
#[must_use]
pub fn query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.query_params.push((key.into(), value.into()));
self
}
#[must_use]
pub fn json<T: Serialize>(mut self, data: &T) -> Self {
match serde_json::to_vec(data) {
Ok(bytes) => {
self.body = Some(RequestBody::Json(bytes));
}
Err(e) => {
tracing::error!("Failed to serialize JSON: {}", e);
}
}
self
}
#[must_use]
pub fn form<K, V>(mut self, data: &[(K, V)]) -> Self
where
K: AsRef<str>,
V: AsRef<str>,
{
let form_data: Vec<(String, String)> = data
.iter()
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
.collect();
self.body = Some(RequestBody::Form(form_data));
self
}
#[must_use]
pub fn multipart(mut self, fields: Vec<MultipartField>) -> Self {
self.body = Some(RequestBody::Multipart(fields));
self
}
#[must_use]
pub fn body(mut self, data: Vec<u8>) -> Self {
self.body = Some(RequestBody::Bytes(data));
self
}
#[must_use]
pub fn text(mut self, data: impl Into<String>) -> Self {
self.body = Some(RequestBody::Text(data.into()));
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
fn resolve_url(&self) -> Result<String, APIError> {
if self.url.starts_with("http://") || self.url.starts_with("https://") {
Ok(self.url.clone())
} else if let Some(ref base) = self.base_url {
let base = base.trim_end_matches('/');
let path = self.url.trim_start_matches('/');
Ok(format!("{base}/{path}"))
} else {
Err(APIError::InvalidUrl(format!(
"Relative URL '{}' requires a base URL",
self.url
)))
}
}
pub async fn send(self) -> Result<APIResponse, APIError> {
if self.disposed {
return Err(APIError::Disposed);
}
let url = self.resolve_url()?;
let mut request_builder = self.client.request(self.method.to_reqwest(), &url);
for (name, value) in &self.default_headers {
request_builder = request_builder.header(name.as_str(), value.as_str());
}
for (name, value) in &self.headers {
request_builder = request_builder.header(name.as_str(), value.as_str());
}
if !self.query_params.is_empty() {
request_builder = request_builder.query(&self.query_params);
}
if let Some(timeout) = self.timeout {
request_builder = request_builder.timeout(timeout);
}
match self.body {
Some(RequestBody::Json(bytes)) => {
request_builder = request_builder
.header("Content-Type", "application/json")
.body(bytes);
}
Some(RequestBody::Form(data)) => {
request_builder = request_builder.form(&data);
}
Some(RequestBody::Multipart(fields)) => {
let mut form = Form::new();
for field in fields {
if let Some(value) = field.value {
form = form.text(field.name, value);
} else if let Some(content) = field.file_content {
let mut part = Part::bytes(content);
if let Some(filename) = field.filename {
part = part.file_name(filename);
}
if let Some(content_type) = field.content_type {
part = part.mime_str(&content_type).map_err(|e| {
APIError::BuildError(format!("Invalid content type: {e}"))
})?;
}
form = form.part(field.name, part);
}
}
request_builder = request_builder.multipart(form);
}
Some(RequestBody::Bytes(data)) => {
request_builder = request_builder.body(data);
}
Some(RequestBody::Text(data)) => {
request_builder = request_builder
.header("Content-Type", "text/plain")
.body(data);
}
None => {}
}
let response = request_builder.send().await.map_err(|e| {
if e.is_timeout() {
APIError::Timeout(self.timeout.unwrap_or(Duration::from_secs(30)))
} else {
APIError::Http(e)
}
})?;
Ok(APIResponse::new(response))
}
}
impl std::future::IntoFuture for APIRequestBuilder {
type Output = Result<APIResponse, APIError>;
type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}