use reqwest::Client;
use std::sync::OnceLock;
use std::time::Duration;
use super::allowlist::{NetworkAllowlist, UrlMatch, is_private_ip};
use crate::error::{Error, Result};
pub const DEFAULT_MAX_RESPONSE_BYTES: usize = 10 * 1024 * 1024;
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const MAX_TIMEOUT_SECS: u64 = 600;
pub const MIN_TIMEOUT_SECS: u64 = 1;
#[async_trait::async_trait]
pub trait HttpHandler: Send + Sync {
async fn request(
&self,
method: &str,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
) -> std::result::Result<Response, String>;
}
pub struct HttpClient {
client: OnceLock<std::result::Result<Client, String>>,
allowlist: NetworkAllowlist,
default_timeout: Duration,
max_response_bytes: usize,
handler: Option<Box<dyn HttpHandler>>,
#[cfg(feature = "bot-auth")]
bot_auth: Option<super::bot_auth::BotAuthConfig>,
before_http: Vec<crate::hooks::Interceptor<crate::hooks::HttpRequestEvent>>,
after_http: Vec<crate::hooks::Interceptor<crate::hooks::HttpResponseEvent>>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Method {
Get,
Post,
Put,
Delete,
Head,
Patch,
}
impl Method {
fn as_reqwest(self) -> reqwest::Method {
match self {
Method::Get => reqwest::Method::GET,
Method::Post => reqwest::Method::POST,
Method::Put => reqwest::Method::PUT,
Method::Delete => reqwest::Method::DELETE,
Method::Head => reqwest::Method::HEAD,
Method::Patch => reqwest::Method::PATCH,
}
}
fn as_str(self) -> &'static str {
match self {
Method::Get => "GET",
Method::Post => "POST",
Method::Put => "PUT",
Method::Delete => "DELETE",
Method::Head => "HEAD",
Method::Patch => "PATCH",
}
}
}
#[derive(Debug)]
pub struct Response {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl Response {
pub fn body_string(&self) -> String {
String::from_utf8_lossy(&self.body).into_owned()
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
}
impl HttpClient {
pub fn new(allowlist: NetworkAllowlist) -> Self {
Self::with_config(
allowlist,
Duration::from_secs(DEFAULT_TIMEOUT_SECS),
DEFAULT_MAX_RESPONSE_BYTES,
)
}
pub fn with_timeout(allowlist: NetworkAllowlist, timeout: Duration) -> Self {
Self::with_config(allowlist, timeout, DEFAULT_MAX_RESPONSE_BYTES)
}
pub fn with_config(
allowlist: NetworkAllowlist,
timeout: Duration,
max_response_bytes: usize,
) -> Self {
Self {
client: OnceLock::new(),
allowlist,
default_timeout: timeout,
max_response_bytes,
handler: None,
#[cfg(feature = "bot-auth")]
bot_auth: None,
before_http: Vec::new(),
after_http: Vec::new(),
}
}
pub fn set_handler(&mut self, handler: Box<dyn HttpHandler>) {
self.handler = Some(handler);
}
#[cfg(feature = "bot-auth")]
pub fn set_bot_auth(&mut self, config: super::bot_auth::BotAuthConfig) {
self.bot_auth = Some(config);
}
#[cfg(feature = "bot-auth")]
fn bot_auth_headers(&self, method: Method, url: &str) -> Vec<(String, String)> {
let Some(ref bot_auth) = self.bot_auth else {
return Vec::new();
};
let Ok(parsed) = url::Url::parse(url) else {
return Vec::new();
};
match bot_auth.sign_request(method.as_str(), parsed.as_str()) {
Ok(headers) => {
let mut result = vec![
("signature".to_string(), headers.signature),
("signature-input".to_string(), headers.signature_input),
];
if let Some(fqdn) = headers.signature_agent {
result.push(("signature-agent".to_string(), fqdn));
}
result
}
Err(_e) => {
Vec::new()
}
}
}
pub fn set_before_http(
&mut self,
hooks: Vec<crate::hooks::Interceptor<crate::hooks::HttpRequestEvent>>,
) {
self.before_http = hooks;
}
pub fn set_after_http(
&mut self,
hooks: Vec<crate::hooks::Interceptor<crate::hooks::HttpResponseEvent>>,
) {
self.after_http = hooks;
}
fn fire_before_http(
&self,
event: crate::hooks::HttpRequestEvent,
) -> Option<crate::hooks::HttpRequestEvent> {
if self.before_http.is_empty() {
return Some(event);
}
let mut current = event;
for hook in &self.before_http {
match hook(current) {
crate::hooks::HookAction::Continue(e) => current = e,
crate::hooks::HookAction::Cancel(_) => return None,
}
}
Some(current)
}
fn fire_after_http(&self, event: crate::hooks::HttpResponseEvent) {
if self.after_http.is_empty() {
return;
}
let mut current = event;
for hook in &self.after_http {
match hook(current) {
crate::hooks::HookAction::Continue(e) => current = e,
crate::hooks::HookAction::Cancel(_) => return,
}
}
}
fn client(&self) -> Result<&Client> {
let client = self
.client
.get_or_init(|| build_client(self.default_timeout, None));
client
.as_ref()
.map_err(|err| Error::Internal(format!("failed to build HTTP client: {err}")))
}
pub async fn get(&self, url: &str) -> Result<Response> {
self.request(Method::Get, url, None).await
}
pub async fn post(&self, url: &str, body: Option<&[u8]>) -> Result<Response> {
self.request(Method::Post, url, body).await
}
pub async fn put(&self, url: &str, body: Option<&[u8]>) -> Result<Response> {
self.request(Method::Put, url, body).await
}
pub async fn delete(&self, url: &str) -> Result<Response> {
self.request(Method::Delete, url, None).await
}
pub async fn request(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
) -> Result<Response> {
self.request_with_headers(method, url, body, &[]).await
}
fn check_allowlist(&self, url: &str) -> Result<()> {
match self.allowlist.check(url) {
UrlMatch::Allowed => Ok(()),
UrlMatch::Blocked { reason } => {
Err(Error::Network(format!("access denied: {}", reason)))
}
UrlMatch::Invalid { reason } => Err(Error::Network(format!("invalid URL: {}", reason))),
}
}
async fn enforce_url_security(&self, url: &str) -> Result<()> {
self.check_allowlist(url)?;
if self.allowlist.is_blocking_private_ips() {
self.check_private_ip(url).await?;
}
Ok(())
}
async fn check_private_ip(&self, url: &str) -> Result<()> {
let parsed = match url::Url::parse(url) {
Ok(p) => p,
Err(_) => return Ok(()),
};
let Some(host) = parsed.host_str() else {
return Ok(());
};
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_ip(&ip) {
return Err(Error::Network(format!(
"access denied: {} is a private IP (SSRF protection)",
host
)));
}
} else {
let port = parsed
.port()
.unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 });
let addr = format!("{}:{}", host, port);
if let Ok(addrs) = tokio::net::lookup_host(&addr).await {
for a in addrs {
if is_private_ip(&a.ip()) {
return Err(Error::Network(format!(
"access denied: {} resolves to private IP {} (SSRF protection)",
host,
a.ip()
)));
}
}
}
}
Ok(())
}
pub async fn request_with_headers(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
) -> Result<Response> {
self.enforce_url_security(url).await?;
let (url, headers) = if !self.before_http.is_empty() {
let event = crate::hooks::HttpRequestEvent {
method: method.as_str().to_string(),
url: url.to_string(),
headers: headers.to_vec(),
};
match self.fire_before_http(event) {
Some(modified) => (
std::borrow::Cow::Owned(modified.url),
std::borrow::Cow::Owned(modified.headers),
),
None => {
return Err(Error::Network("cancelled by before_http hook".to_string()));
}
}
} else {
(
std::borrow::Cow::Borrowed(url),
std::borrow::Cow::Borrowed(headers),
)
};
let url: &str = &url;
let headers: &[(String, String)] = &headers;
self.enforce_url_security(url).await?;
#[cfg(feature = "bot-auth")]
let signing_headers = self.bot_auth_headers(method, url);
#[cfg(not(feature = "bot-auth"))]
let signing_headers: Vec<(String, String)> = Vec::new();
if let Some(handler) = &self.handler {
let method_str = method.as_str();
let mut all_headers: Vec<(String, String)> = headers.to_vec();
all_headers.extend(signing_headers);
let response = tokio::time::timeout(
self.default_timeout,
handler.request(method_str, url, body, &all_headers),
)
.await
.map_err(|_| Error::Network("operation timed out".to_string()))?
.map_err(Error::Network)?;
if response.body.len() > self.max_response_bytes {
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
response.body.len(),
self.max_response_bytes
)));
}
self.fire_after_http(crate::hooks::HttpResponseEvent {
url: url.to_string(),
status: response.status,
headers: response.headers.clone(),
});
return Ok(response);
}
let mut request = self.client()?.request(method.as_reqwest(), url);
for (name, value) in headers {
request = request.header(name.as_str(), value.as_str());
}
for (name, value) in &signing_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(body_data) = body {
request = request.body(body_data.to_vec());
}
let response = request
.send()
.await
.map_err(|e| Error::network_sanitized("request failed", &e))?;
let status = response.status().as_u16();
let resp_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
self.fire_after_http(crate::hooks::HttpResponseEvent {
url: url.to_string(),
status,
headers: resp_headers.clone(),
});
if let Some(content_length) = response.content_length()
&& usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes
{
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
content_length, self.max_response_bytes
)));
}
let body = self.read_body_with_limit(response).await?;
Ok(Response {
status,
headers: resp_headers,
body,
})
}
async fn read_body_with_limit(&self, response: reqwest::Response) -> Result<Vec<u8>> {
use futures_util::StreamExt;
let mut body = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e| Error::network_sanitized("failed to read response chunk", &e))?;
if body.len() + chunk.len() > self.max_response_bytes {
return Err(Error::Network(format!(
"response too large: exceeded {} bytes limit",
self.max_response_bytes
)));
}
body.extend_from_slice(&chunk);
}
Ok(body)
}
pub async fn head(&self, url: &str) -> Result<Response> {
self.request(Method::Head, url, None).await
}
pub fn max_response_bytes(&self) -> usize {
self.max_response_bytes
}
pub async fn request_with_timeout(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
timeout_secs: Option<u64>,
) -> Result<Response> {
self.request_with_timeouts(method, url, body, headers, timeout_secs, None)
.await
}
pub async fn request_with_timeouts(
&self,
method: Method,
url: &str,
body: Option<&[u8]>,
headers: &[(String, String)],
timeout_secs: Option<u64>,
connect_timeout_secs: Option<u64>,
) -> Result<Response> {
self.enforce_url_security(url).await?;
let (url, headers) = if !self.before_http.is_empty() {
let event = crate::hooks::HttpRequestEvent {
method: method.as_str().to_string(),
url: url.to_string(),
headers: headers.to_vec(),
};
match self.fire_before_http(event) {
Some(modified) => (
std::borrow::Cow::Owned(modified.url),
std::borrow::Cow::Owned(modified.headers),
),
None => {
return Err(Error::Network("cancelled by before_http hook".to_string()));
}
}
} else {
(
std::borrow::Cow::Borrowed(url),
std::borrow::Cow::Borrowed(headers),
)
};
let url: &str = &url;
let headers: &[(String, String)] = &headers;
self.enforce_url_security(url).await?;
#[cfg(feature = "bot-auth")]
let signing_headers = self.bot_auth_headers(method, url);
#[cfg(not(feature = "bot-auth"))]
let signing_headers: Vec<(String, String)> = Vec::new();
let clamp_timeout = |secs: u64| secs.clamp(MIN_TIMEOUT_SECS, MAX_TIMEOUT_SECS);
let request_timeout = timeout_secs.map_or(Duration::from_secs(DEFAULT_TIMEOUT_SECS), |s| {
Duration::from_secs(clamp_timeout(s))
});
if let Some(handler) = &self.handler {
let method_str = method.as_str();
let mut all_headers: Vec<(String, String)> = headers.to_vec();
all_headers.extend(signing_headers);
let response = tokio::time::timeout(
request_timeout,
handler.request(method_str, url, body, &all_headers),
)
.await
.map_err(|_| Error::Network("operation timed out".to_string()))?
.map_err(Error::Network)?;
if response.body.len() > self.max_response_bytes {
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
response.body.len(),
self.max_response_bytes
)));
}
self.fire_after_http(crate::hooks::HttpResponseEvent {
url: url.to_string(),
status: response.status,
headers: response.headers.clone(),
});
return Ok(response);
}
let client = if timeout_secs.is_some() || connect_timeout_secs.is_some() {
let connect_timeout = connect_timeout_secs.map_or_else(
|| std::cmp::min(request_timeout, Duration::from_secs(10)),
|s| Duration::from_secs(clamp_timeout(s)),
);
build_client(request_timeout, Some(connect_timeout))
.map_err(|e| Error::network_sanitized("failed to create client", &e))?
} else {
self.client()?.clone()
};
let mut request = client.request(method.as_reqwest(), url);
for (name, value) in headers {
request = request.header(name.as_str(), value.as_str());
}
for (name, value) in &signing_headers {
request = request.header(name.as_str(), value.as_str());
}
if let Some(body_data) = body {
request = request.body(body_data.to_vec());
}
let response = request.send().await.map_err(|e| {
if e.is_timeout() {
Error::Network("operation timed out".to_string())
} else {
Error::network_sanitized("request failed", &e)
}
})?;
let status = response.status().as_u16();
let resp_headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
self.fire_after_http(crate::hooks::HttpResponseEvent {
url: url.to_string(),
status,
headers: resp_headers.clone(),
});
if let Some(content_length) = response.content_length()
&& usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes
{
return Err(Error::Network(format!(
"response too large: {} bytes (max: {} bytes)",
content_length, self.max_response_bytes
)));
}
let body = self.read_body_with_limit(response).await?;
Ok(Response {
status,
headers: resp_headers,
body,
})
}
}
fn install_default_crypto_provider() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
fn build_client(
timeout: Duration,
connect_timeout: Option<Duration>,
) -> std::result::Result<Client, String> {
install_default_crypto_provider();
Client::builder()
.timeout(timeout)
.connect_timeout(connect_timeout.unwrap_or(Duration::from_secs(10)))
.user_agent("bashkit/0.1.2")
.redirect(reqwest::redirect::Policy::none())
.no_gzip()
.no_brotli()
.no_deflate()
.no_proxy()
.build()
.map_err(|e| e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration as StdDuration;
use tokio::time::sleep;
struct StaticHandler {
response: Response,
}
#[async_trait::async_trait]
impl HttpHandler for StaticHandler {
async fn request(
&self,
_method: &str,
_url: &str,
_body: Option<&[u8]>,
_headers: &[(String, String)],
) -> std::result::Result<Response, String> {
Ok(Response {
status: self.response.status,
headers: self.response.headers.clone(),
body: self.response.body.clone(),
})
}
}
struct SlowHandler {
delay: StdDuration,
}
#[async_trait::async_trait]
impl HttpHandler for SlowHandler {
async fn request(
&self,
_method: &str,
_url: &str,
_body: Option<&[u8]>,
_headers: &[(String, String)],
) -> std::result::Result<Response, String> {
sleep(self.delay).await;
Ok(Response {
status: 200,
headers: vec![],
body: b"ok".to_vec(),
})
}
}
#[tokio::test]
async fn test_blocked_by_empty_allowlist() {
let client = HttpClient::new(NetworkAllowlist::new());
assert!(client.client.get().is_none());
let result = client.get("https://example.com").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
assert!(client.client.get().is_none());
}
#[test]
fn test_default_client_initializes_on_first_use() {
let client = HttpClient::new(NetworkAllowlist::allow_all());
assert!(client.client.get().is_none());
client.client().expect("client");
assert!(client.client.get().is_some());
}
#[tokio::test]
async fn test_blocked_by_allowlist() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let client = HttpClient::new(allowlist);
let result = client.get("https://blocked.com").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_blocked_by_allowlist() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeout(Method::Get, "https://example.com", None, &[], Some(5))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_none_uses_default() {
let allowlist = NetworkAllowlist::new().allow("https://blocked.com");
let client = HttpClient::new(allowlist);
let result = client
.request_with_timeout(Method::Get, "https://blocked.example.com", None, &[], None)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeout_validates_url() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let client = HttpClient::new(allowlist);
let result = client
.request_with_timeout(Method::Get, "not-a-url", None, &[], Some(10))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_request_with_timeouts_both_params() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeouts(
Method::Get,
"https://example.com",
None,
&[],
Some(30),
Some(10),
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_request_with_timeouts_connect_only() {
let client = HttpClient::new(NetworkAllowlist::new());
let result = client
.request_with_timeouts(Method::Get, "https://example.com", None, &[], None, Some(5))
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[test]
fn test_u64_to_usize_no_truncation() {
let large: u64 = 5_368_709_120; let result = usize::try_from(large).unwrap_or(usize::MAX);
assert!(result >= large.min(usize::MAX as u64) as usize);
}
#[test]
fn test_build_client_uses_no_proxy() {
let client = build_client(Duration::from_secs(30), None);
assert!(client.is_ok(), "build_client should succeed with no_proxy");
}
#[test]
fn test_build_client_installs_ring_crypto_provider() {
let _ = build_client(Duration::from_secs(30), None);
let second_install = rustls::crypto::ring::default_provider().install_default();
assert!(
second_install.is_err(),
"build_client must install a default crypto provider before \
returning, otherwise the first HTTPS request panics"
);
}
#[test]
fn test_install_default_crypto_provider_is_idempotent() {
install_default_crypto_provider();
install_default_crypto_provider();
install_default_crypto_provider();
}
#[tokio::test]
async fn test_custom_handler_enforces_max_response_bytes() {
let mut client =
HttpClient::with_config(NetworkAllowlist::allow_all(), Duration::from_secs(30), 4);
client.set_handler(Box::new(StaticHandler {
response: Response {
status: 200,
headers: vec![],
body: b"too-large".to_vec(),
},
}));
let result = client.get("https://example.com").await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("response too large")
);
}
#[tokio::test]
async fn test_before_http_hook_cannot_bypass_allowlist_request_with_headers() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let mut client = HttpClient::new(allowlist);
client.set_handler(Box::new(StaticHandler {
response: Response {
status: 200,
headers: vec![],
body: b"ok".to_vec(),
},
}));
client.set_before_http(vec![Box::new(|mut event| {
event.url = "https://blocked.com".to_string();
crate::hooks::HookAction::Continue(event)
})]);
let result = client
.request_with_headers(Method::Get, "https://allowed.com", None, &[])
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_before_http_hook_cannot_bypass_allowlist_request_with_timeouts() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let mut client = HttpClient::new(allowlist);
client.set_handler(Box::new(StaticHandler {
response: Response {
status: 200,
headers: vec![],
body: b"ok".to_vec(),
},
}));
client.set_before_http(vec![Box::new(|mut event| {
event.url = "https://blocked.com".to_string();
crate::hooks::HookAction::Continue(event)
})]);
let result = client
.request_with_timeouts(Method::Get, "https://allowed.com", None, &[], Some(5), None)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("access denied"));
}
#[tokio::test]
async fn test_custom_handler_enforces_request_timeout() {
let mut client = HttpClient::with_config(
NetworkAllowlist::allow_all(),
Duration::from_secs(30),
DEFAULT_MAX_RESPONSE_BYTES,
);
client.set_handler(Box::new(SlowHandler {
delay: StdDuration::from_millis(1200),
}));
let result = client
.request_with_timeouts(Method::Get, "https://example.com", None, &[], Some(1), None)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("operation timed out")
);
}
}