use super::auth::HttpAuthProvider;
use super::{join_url, HttpConnector, HttpConnectorError, Operation};
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct HttpConfig {
#[serde(default = "default_timeout")]
pub timeout_seconds: u64,
#[serde(default = "default_retries")]
pub retries: u32,
#[serde(default = "default_retry_backoff")]
pub retry_backoff_ms: u64,
#[serde(default = "default_user_agent")]
pub user_agent: String,
#[serde(default)]
pub default_headers: HashMap<String, String>,
}
fn default_timeout() -> u64 {
30
}
fn default_retries() -> u32 {
3
}
fn default_retry_backoff() -> u64 {
1000
}
fn default_user_agent() -> String {
format!("pmcp-server-toolkit/{}", env!("CARGO_PKG_VERSION"))
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
timeout_seconds: default_timeout(),
retries: default_retries(),
retry_backoff_ms: default_retry_backoff(),
user_agent: default_user_agent(),
default_headers: HashMap::new(),
}
}
}
pub struct HttpClient {
client: reqwest::Client,
base_url: url::Url,
auth: Arc<dyn HttpAuthProvider>,
http_config: HttpConfig,
}
impl HttpClient {
pub fn new(
client: reqwest::Client,
base_url: String,
auth: Arc<dyn HttpAuthProvider>,
) -> Result<Self, HttpConnectorError> {
Self::with_config(client, base_url, auth, HttpConfig::default())
}
pub fn with_config(
client: reqwest::Client,
base_url: String,
auth: Arc<dyn HttpAuthProvider>,
http_config: HttpConfig,
) -> Result<Self, HttpConnectorError> {
let base_url = url::Url::parse(&base_url)
.map_err(|_| HttpConnectorError::Backend("invalid base URL".to_string()))?;
Ok(Self {
client,
base_url,
auth,
http_config,
})
}
pub fn from_config(
base_url: String,
auth: Arc<dyn HttpAuthProvider>,
http_config: HttpConfig,
) -> Result<Self, HttpConnectorError> {
let mut headers = HeaderMap::new();
if let Ok(ua) = HeaderValue::from_str(&http_config.user_agent) {
headers.insert(reqwest::header::USER_AGENT, ua);
}
for (key, value) in &http_config.default_headers {
if let (Ok(name), Ok(val)) = (
HeaderName::try_from(key.as_str()),
HeaderValue::try_from(value.as_str()),
) {
headers.insert(name, val);
}
}
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(http_config.timeout_seconds))
.default_headers(headers)
.build()
.map_err(|_| HttpConnectorError::Backend("failed to build HTTP client".to_string()))?;
Self::with_config(client, base_url, auth, http_config)
}
fn substitute_path(
operation: &Operation,
args: &serde_json::Map<String, serde_json::Value>,
) -> Result<String, HttpConnectorError> {
let mut path = operation.path.clone();
for param in operation.path_parameters() {
let placeholder = format!("{{{}}}", param.name);
if let Some(value) = args.get(¶m.name) {
let value_str = render_scalar(¶m.name, value)?;
path = path.replace(&placeholder, &value_str);
}
}
Ok(path)
}
fn render_query_value(
param_name: &str,
value: &serde_json::Value,
) -> Result<String, HttpConnectorError> {
if let serde_json::Value::Array(arr) = value {
let mut csv = String::new();
for (i, member) in arr.iter().enumerate() {
if i > 0 {
csv.push(',');
}
csv.push_str(&render_scalar(param_name, member)?);
}
Ok(csv)
} else {
render_scalar(param_name, value)
}
}
fn build_query(
operation: &Operation,
args: &serde_json::Map<String, serde_json::Value>,
) -> Result<HashMap<String, String>, HttpConnectorError> {
let mut query = HashMap::new();
for param in operation.query_parameters() {
if let Some(value) = args.get(¶m.name) {
query.insert(
param.name.clone(),
Self::render_query_value(¶m.name, value)?,
);
}
}
Ok(query)
}
fn build_headers(
operation: &Operation,
args: &serde_json::Map<String, serde_json::Value>,
) -> Result<HeaderMap, HttpConnectorError> {
let mut headers = HeaderMap::new();
for param in operation.header_parameters() {
if let Some(value) = args.get(¶m.name) {
let name = HeaderName::try_from(param.name.as_str()).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid header name".to_string())
})?;
let rendered = render_scalar(¶m.name, value)?;
let val = HeaderValue::try_from(rendered).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid header value".to_string())
})?;
headers.insert(name, val);
}
}
Ok(headers)
}
fn build_body(
operation: &Operation,
args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
if !operation.has_request_body {
return None;
}
if let Some(body) = args.get("body") {
return Some(body.clone());
}
let declared: std::collections::HashSet<&str> = operation
.parameters
.iter()
.map(|p| p.name.as_str())
.collect();
let body: serde_json::Map<String, serde_json::Value> = args
.iter()
.filter(|(k, _)| !declared.contains(k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if body.is_empty() {
None
} else {
Some(serde_json::Value::Object(body))
}
}
fn convert_method(method: &str) -> Result<reqwest::Method, HttpConnectorError> {
match method.to_uppercase().as_str() {
"GET" => Ok(reqwest::Method::GET),
"POST" => Ok(reqwest::Method::POST),
"PUT" => Ok(reqwest::Method::PUT),
"PATCH" => Ok(reqwest::Method::PATCH),
"DELETE" => Ok(reqwest::Method::DELETE),
"HEAD" => Ok(reqwest::Method::HEAD),
"OPTIONS" => Ok(reqwest::Method::OPTIONS),
_ => Err(HttpConnectorError::Backend(
"unknown HTTP method".to_string(),
)),
}
}
async fn send_with_retries(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::Response, HttpConnectorError> {
let max_retries = self.http_config.retries;
let mut last_status: Option<u16> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
let delay = self.http_config.retry_backoff_ms * (1u64 << (attempt - 1));
tokio::time::sleep(Duration::from_millis(delay)).await;
}
let Some(attempt_request) = request.try_clone() else {
return Err(HttpConnectorError::Request(
"request body is not retryable".to_string(),
));
};
match attempt_request.send().await {
Ok(response) => {
let status = response.status();
if status.is_server_error() && attempt < max_retries {
last_status = Some(status.as_u16());
continue;
}
return Ok(response);
},
Err(e) => {
let retryable = e.is_connect() || e.is_timeout();
if retryable && attempt < max_retries {
continue;
}
return Err(HttpConnectorError::Request(
"transport error contacting backend".to_string(),
));
},
}
}
Err(HttpConnectorError::Status {
status: last_status.unwrap_or(0),
})
}
}
fn render_scalar(
param_name: &str,
value: &serde_json::Value,
) -> Result<String, HttpConnectorError> {
match value {
serde_json::Value::String(s) => Ok(s.clone()),
serde_json::Value::Number(n) => Ok(n.to_string()),
serde_json::Value::Bool(b) => Ok(b.to_string()),
serde_json::Value::Null => Ok("null".to_string()),
serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
Err(HttpConnectorError::Backend(format!(
"param '{param_name}' must be a scalar (non-scalar values are \
not supported in path/query/header position)"
)))
},
}
}
#[async_trait]
impl HttpConnector for HttpClient {
async fn execute(
&self,
operation: &Operation,
args: &serde_json::Value,
) -> Result<serde_json::Value, HttpConnectorError> {
let empty = serde_json::Map::new();
let args_map = args.as_object().unwrap_or(&empty);
let substituted = Self::substitute_path(operation, args_map)?;
let joined = join_url(self.base_url.as_str(), &substituted);
let mut url = url::Url::parse(&joined)
.map_err(|_| HttpConnectorError::Backend("constructed URL is invalid".to_string()))?;
let mut query = Self::build_query(operation, args_map)?;
let mut headers = Self::build_headers(operation, args_map)?;
self.auth.apply(&mut headers, &mut query, None).await?;
if !query.is_empty() {
let mut pairs = url.query_pairs_mut();
for (key, value) in &query {
pairs.append_pair(key, value);
}
drop(pairs);
}
let method = Self::convert_method(&operation.method)?;
let mut request = self.client.request(method, url);
request = request.headers(headers);
if let Some(body) = Self::build_body(operation, args_map) {
request = request.json(&body);
}
let response = self.send_with_retries(request).await?;
let status = response.status();
if !status.is_success() {
return Err(HttpConnectorError::Status {
status: status.as_u16(),
});
}
let body = response
.text()
.await
.map_err(|_| HttpConnectorError::Request("failed to read response body".to_string()))?;
if body.is_empty() {
return Ok(serde_json::Value::Null);
}
serde_json::from_str(&body).map_err(|_| {
HttpConnectorError::Backend("response body was not valid JSON".to_string())
})
}
fn base_url(&self) -> &str {
self.base_url.as_str()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::auth::NoAuth;
use crate::http::{Parameter, ParameterLocation};
fn get_user_op() -> Operation {
Operation {
method: "GET".to_string(),
path: "/users/{id}".to_string(),
parameters: vec![
Parameter::new("id", ParameterLocation::Path, true),
Parameter::new("verbose", ParameterLocation::Query, false),
],
has_request_body: false,
base_url: None,
}
}
#[test]
fn test_build_url_with_path_prefix() {
let client = HttpClient::new(
reqwest::Client::new(),
"https://xxx.execute-api.eu-west-1.amazonaws.com/v1/".to_string(),
Arc::new(NoAuth),
)
.unwrap();
let op = get_user_op();
let mut args = serde_json::Map::new();
args.insert("id".to_string(), serde_json::json!("42"));
let substituted = HttpClient::substitute_path(&op, &args).unwrap();
let joined = join_url(client.base_url(), &substituted);
assert_eq!(
joined,
"https://xxx.execute-api.eu-west-1.amazonaws.com/v1/users/42"
);
}
#[test]
fn test_substitute_path_replaces_placeholder() {
let op = get_user_op();
let mut args = serde_json::Map::new();
args.insert("id".to_string(), serde_json::json!(7));
assert_eq!(HttpClient::substitute_path(&op, &args).unwrap(), "/users/7");
}
#[test]
fn test_build_query_skips_path_params() {
let op = get_user_op();
let mut args = serde_json::Map::new();
args.insert("id".to_string(), serde_json::json!("42"));
args.insert("verbose".to_string(), serde_json::json!(true));
let query = HttpClient::build_query(&op, &args).unwrap();
assert_eq!(query.get("verbose"), Some(&"true".to_string()));
assert!(!query.contains_key("id"));
}
#[test]
fn render_query_value_comma_joins_scalar_array() {
let rendered =
HttpClient::render_query_value("tags", &serde_json::json!(["a", 2, true])).unwrap();
assert_eq!(rendered, "a,2,true");
}
#[test]
fn render_query_value_scalar_passthrough() {
assert_eq!(
HttpClient::render_query_value("q", &serde_json::json!("hi")).unwrap(),
"hi"
);
assert_eq!(
HttpClient::render_query_value("n", &serde_json::json!(7)).unwrap(),
"7"
);
}
#[test]
fn render_scalar_null_is_bare_null() {
assert_eq!(
render_scalar("x", &serde_json::Value::Null).unwrap(),
"null"
);
}
#[test]
fn substitute_path_rejects_object_param() {
let op = get_user_op();
let mut args = serde_json::Map::new();
args.insert("id".to_string(), serde_json::json!({"nested": "x"}));
let err = HttpClient::substitute_path(&op, &args).unwrap_err();
assert!(matches!(err, HttpConnectorError::Backend(_)));
let rendered = err.to_string();
assert!(
rendered.contains("id"),
"error must name the param: {rendered}"
);
for forbidden in ['{', '[', '"'] {
assert!(
!rendered.contains(forbidden),
"must not echo JSON: {rendered}"
);
}
assert!(
!rendered.contains("nested"),
"must not echo the value: {rendered}"
);
}
#[test]
fn build_query_rejects_object_param() {
let op = get_user_op();
let mut args = serde_json::Map::new();
args.insert("verbose".to_string(), serde_json::json!({"k": "v"}));
let err = HttpClient::build_query(&op, &args).unwrap_err();
assert!(matches!(err, HttpConnectorError::Backend(_)));
assert!(err.to_string().contains("verbose"));
}
#[test]
fn render_query_value_rejects_array_with_object_member() {
let err = HttpClient::render_query_value("tags", &serde_json::json!(["ok", {"bad": 1}]))
.unwrap_err();
assert!(matches!(err, HttpConnectorError::Backend(_)));
assert!(err.to_string().contains("tags"));
}
#[test]
fn build_headers_rejects_non_scalar_param() {
let op = Operation {
method: "GET".to_string(),
path: "/x".to_string(),
parameters: vec![Parameter::new("x-trace", ParameterLocation::Header, false)],
has_request_body: false,
base_url: None,
};
let mut args = serde_json::Map::new();
args.insert("x-trace".to_string(), serde_json::json!(["a", "b"]));
let err = HttpClient::build_headers(&op, &args).unwrap_err();
assert!(matches!(err, HttpConnectorError::Backend(_)));
assert!(err.to_string().contains("x-trace"));
let mut args2 = serde_json::Map::new();
args2.insert("x-trace".to_string(), serde_json::json!({"k": "v"}));
let err2 = HttpClient::build_headers(&op, &args2).unwrap_err();
assert!(matches!(err2, HttpConnectorError::Backend(_)));
assert!(err2.to_string().contains("x-trace"));
let mut args3 = serde_json::Map::new();
args3.insert("x-trace".to_string(), serde_json::json!("abc"));
let headers = HttpClient::build_headers(&op, &args3).unwrap();
assert_eq!(headers.get("x-trace").unwrap(), "abc");
}
#[test]
fn test_new_is_lazy_and_rejects_bad_url() {
let err = HttpClient::new(
reqwest::Client::new(),
"not a url".to_string(),
Arc::new(NoAuth),
)
.err()
.expect("bad URL should error");
assert!(matches!(err, HttpConnectorError::Backend(_)));
let rendered = err.to_string();
assert!(!rendered.contains("not a url"), "must not echo the bad URL");
}
#[tokio::test]
async fn http_connector_get_returns_json() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/users/42"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"id": 42, "name": "Ada"})),
)
.mount(&server)
.await;
let client =
HttpClient::new(reqwest::Client::new(), server.uri(), Arc::new(NoAuth)).unwrap();
let op = get_user_op();
let args = serde_json::json!({"id": "42"});
let result = client.execute(&op, &args).await.unwrap();
assert_eq!(result["id"], 42);
assert_eq!(result["name"], "Ada");
}
#[tokio::test]
async fn http_connector_post_sends_body_and_auth() {
use wiremock::matchers::{body_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/items"))
.and(header("authorization", "Bearer tok"))
.and(body_json(serde_json::json!({"name": "widget"})))
.respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({"ok": true})))
.mount(&server)
.await;
let auth = crate::http::auth::create_auth_provider(&crate::http::AuthConfig::Bearer {
token: "tok".to_string(),
required: true,
})
.unwrap();
let client = HttpClient::new(reqwest::Client::new(), server.uri(), auth).unwrap();
let op = Operation {
method: "POST".to_string(),
path: "/items".to_string(),
parameters: vec![],
has_request_body: true,
base_url: None,
};
let args = serde_json::json!({"name": "widget"});
let result = client.execute(&op, &args).await.unwrap();
assert_eq!(result["ok"], true);
}
#[tokio::test]
async fn http_connector_maps_non_2xx_to_status_without_url() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/users/42"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let client =
HttpClient::new(reqwest::Client::new(), server.uri(), Arc::new(NoAuth)).unwrap();
let op = get_user_op();
let args = serde_json::json!({"id": "42"});
let err = client.execute(&op, &args).await.unwrap_err();
assert!(matches!(err, HttpConnectorError::Status { status: 404 }));
let rendered = err.to_string();
assert!(rendered.contains("404"));
assert!(
!rendered.contains("http://"),
"status error must not echo the URL"
);
}
}