use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use reqwest::{Certificate, Client, Identity};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitState};
use barbacane_plugin_sdk::types::base64_body;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TlsConfig {
#[serde(default)]
pub client_cert: Option<PathBuf>,
#[serde(default)]
pub client_key: Option<PathBuf>,
#[serde(default)]
pub ca: Option<PathBuf>,
}
impl TlsConfig {
pub fn is_configured(&self) -> bool {
self.client_cert.is_some() || self.client_key.is_some() || self.ca.is_some()
}
pub fn validate(&self) -> Result<(), TlsConfigError> {
match (&self.client_cert, &self.client_key) {
(Some(_), None) => Err(TlsConfigError::MissingClientKey),
(None, Some(_)) => Err(TlsConfigError::MissingClientCert),
_ => Ok(()),
}
}
fn cache_key(&self) -> TlsCacheKey {
TlsCacheKey {
client_cert: self.client_cert.clone(),
client_key: self.client_key.clone(),
ca: self.ca.clone(),
}
}
}
#[derive(Debug, Error)]
pub enum TlsConfigError {
#[error("client_cert specified but client_key is missing")]
MissingClientKey,
#[error("client_key specified but client_cert is missing")]
MissingClientCert,
#[error("failed to read certificate file: {0}")]
ReadCertificate(#[source] std::io::Error),
#[error("failed to read key file: {0}")]
ReadKey(#[source] std::io::Error),
#[error("failed to read CA file: {0}")]
ReadCa(#[source] std::io::Error),
#[error("failed to parse PEM identity: {0}")]
ParseIdentity(#[source] reqwest::Error),
#[error("failed to parse CA certificate: {0}")]
ParseCaCert(#[source] reqwest::Error),
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct TlsCacheKey {
client_cert: Option<PathBuf>,
client_key: Option<PathBuf>,
ca: Option<PathBuf>,
}
impl Hash for TlsCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.client_cert.hash(state);
self.client_key.hash(state);
self.ca.hash(state);
}
}
#[derive(Clone)]
pub struct HttpClient {
client: Client,
tls_clients: Arc<RwLock<HashMap<TlsCacheKey, Client>>>,
base_config: HttpClientConfig,
circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
default_timeout: Duration,
allow_plaintext: bool,
}
impl HttpClient {
pub fn new(config: HttpClientConfig) -> Result<Self, HttpClientError> {
let client = Client::builder()
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.pool_idle_timeout(config.pool_idle_timeout)
.connect_timeout(config.connect_timeout)
.timeout(config.default_timeout)
.build()
.map_err(HttpClientError::BuildError)?;
let default_timeout = config.default_timeout;
let allow_plaintext = config.allow_plaintext;
Ok(Self {
client,
tls_clients: Arc::new(RwLock::new(HashMap::new())),
base_config: config,
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
default_timeout,
allow_plaintext,
})
}
fn get_or_create_tls_client(&self, tls_config: &TlsConfig) -> Result<Client, HttpClientError> {
let cache_key = tls_config.cache_key();
{
let clients = self.tls_clients.read();
if let Some(client) = clients.get(&cache_key) {
return Ok(client.clone());
}
}
let client = self.build_tls_client(tls_config)?;
{
let mut clients = self.tls_clients.write();
clients.insert(cache_key, client.clone());
}
Ok(client)
}
fn build_tls_client(&self, tls_config: &TlsConfig) -> Result<Client, HttpClientError> {
tls_config.validate().map_err(HttpClientError::TlsConfig)?;
let mut builder = Client::builder()
.pool_max_idle_per_host(self.base_config.pool_max_idle_per_host)
.pool_idle_timeout(self.base_config.pool_idle_timeout)
.connect_timeout(self.base_config.connect_timeout)
.timeout(self.base_config.default_timeout);
if let (Some(cert_path), Some(key_path)) = (&tls_config.client_cert, &tls_config.client_key)
{
let cert_pem = std::fs::read(cert_path)
.map_err(|e| HttpClientError::TlsConfig(TlsConfigError::ReadCertificate(e)))?;
let key_pem = std::fs::read(key_path)
.map_err(|e| HttpClientError::TlsConfig(TlsConfigError::ReadKey(e)))?;
let mut pem = cert_pem;
pem.extend_from_slice(&key_pem);
let identity = Identity::from_pem(&pem)
.map_err(|e| HttpClientError::TlsConfig(TlsConfigError::ParseIdentity(e)))?;
builder = builder.identity(identity);
}
if let Some(ca_path) = &tls_config.ca {
let ca_pem = std::fs::read(ca_path)
.map_err(|e| HttpClientError::TlsConfig(TlsConfigError::ReadCa(e)))?;
let ca_cert = Certificate::from_pem(&ca_pem)
.map_err(|e| HttpClientError::TlsConfig(TlsConfigError::ParseCaCert(e)))?;
builder = builder.add_root_certificate(ca_cert);
}
builder.build().map_err(HttpClientError::BuildError)
}
pub async fn call(&self, request: HttpRequest) -> Result<HttpResponse, HttpClientError> {
self.call_with_tls(request, None).await
}
pub async fn stream_raw(
&self,
request: HttpRequest,
) -> Result<reqwest::Response, HttpClientError> {
let url = request
.url
.parse::<reqwest::Url>()
.map_err(|e| HttpClientError::InvalidUrl(e.to_string()))?;
if url.scheme() == "http" && !self.allow_plaintext {
return Err(HttpClientError::PlaintextNotAllowed);
}
let host = url
.host_str()
.ok_or_else(|| HttpClientError::InvalidUrl("missing host".into()))?
.to_string();
let circuit_state = self.get_circuit_state(&host);
if circuit_state == crate::circuit_breaker::CircuitState::Open {
return Err(HttpClientError::CircuitOpen(host));
}
let method = request
.method
.parse::<reqwest::Method>()
.map_err(|e| HttpClientError::InvalidMethod(e.to_string()))?;
let timeout = request.timeout.unwrap_or(self.default_timeout);
let mut req_builder = self.client.request(method, url).timeout(timeout);
for (key, value) in &request.headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
if let Some(body) = request.body {
req_builder = req_builder.body(body);
}
match req_builder.send().await {
Ok(response) => Ok(response),
Err(e) => {
self.record_failure(&host);
if e.is_timeout() {
Err(HttpClientError::Timeout)
} else if e.is_connect() {
Err(HttpClientError::ConnectionFailed(e.to_string()))
} else {
Err(HttpClientError::RequestFailed(e.to_string()))
}
}
}
}
pub async fn call_with_tls(
&self,
request: HttpRequest,
tls_config: Option<&TlsConfig>,
) -> Result<HttpResponse, HttpClientError> {
let url = request
.url
.parse::<reqwest::Url>()
.map_err(|e| HttpClientError::InvalidUrl(e.to_string()))?;
if url.scheme() == "http" && !self.allow_plaintext {
return Err(HttpClientError::PlaintextNotAllowed);
}
let host = url
.host_str()
.ok_or_else(|| HttpClientError::InvalidUrl("missing host".into()))?
.to_string();
let circuit_state = self.get_circuit_state(&host);
if circuit_state == CircuitState::Open {
return Err(HttpClientError::CircuitOpen(host));
}
let client = match tls_config {
Some(tls) if tls.is_configured() => self.get_or_create_tls_client(tls)?,
_ => self.client.clone(),
};
let method = request
.method
.parse::<reqwest::Method>()
.map_err(|e| HttpClientError::InvalidMethod(e.to_string()))?;
let timeout = request.timeout.unwrap_or(self.default_timeout);
let mut req_builder = client.request(method, url).timeout(timeout);
for (key, value) in &request.headers {
req_builder = req_builder.header(key.as_str(), value.as_str());
}
if let Some(body) = request.body {
req_builder = req_builder.body(body);
}
let result = req_builder.send().await;
match result {
Ok(response) => {
self.record_success(&host);
let status = response.status().as_u16();
let headers: HashMap<String, String> = response
.headers()
.iter()
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|v| (k.as_str().to_lowercase(), v.to_string()))
})
.collect();
let body = response
.bytes()
.await
.map_err(HttpClientError::ResponseReadError)?;
Ok(HttpResponse {
status,
headers,
body: Some(body.to_vec()),
})
}
Err(e) => {
self.record_failure(&host);
if e.is_timeout() {
Err(HttpClientError::Timeout)
} else if e.is_connect() {
Err(HttpClientError::ConnectionFailed(e.to_string()))
} else {
Err(HttpClientError::RequestFailed(e.to_string()))
}
}
}
}
pub fn configure_circuit_breaker(&self, host: &str, config: CircuitBreakerConfig) {
let mut breakers = self.circuit_breakers.write();
breakers.insert(host.to_string(), CircuitBreaker::new(config));
}
fn get_circuit_state(&self, host: &str) -> CircuitState {
let breakers = self.circuit_breakers.read();
breakers
.get(host)
.map(|cb| cb.state())
.unwrap_or(CircuitState::Closed)
}
fn record_success(&self, host: &str) {
let mut breakers = self.circuit_breakers.write();
if let Some(cb) = breakers.get_mut(host) {
cb.record_success();
}
}
fn record_failure(&self, host: &str) {
let mut breakers = self.circuit_breakers.write();
if let Some(cb) = breakers.get_mut(host) {
cb.record_failure();
}
}
}
#[derive(Debug, Clone)]
pub struct HttpClientConfig {
pub pool_max_idle_per_host: usize,
pub pool_idle_timeout: Duration,
pub connect_timeout: Duration,
pub default_timeout: Duration,
pub allow_plaintext: bool,
}
impl Default for HttpClientConfig {
fn default() -> Self {
Self {
pool_max_idle_per_host: 10,
pool_idle_timeout: Duration::from_secs(90),
connect_timeout: Duration::from_secs(10),
default_timeout: Duration::from_secs(30),
allow_plaintext: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpRequest {
pub method: String,
pub url: String,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default, with = "base64_body")]
pub body: Option<Vec<u8>>,
#[serde(default, with = "option_duration_serde")]
pub timeout: Option<Duration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpResponse {
pub status: u16,
pub headers: HashMap<String, String>,
#[serde(default, with = "base64_body")]
pub body: Option<Vec<u8>>,
}
impl HttpResponse {
pub fn error(status: u16, error_type: &str, title: &str, detail: &str) -> Self {
let body = serde_json::json!({
"type": error_type,
"title": title,
"status": status,
"detail": detail
});
let mut headers = HashMap::new();
headers.insert(
"content-type".to_string(),
"application/problem+json".to_string(),
);
Self {
status,
headers,
body: Some(body.to_string().into_bytes()),
}
}
}
#[derive(Debug, Error)]
pub enum HttpClientError {
#[error("failed to build HTTP client: {0}")]
BuildError(#[source] reqwest::Error),
#[error("invalid URL: {0}")]
InvalidUrl(String),
#[error("invalid HTTP method: {0}")]
InvalidMethod(String),
#[error("plaintext HTTP not allowed")]
PlaintextNotAllowed,
#[error("circuit breaker open for host: {0}")]
CircuitOpen(String),
#[error("request timeout")]
Timeout,
#[error("connection failed: {0}")]
ConnectionFailed(String),
#[error("request failed: {0}")]
RequestFailed(String),
#[error("failed to read response: {0}")]
ResponseReadError(#[source] reqwest::Error),
#[error("TLS configuration error: {0}")]
TlsConfig(#[source] TlsConfigError),
}
mod option_duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match duration {
Some(d) => d.as_secs_f64().serialize(serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let opt: Option<f64> = Option::deserialize(deserializer)?;
Ok(opt.map(Duration::from_secs_f64))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = HttpClientConfig::default();
assert_eq!(config.pool_max_idle_per_host, 10);
assert_eq!(config.default_timeout, Duration::from_secs(30));
assert!(!config.allow_plaintext);
}
#[test]
fn test_error_response() {
let resp = HttpResponse::error(
502,
"urn:barbacane:error:upstream-unavailable",
"Bad Gateway",
"Failed to connect to upstream",
);
assert_eq!(resp.status, 502);
assert_eq!(
resp.headers.get("content-type"),
Some(&"application/problem+json".to_string())
);
}
#[test]
fn test_tls_config_default() {
let tls = TlsConfig::default();
assert!(tls.client_cert.is_none());
assert!(tls.client_key.is_none());
assert!(tls.ca.is_none());
assert!(!tls.is_configured());
}
#[test]
fn test_tls_config_is_configured() {
let mut tls = TlsConfig::default();
assert!(!tls.is_configured());
tls.client_cert = Some(PathBuf::from("/path/to/cert.pem"));
assert!(tls.is_configured());
tls.client_cert = None;
tls.ca = Some(PathBuf::from("/path/to/ca.pem"));
assert!(tls.is_configured());
}
#[test]
fn test_tls_config_validate_success() {
let tls = TlsConfig::default();
assert!(tls.validate().is_ok());
let tls = TlsConfig {
client_cert: None,
client_key: None,
ca: Some(PathBuf::from("/path/to/ca.pem")),
};
assert!(tls.validate().is_ok());
let tls = TlsConfig {
client_cert: Some(PathBuf::from("/path/to/cert.pem")),
client_key: Some(PathBuf::from("/path/to/key.pem")),
ca: None,
};
assert!(tls.validate().is_ok());
}
#[test]
fn test_tls_config_validate_missing_key() {
let tls = TlsConfig {
client_cert: Some(PathBuf::from("/path/to/cert.pem")),
client_key: None,
ca: None,
};
let err = tls.validate().unwrap_err();
assert!(matches!(err, TlsConfigError::MissingClientKey));
}
#[test]
fn test_tls_config_validate_missing_cert() {
let tls = TlsConfig {
client_cert: None,
client_key: Some(PathBuf::from("/path/to/key.pem")),
ca: None,
};
let err = tls.validate().unwrap_err();
assert!(matches!(err, TlsConfigError::MissingClientCert));
}
#[test]
fn test_tls_config_serde() {
let json = r#"{
"client_cert": "/etc/certs/client.crt",
"client_key": "/etc/certs/client.key",
"ca": "/etc/certs/ca.crt"
}"#;
let tls: TlsConfig = serde_json::from_str(json).unwrap();
assert_eq!(
tls.client_cert,
Some(PathBuf::from("/etc/certs/client.crt"))
);
assert_eq!(tls.client_key, Some(PathBuf::from("/etc/certs/client.key")));
assert_eq!(tls.ca, Some(PathBuf::from("/etc/certs/ca.crt")));
}
#[test]
fn test_tls_config_serde_partial() {
let json = r#"{"ca": "/etc/certs/ca.crt"}"#;
let tls: TlsConfig = serde_json::from_str(json).unwrap();
assert!(tls.client_cert.is_none());
assert!(tls.client_key.is_none());
assert_eq!(tls.ca, Some(PathBuf::from("/etc/certs/ca.crt")));
}
#[tokio::test]
async fn stream_raw_rejects_invalid_url() {
let client = HttpClient::new(HttpClientConfig::default()).expect("client");
let req = HttpRequest {
method: "GET".into(),
url: "not a url".into(),
headers: Default::default(),
body: None,
timeout: None,
};
assert!(matches!(
client.stream_raw(req).await,
Err(HttpClientError::InvalidUrl(_))
));
}
#[tokio::test]
async fn stream_raw_rejects_plaintext_when_disallowed() {
let config = HttpClientConfig {
allow_plaintext: false,
..Default::default()
};
let client = HttpClient::new(config).expect("client");
let req = HttpRequest {
method: "GET".into(),
url: "http://example.com/api".into(),
headers: Default::default(),
body: None,
timeout: None,
};
assert!(matches!(
client.stream_raw(req).await,
Err(HttpClientError::PlaintextNotAllowed)
));
}
#[tokio::test]
async fn stream_raw_rejects_invalid_method() {
let config = HttpClientConfig {
allow_plaintext: true,
..Default::default()
};
let client = HttpClient::new(config).expect("client");
let req = HttpRequest {
method: "NOT A METHOD!!!".into(),
url: "http://127.0.0.1:1/".into(),
headers: Default::default(),
body: None,
timeout: None,
};
assert!(matches!(
client.stream_raw(req).await,
Err(HttpClientError::InvalidMethod(_))
));
}
#[tokio::test]
async fn stream_raw_connection_refused() {
let config = HttpClientConfig {
allow_plaintext: true,
..Default::default()
};
let client = HttpClient::new(config).expect("client");
let req = HttpRequest {
method: "GET".into(),
url: "http://127.0.0.1:1/".into(), headers: Default::default(),
body: None,
timeout: None,
};
let err = client.stream_raw(req).await.unwrap_err();
assert!(
matches!(
err,
HttpClientError::ConnectionFailed(_) | HttpClientError::RequestFailed(_)
),
"expected network error, got: {err:?}"
);
}
#[tokio::test]
async fn stream_raw_timeout() {
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local_addr");
let config = HttpClientConfig {
allow_plaintext: true,
..Default::default()
};
let client = HttpClient::new(config).expect("client");
let req = HttpRequest {
method: "GET".into(),
url: format!("http://{addr}/slow"),
headers: Default::default(),
body: None,
timeout: Some(Duration::from_millis(50)),
};
let err = client.stream_raw(req).await.unwrap_err();
assert!(
matches!(err, HttpClientError::Timeout),
"expected Timeout, got: {err:?}"
);
drop(listener);
}
#[tokio::test]
async fn stream_raw_successful_request() {
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr = listener.local_addr().expect("local_addr");
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.expect("accept");
let mut buf = [0u8; 1024];
let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
let response = "HTTP/1.1 200 OK\r\ncontent-length: 2\r\n\r\nok";
socket.write_all(response.as_bytes()).await.expect("write");
socket.shutdown().await.expect("shutdown");
});
let config = HttpClientConfig {
allow_plaintext: true,
..Default::default()
};
let client = HttpClient::new(config).expect("client");
let req = HttpRequest {
method: "GET".into(),
url: format!("http://{addr}/"),
headers: Default::default(),
body: None,
timeout: None,
};
let resp = client
.stream_raw(req)
.await
.expect("stream_raw should succeed");
assert_eq!(resp.status(), 200);
let body = resp.text().await.expect("body");
assert_eq!(body, "ok");
}
#[test]
fn test_tls_cache_key_equality() {
let tls1 = TlsConfig {
client_cert: Some(PathBuf::from("/path/to/cert.pem")),
client_key: Some(PathBuf::from("/path/to/key.pem")),
ca: None,
};
let tls2 = TlsConfig {
client_cert: Some(PathBuf::from("/path/to/cert.pem")),
client_key: Some(PathBuf::from("/path/to/key.pem")),
ca: None,
};
let tls3 = TlsConfig {
client_cert: Some(PathBuf::from("/other/cert.pem")),
client_key: Some(PathBuf::from("/path/to/key.pem")),
ca: None,
};
assert_eq!(tls1.cache_key(), tls2.cache_key());
assert_ne!(tls1.cache_key(), tls3.cache_key());
}
#[test]
fn http_request_base64_body_roundtrip() {
let binary_body: Vec<u8> = vec![0x89, 0x50, 0x4E, 0x47, 0xFF, 0xFE, 0x00, 0x01];
let req = HttpRequest {
method: "POST".into(),
url: "https://example.com/upload".into(),
headers: Default::default(),
body: Some(binary_body.clone()),
timeout: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(
!json.contains("\\u0089"),
"body should be base64-encoded, not escaped unicode"
);
let decoded: HttpRequest = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.body.unwrap(), binary_body);
}
#[test]
fn http_response_base64_body_roundtrip() {
let binary_body: Vec<u8> = vec![0x89, 0x50, 0x4E, 0x47, 0xFF, 0xFE, 0x00, 0x01];
let resp = HttpResponse {
status: 200,
headers: Default::default(),
body: Some(binary_body.clone()),
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: HttpResponse = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.body.unwrap(), binary_body);
}
#[test]
fn http_request_null_body_roundtrip() {
let req = HttpRequest {
method: "GET".into(),
url: "https://example.com".into(),
headers: Default::default(),
body: None,
timeout: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains(r#""body":null"#));
let decoded: HttpRequest = serde_json::from_str(&json).unwrap();
assert!(decoded.body.is_none());
}
#[test]
fn http_request_deserialize_from_plugin_json() {
use base64::Engine;
let raw_bytes: Vec<u8> = vec![0x00, 0x01, 0x80, 0xFF];
let b64 = base64::engine::general_purpose::STANDARD.encode(&raw_bytes);
let json = format!(
r#"{{
"method": "POST",
"url": "https://example.com/api",
"body": "{b64}"
}}"#
);
let req: HttpRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req.body.unwrap(), raw_bytes);
}
#[test]
fn http_response_deserialize_from_host_json() {
use base64::Engine;
let raw_bytes: Vec<u8> = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A];
let b64 = base64::engine::general_purpose::STANDARD.encode(&raw_bytes);
let json = format!(
r#"{{
"status": 200,
"headers": {{}},
"body": "{b64}"
}}"#
);
let resp: HttpResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp.body.unwrap(), raw_bytes);
}
}