use std::time::Duration;
#[derive(Debug, Clone, Default)]
pub struct RequestContextOptions {
pub base_url: Option<String>,
pub extra_http_headers: Vec<(String, String)>,
pub timeout: Option<Duration>,
pub ignore_https_errors: bool,
}
#[derive(Debug, Clone, Default)]
pub struct RequestOptions {
pub method: Option<String>,
pub headers: Option<Vec<(String, String)>>,
pub data: Option<Vec<u8>>,
pub json_data: Option<serde_json::Value>,
pub form: Option<Vec<(String, String)>>,
pub params: Option<Vec<(String, String)>>,
pub timeout: Option<Duration>,
pub fail_on_status_code: Option<bool>,
pub max_redirects: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct APIResponse {
status_code: u16,
status_text: String,
response_url: String,
response_headers: Vec<(String, String)>,
body_bytes: bytes::Bytes,
}
impl APIResponse {
pub fn status(&self) -> u16 {
self.status_code
}
pub fn status_text(&self) -> &str {
&self.status_text
}
pub fn url(&self) -> &str {
&self.response_url
}
pub fn ok(&self) -> bool {
(200..300).contains(&self.status_code)
}
pub fn headers(&self) -> &[(String, String)] {
&self.response_headers
}
pub fn header(&self, name: &str) -> Option<&str> {
let lower = name.to_lowercase();
self
.response_headers
.iter()
.find(|(k, _)| k.to_lowercase() == lower)
.map(|(_, v)| v.as_str())
}
pub fn text(&self) -> Result<String, String> {
String::from_utf8(self.body_bytes.to_vec()).map_err(|e| format!("response body is not UTF-8: {e}"))
}
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, String> {
serde_json::from_slice(&self.body_bytes).map_err(|e| format!("JSON parse error: {e}"))
}
pub fn json_value(&self) -> Result<serde_json::Value, String> {
self.json()
}
pub fn body(&self) -> &[u8] {
&self.body_bytes
}
pub fn dispose(self) {
drop(self);
}
}
#[derive(Clone)]
pub struct APIRequestContext {
client: reqwest::Client,
base_url: Option<String>,
extra_headers: Vec<(String, String)>,
default_timeout: Duration,
}
impl APIRequestContext {
#[must_use]
pub fn new(options: RequestContextOptions) -> Self {
let mut builder = reqwest::Client::builder().cookie_store(true);
if options.ignore_https_errors {
builder = builder.danger_accept_invalid_certs(true);
}
let client = builder.build().unwrap_or_else(|_| reqwest::Client::new());
let default_timeout = options.timeout.unwrap_or(Duration::from_secs(30));
Self {
client,
base_url: options.base_url,
extra_headers: options.extra_http_headers,
default_timeout,
}
}
fn resolve_url(&self, url: &str) -> String {
if url.starts_with("http://") || url.starts_with("https://") {
return url.to_string();
}
match &self.base_url {
Some(base) => {
let base = base.trim_end_matches('/');
if url.starts_with('/') {
format!("{base}{url}")
} else {
format!("{base}/{url}")
}
},
None => url.to_string(),
}
}
pub async fn get(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("GET".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn post(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("POST".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn put(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("PUT".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn delete(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("DELETE".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn patch(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("PATCH".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn head(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
self
.fetch(
url,
Some(RequestOptions {
method: Some("HEAD".into()),
..options.unwrap_or_default()
}),
)
.await
}
pub async fn fetch(&self, url: &str, options: Option<RequestOptions>) -> Result<APIResponse, String> {
let opts = options.unwrap_or_default();
let method_str = opts.method.as_deref().unwrap_or("GET");
let method: reqwest::Method = method_str
.parse()
.map_err(|_| format!("invalid HTTP method: {method_str}"))?;
let resolved_url = self.resolve_url(url);
let mut builder = self.client.request(method, &resolved_url);
for (k, v) in &self.extra_headers {
builder = builder.header(k, v);
}
if let Some(headers) = &opts.headers {
for (k, v) in headers {
builder = builder.header(k, v);
}
}
if let Some(params) = &opts.params {
builder = builder.query(params);
}
if let Some(json) = &opts.json_data {
builder = builder.json(json);
} else if let Some(form) = &opts.form {
builder = builder.form(form);
} else if let Some(data) = &opts.data {
builder = builder.body(data.clone());
}
let timeout = opts.timeout.unwrap_or(self.default_timeout);
builder = builder.timeout(timeout);
if let Some(max) = opts.max_redirects {
let _ = max;
}
let response = builder
.send()
.await
.map_err(|e| format!("request to {resolved_url} failed: {e}"))?;
let status_code = response.status().as_u16();
let status_text = response.status().canonical_reason().unwrap_or("Unknown").to_string();
let response_url = response.url().to_string();
let response_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let body_bytes = response.bytes().await.map_err(|e| format!("read response body: {e}"))?;
let api_response = APIResponse {
status_code,
status_text,
response_url,
response_headers,
body_bytes,
};
if opts.fail_on_status_code.unwrap_or(false) && !api_response.ok() {
return Err(format!(
"{} {resolved_url} failed: {} {}",
method_str,
api_response.status(),
api_response.status_text()
));
}
Ok(api_response)
}
pub fn dispose(self) {
drop(self);
}
}