use reqwest::Client;
use std::sync::OnceLock;
use std::time::Duration;
use super::allowlist::{NetworkAllowlist, UrlMatch};
use crate::error::{Error, Result};
pub const DEFAULT_MAX_RESPONSE_BYTES: usize = 10 * 1024 * 1024;
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const MAX_TIMEOUT_SECS: u64 = 600;
pub const MIN_TIMEOUT_SECS: u64 = 1;
#[async_trait::async_trait]
pub trait HttpHandler: Send + Sync {
async fn request(
&self,
method: &str,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
) -> std::result::Result<Response, String>;
}
pub struct HttpClient {
client: OnceLock<std::result::Result<Client, String>>,
allowlist: NetworkAllowlist,
default_timeout: Duration,
max_response_bytes: usize,
handler: Option<Box<dyn HttpHandler>>,
#[cfg(feature = "bot-auth")]
bot_auth: Option<super::bot_auth::BotAuthConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Method {
Get,
Post,
Put,
Delete,
Head,
Patch,
}
impl Method {
fn as_reqwest(self) -> reqwest::Method {
match self {
Method::Get => reqwest::Method::GET,
Method::Post => reqwest::Method::POST,
Method::Put => reqwest::Method::PUT,
Method::Delete => reqwest::Method::DELETE,
Method::Head => reqwest::Method::HEAD,
Method::Patch => reqwest::Method::PATCH,
}
}
}
#[derive(Debug)]
pub struct Response {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl Response {
pub fn body_string(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
}
impl HttpClient {
pub fn new(allowlist: NetworkAllowlist) -> Self {
Self::with_config(
allowlist,
Duration::from_secs(DEFAULT_TIMEOUT_SECS),
DEFAULT_MAX_RESPONSE_BYTES,
)
}
pub fn with_timeout(allowlist: NetworkAllowlist, timeout: Duration) -> Self {
Self::with_config(allowlist, timeout, DEFAULT_MAX_RESPONSE_BYTES)
}
pub fn with_config(
allowlist: NetworkAllowlist,
timeout: Duration,
max_response_bytes: usize,
) -> Self {
Self {
client: OnceLock::new(),
allowlist,
default_timeout: timeout,
max_response_bytes,
handler: None,
#[cfg(feature = "bot-auth")]
bot_auth: None,
}
}
pub fn set_handler(&mut self, handler: Box<dyn HttpHandler>) {
self.handler = Some(handler);
}
#[cfg(feature = "bot-auth")]
pub fn set_bot_auth(&mut self, config: super::bot_auth::BotAuthConfig) {
self.bot_auth = Some(config);
}
#[cfg(feature = "bot-auth")]
fn bot_auth_headers(&self, url: &str) -> Vec<(String, String)> {
let Some(ref bot_auth) = self.bot_auth else {
return Vec::new();
};
let Ok(parsed) = url::Url::parse(url) else {
return Vec::new();
};
let Some(authority) = parsed.host_str() else {
return Vec::new();
};
match bot_auth.sign_request(authority) {
Ok(headers) => {
let mut result = vec![
("signature".to_string(), headers.signature),
("signature-input".to_string(), headers.signature_input),
];
if let Some(fqdn) = headers.signature_agent {
result.push(("signature-agent".to_string(), fqdn));
}
result
}
Err(_e) => {
Vec::new()
}
}
}
fn client(&self) -> Result<&Client> {
let client = self
.client
.get_or_init(|| build_client(self.default_timeout, None));
client
.as_ref()
.map_err(|err| Error::Internal(format!("failed to build HTTP client: {err}")))
}
pub async fn get(&self, url: &str) -> Result<Response> {
self.request(Method::Get, url, None).await
}
pub async fn post(&self, url: &str, body: Option<&[u8]>) -> Result<Response> {
self.request(Method::Post, url, body).await
}
pub async fn put(&self, url: &str, body: Option<&[u8]>) -> Result<Response> {
self.request(Method::Put, url, body).await
}
pub async fn delete(&self, url: &str) -> Result<Response> {
self.request(Method::Delete, url, None).await
}
pub async fn request(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
) -> Result<Response> {
self.request_with_headers(method, url, body, &[]).await
}
pub async fn request_with_headers(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
) -> Result<Response> {
match self.allowlist.check(url) {
UrlMatch::Allowed => {}
UrlMatch::Blocked { reason } => {
return Err(Error::Network(format!("access denied: {}", reason)));
}
UrlMatch::Invalid { reason } => {
return Err(Error::Network(format!("invalid URL: {}", reason)));
}
}
#[cfg(feature = "bot-auth")]
let signing_headers = self.bot_auth_headers(url);
#[cfg(not(feature = "bot-auth"))]
let signing_headers: Vec<(String, String)> = Vec::new();
if let Some(handler) = &self.handler {
let method_str = match method {
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Delete => "DELETE",
Method::Head => "HEAD",
Method::Patch => "PATCH",
};
if signing_headers.is_empty() {
return handler
.request(method_str, url, body, headers)
.await
.map_err(Error::Network);
}
let mut all_headers: Vec<(String, String)> = headers.to_vec();
all_headers.extend(signing_headers);
return handler
.request(method_str, url, body, &all_headers)
.await
.map_err(Error::Network);
}
let mut request = self.client()?.request(method.as_reqwest(), url);
for (name, value) in headers {
request = request.header(name.as_str(), value.as_str());
}
for (name, value) in &signing_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(body_data) = body {
request = request.body(body_data.to_vec());
}
let response = request
.send()
.await
.map_err(|e| Error::network_sanitized("request failed", &e))?;
let status = response.status().as_u16();
let resp_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
if let Some(content_length) = response.content_length()
&& usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes
{
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
content_length, self.max_response_bytes
)));
}
let body = self.read_body_with_limit(response).await?;
Ok(Response {
status,
headers: resp_headers,
body,
})
}
async fn read_body_with_limit(&self, response: reqwest::Response) -> Result<Vec<u8>> {
use futures::StreamExt;
let mut body = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e| Error::network_sanitized("failed to read response chunk", &e))?;
if body.len() + chunk.len() > self.max_response_bytes {
return Err(Error::Network(format!(
"response too large: exceeded {} bytes limit",
self.max_response_bytes
)));
}
body.extend_from_slice(&chunk);
}
Ok(body)
}
pub async fn head(&self, url: &str) -> Result<Response> {
self.request(Method::Head, url, None).await
}
pub fn max_response_bytes(&self) -> usize {
self.max_response_bytes
}
pub async fn request_with_timeout(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
timeout_secs: Option<u64>,
) -> Result<Response> {
self.request_with_timeouts(method, url, body, headers, timeout_secs, None)
.await
}
pub async fn request_with_timeouts(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
timeout_secs: Option<u64>,
connect_timeout_secs: Option<u64>,
) -> Result<Response> {
match self.allowlist.check(url) {
UrlMatch::Allowed => {}
UrlMatch::Blocked { reason } => {
return Err(Error::Network(format!("access denied: {}", reason)));
}
UrlMatch::Invalid { reason } => {
return Err(Error::Network(format!("invalid URL: {}", reason)));
}
}
#[cfg(feature = "bot-auth")]
let signing_headers = self.bot_auth_headers(url);
#[cfg(not(feature = "bot-auth"))]
let signing_headers: Vec<(String, String)> = Vec::new();
if let Some(handler) = &self.handler {
let method_str = match method {
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Delete => "DELETE",
Method::Head => "HEAD",
Method::Patch => "PATCH",
};
if signing_headers.is_empty() {
return handler
.request(method_str, url, body, headers)
.await
.map_err(Error::Network);
}
let mut all_headers: Vec<(String, String)> = headers.to_vec();
all_headers.extend(signing_headers);
return handler
.request(method_str, url, body, &all_headers)
.await
.map_err(Error::Network);
}
let client = if timeout_secs.is_some() || connect_timeout_secs.is_some() {
let clamp_timeout = |secs: u64| secs.clamp(MIN_TIMEOUT_SECS, MAX_TIMEOUT_SECS);
let timeout = timeout_secs.map_or(Duration::from_secs(DEFAULT_TIMEOUT_SECS), |s| {
Duration::from_secs(clamp_timeout(s))
});
let connect_timeout = connect_timeout_secs.map_or_else(
|| std::cmp::min(timeout, Duration::from_secs(10)),
|s| Duration::from_secs(clamp_timeout(s)),
);
build_client(timeout, Some(connect_timeout))
.map_err(|e| Error::network_sanitized("failed to create client", &e))?
} else {
self.client()?.clone()
};
let mut request = client.request(method.as_reqwest(), url);
for (name, value) in headers {
request = request.header(name.as_str(), value.as_str());
}
for (name, value) in &signing_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(body_data) = body {
request = request.body(body_data.to_vec());
}
let response = request.send().await.map_err(|e| {
if e.is_timeout() {
Error::Network("operation timed out".to_string())
} else {
Error::network_sanitized("request failed", &e)
}
})?;
let status = response.status().as_u16();
let resp_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
if let Some(content_length) = response.content_length()
&& usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes
{
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
content_length, self.max_response_bytes
)));
}
let body = self.read_body_with_limit(response).await?;
Ok(Response {
status,
headers: resp_headers,
body,
})
}
}
fn build_client(
timeout: Duration,
connect_timeout: Option<Duration>,
) -> std::result::Result<Client, String> {
Client::builder()
.timeout(timeout)
.connect_timeout(connect_timeout.unwrap_or(Duration::from_secs(10)))
.user_agent("bashkit/0.1.2")
.redirect(reqwest::redirect::Policy::none())
.no_gzip()
.no_brotli()
.no_deflate()
.build()
.map_err(|e| e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_blocked_by_empty_allowlist() {
let client = HttpClient::new(NetworkAllowlist::new());
assert!(client.client.get().is_none());
let result = client.get("https://example.com").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
assert!(client.client.get().is_none());
}
#[test]
fn test_default_client_initializes_on_first_use() {
let client = HttpClient::new(NetworkAllowlist::allow_all());
assert!(client.client.get().is_none());
client.client().expect("client");
assert!(client.client.get().is_some());
}
#[tokio::test]
async fn test_blocked_by_allowlist() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let client = HttpClient::new(allowlist);
let result = client.get("https://blocked.com").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_blocked_by_allowlist() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeout(Method::Get, "https://example.com", None, &[], Some(5))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_none_uses_default() {
let allowlist = NetworkAllowlist::new().allow("https://blocked.com");
let client = HttpClient::new(allowlist);
let result = client
.request_with_timeout(Method::Get, "https://blocked.example.com", None, &[], None)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_validates_url() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let client = HttpClient::new(allowlist);
let result = client
.request_with_timeout(Method::Get, "not-a-url", None, &[], Some(10))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_request_with_timeouts_both_params() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeouts(
Method::Get,
"https://example.com",
None,
&[],
Some(30),
Some(10),
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeouts_connect_only() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeouts(Method::Get, "https://example.com", None, &[], None, Some(5))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[test]
fn test_u64_to_usize_no_truncation() {
let large: u64 = 5_368_709_120; let result = usize::try_from(large).unwrap_or(usize::MAX);
assert!(result >= large.min(usize::MAX as u64) as usize);
}
}