1use super::auth::HttpAuthProvider;
13use super::{join_url, HttpConnector, HttpConnectorError, Operation};
14use async_trait::async_trait;
15use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20
21#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
24#[serde(deny_unknown_fields)]
25pub struct HttpConfig {
26 #[serde(default = "default_timeout")]
28 pub timeout_seconds: u64,
29 #[serde(default = "default_retries")]
31 pub retries: u32,
32 #[serde(default = "default_retry_backoff")]
34 pub retry_backoff_ms: u64,
35 #[serde(default = "default_user_agent")]
37 pub user_agent: String,
38 #[serde(default)]
40 pub default_headers: HashMap<String, String>,
41}
42
43fn default_timeout() -> u64 {
44 30
45}
46fn default_retries() -> u32 {
47 3
48}
49fn default_retry_backoff() -> u64 {
50 1000
51}
52fn default_user_agent() -> String {
53 format!("pmcp-server-toolkit/{}", env!("CARGO_PKG_VERSION"))
54}
55
56impl Default for HttpConfig {
57 fn default() -> Self {
58 Self {
59 timeout_seconds: default_timeout(),
60 retries: default_retries(),
61 retry_backoff_ms: default_retry_backoff(),
62 user_agent: default_user_agent(),
63 default_headers: HashMap::new(),
64 }
65 }
66}
67
68pub struct HttpClient {
70 client: reqwest::Client,
71 base_url: url::Url,
72 auth: Arc<dyn HttpAuthProvider>,
73 http_config: HttpConfig,
74}
75
76impl HttpClient {
77 pub fn new(
84 client: reqwest::Client,
85 base_url: String,
86 auth: Arc<dyn HttpAuthProvider>,
87 ) -> Result<Self, HttpConnectorError> {
88 Self::with_config(client, base_url, auth, HttpConfig::default())
89 }
90
91 pub fn with_config(
97 client: reqwest::Client,
98 base_url: String,
99 auth: Arc<dyn HttpAuthProvider>,
100 http_config: HttpConfig,
101 ) -> Result<Self, HttpConnectorError> {
102 let base_url = url::Url::parse(&base_url)
103 .map_err(|_| HttpConnectorError::Backend("invalid base URL".to_string()))?;
104 Ok(Self {
105 client,
106 base_url,
107 auth,
108 http_config,
109 })
110 }
111
112 pub fn from_config(
119 base_url: String,
120 auth: Arc<dyn HttpAuthProvider>,
121 http_config: HttpConfig,
122 ) -> Result<Self, HttpConnectorError> {
123 let mut headers = HeaderMap::new();
124 if let Ok(ua) = HeaderValue::from_str(&http_config.user_agent) {
125 headers.insert(reqwest::header::USER_AGENT, ua);
126 }
127 for (key, value) in &http_config.default_headers {
128 if let (Ok(name), Ok(val)) = (
129 HeaderName::try_from(key.as_str()),
130 HeaderValue::try_from(value.as_str()),
131 ) {
132 headers.insert(name, val);
133 }
134 }
135 let client = reqwest::Client::builder()
136 .timeout(Duration::from_secs(http_config.timeout_seconds))
137 .default_headers(headers)
138 .build()
139 .map_err(|_| HttpConnectorError::Backend("failed to build HTTP client".to_string()))?;
140 Self::with_config(client, base_url, auth, http_config)
141 }
142
143 fn substitute_path(
151 operation: &Operation,
152 args: &serde_json::Map<String, serde_json::Value>,
153 ) -> Result<String, HttpConnectorError> {
154 let mut path = operation.path.clone();
155 for param in operation.path_parameters() {
156 let placeholder = format!("{{{}}}", param.name);
157 if let Some(value) = args.get(¶m.name) {
158 let value_str = render_scalar(¶m.name, value)?;
159 path = path.replace(&placeholder, &value_str);
160 }
161 }
162 Ok(path)
163 }
164
165 fn render_query_value(
175 param_name: &str,
176 value: &serde_json::Value,
177 ) -> Result<String, HttpConnectorError> {
178 if let serde_json::Value::Array(arr) = value {
179 let mut csv = String::new();
182 for (i, member) in arr.iter().enumerate() {
183 if i > 0 {
184 csv.push(',');
185 }
186 csv.push_str(&render_scalar(param_name, member)?);
187 }
188 Ok(csv)
189 } else {
190 render_scalar(param_name, value)
191 }
192 }
193
194 fn build_query(
202 operation: &Operation,
203 args: &serde_json::Map<String, serde_json::Value>,
204 ) -> Result<HashMap<String, String>, HttpConnectorError> {
205 let mut query = HashMap::new();
206 for param in operation.query_parameters() {
207 if let Some(value) = args.get(¶m.name) {
208 query.insert(
209 param.name.clone(),
210 Self::render_query_value(¶m.name, value)?,
211 );
212 }
213 }
214 Ok(query)
215 }
216
217 fn build_headers(
219 operation: &Operation,
220 args: &serde_json::Map<String, serde_json::Value>,
221 ) -> Result<HeaderMap, HttpConnectorError> {
222 let mut headers = HeaderMap::new();
223 for param in operation.header_parameters() {
224 if let Some(value) = args.get(¶m.name) {
225 let name = HeaderName::try_from(param.name.as_str()).map_err(|_| {
226 HttpConnectorError::InvalidHeader("invalid header name".to_string())
227 })?;
228 let rendered = render_scalar(¶m.name, value)?;
231 let val = HeaderValue::try_from(rendered).map_err(|_| {
232 HttpConnectorError::InvalidHeader("invalid header value".to_string())
233 })?;
234 headers.insert(name, val);
235 }
236 }
237 Ok(headers)
238 }
239
240 fn build_body(
242 operation: &Operation,
243 args: &serde_json::Map<String, serde_json::Value>,
244 ) -> Option<serde_json::Value> {
245 if !operation.has_request_body {
246 return None;
247 }
248 if let Some(body) = args.get("body") {
249 return Some(body.clone());
250 }
251 let declared: std::collections::HashSet<&str> = operation
252 .parameters
253 .iter()
254 .map(|p| p.name.as_str())
255 .collect();
256 let body: serde_json::Map<String, serde_json::Value> = args
257 .iter()
258 .filter(|(k, _)| !declared.contains(k.as_str()))
259 .map(|(k, v)| (k.clone(), v.clone()))
260 .collect();
261 if body.is_empty() {
262 None
263 } else {
264 Some(serde_json::Value::Object(body))
265 }
266 }
267
268 fn convert_method(method: &str) -> Result<reqwest::Method, HttpConnectorError> {
269 match method.to_uppercase().as_str() {
270 "GET" => Ok(reqwest::Method::GET),
271 "POST" => Ok(reqwest::Method::POST),
272 "PUT" => Ok(reqwest::Method::PUT),
273 "PATCH" => Ok(reqwest::Method::PATCH),
274 "DELETE" => Ok(reqwest::Method::DELETE),
275 "HEAD" => Ok(reqwest::Method::HEAD),
276 "OPTIONS" => Ok(reqwest::Method::OPTIONS),
277 _ => Err(HttpConnectorError::Backend(
278 "unknown HTTP method".to_string(),
279 )),
280 }
281 }
282
283 async fn send_with_retries(
285 &self,
286 request: reqwest::RequestBuilder,
287 ) -> Result<reqwest::Response, HttpConnectorError> {
288 let max_retries = self.http_config.retries;
289 let mut last_status: Option<u16> = None;
290 for attempt in 0..=max_retries {
291 if attempt > 0 {
292 let delay = self.http_config.retry_backoff_ms * (1u64 << (attempt - 1));
293 tokio::time::sleep(Duration::from_millis(delay)).await;
294 }
295 let Some(attempt_request) = request.try_clone() else {
296 return Err(HttpConnectorError::Request(
297 "request body is not retryable".to_string(),
298 ));
299 };
300 match attempt_request.send().await {
301 Ok(response) => {
302 let status = response.status();
303 if status.is_server_error() && attempt < max_retries {
304 last_status = Some(status.as_u16());
305 continue;
306 }
307 return Ok(response);
308 },
309 Err(e) => {
310 let retryable = e.is_connect() || e.is_timeout();
311 if retryable && attempt < max_retries {
312 continue;
313 }
314 return Err(HttpConnectorError::Request(
316 "transport error contacting backend".to_string(),
317 ));
318 },
319 }
320 }
321 Err(HttpConnectorError::Status {
322 status: last_status.unwrap_or(0),
323 })
324 }
325}
326
327fn render_scalar(
355 param_name: &str,
356 value: &serde_json::Value,
357) -> Result<String, HttpConnectorError> {
358 match value {
359 serde_json::Value::String(s) => Ok(s.clone()),
360 serde_json::Value::Number(n) => Ok(n.to_string()),
361 serde_json::Value::Bool(b) => Ok(b.to_string()),
362 serde_json::Value::Null => Ok("null".to_string()),
363 serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
366 Err(HttpConnectorError::Backend(format!(
367 "param '{param_name}' must be a scalar (non-scalar values are \
368 not supported in path/query/header position)"
369 )))
370 },
371 }
372}
373
374#[async_trait]
375impl HttpConnector for HttpClient {
376 async fn execute(
377 &self,
378 operation: &Operation,
379 args: &serde_json::Value,
380 ) -> Result<serde_json::Value, HttpConnectorError> {
381 let empty = serde_json::Map::new();
382 let args_map = args.as_object().unwrap_or(&empty);
383
384 let substituted = Self::substitute_path(operation, args_map)?;
388 let joined = join_url(self.base_url.as_str(), &substituted);
389 let mut url = url::Url::parse(&joined)
390 .map_err(|_| HttpConnectorError::Backend("constructed URL is invalid".to_string()))?;
391
392 let mut query = Self::build_query(operation, args_map)?;
393 let mut headers = Self::build_headers(operation, args_map)?;
394
395 self.auth.apply(&mut headers, &mut query, None).await?;
398
399 if !query.is_empty() {
405 let mut pairs = url.query_pairs_mut();
406 for (key, value) in &query {
407 pairs.append_pair(key, value);
408 }
409 drop(pairs);
410 }
411
412 let method = Self::convert_method(&operation.method)?;
413 let mut request = self.client.request(method, url);
414 request = request.headers(headers);
415 if let Some(body) = Self::build_body(operation, args_map) {
416 request = request.json(&body);
417 }
418
419 let response = self.send_with_retries(request).await?;
420 let status = response.status();
421 if !status.is_success() {
422 return Err(HttpConnectorError::Status {
423 status: status.as_u16(),
424 });
425 }
426 let body = response
427 .text()
428 .await
429 .map_err(|_| HttpConnectorError::Request("failed to read response body".to_string()))?;
430 if body.is_empty() {
431 return Ok(serde_json::Value::Null);
432 }
433 serde_json::from_str(&body).map_err(|_| {
434 HttpConnectorError::Backend("response body was not valid JSON".to_string())
435 })
436 }
437
438 fn base_url(&self) -> &str {
439 self.base_url.as_str()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::http::auth::NoAuth;
447 use crate::http::{Parameter, ParameterLocation};
448
449 fn get_user_op() -> Operation {
450 Operation {
451 method: "GET".to_string(),
452 path: "/users/{id}".to_string(),
453 parameters: vec![
454 Parameter::new("id", ParameterLocation::Path, true),
455 Parameter::new("verbose", ParameterLocation::Query, false),
456 ],
457 has_request_body: false,
458 base_url: None,
459 }
460 }
461
462 #[test]
463 fn test_build_url_with_path_prefix() {
464 let client = HttpClient::new(
466 reqwest::Client::new(),
467 "https://xxx.execute-api.eu-west-1.amazonaws.com/v1/".to_string(),
468 Arc::new(NoAuth),
469 )
470 .unwrap();
471 let op = get_user_op();
472 let mut args = serde_json::Map::new();
473 args.insert("id".to_string(), serde_json::json!("42"));
474 let substituted = HttpClient::substitute_path(&op, &args).unwrap();
475 let joined = join_url(client.base_url(), &substituted);
476 assert_eq!(
477 joined,
478 "https://xxx.execute-api.eu-west-1.amazonaws.com/v1/users/42"
479 );
480 }
481
482 #[test]
483 fn test_substitute_path_replaces_placeholder() {
484 let op = get_user_op();
485 let mut args = serde_json::Map::new();
486 args.insert("id".to_string(), serde_json::json!(7));
487 assert_eq!(HttpClient::substitute_path(&op, &args).unwrap(), "/users/7");
488 }
489
490 #[test]
491 fn test_build_query_skips_path_params() {
492 let op = get_user_op();
493 let mut args = serde_json::Map::new();
494 args.insert("id".to_string(), serde_json::json!("42"));
495 args.insert("verbose".to_string(), serde_json::json!(true));
496 let query = HttpClient::build_query(&op, &args).unwrap();
497 assert_eq!(query.get("verbose"), Some(&"true".to_string()));
498 assert!(!query.contains_key("id"));
499 }
500
501 #[test]
506 fn render_query_value_comma_joins_scalar_array() {
507 let rendered =
508 HttpClient::render_query_value("tags", &serde_json::json!(["a", 2, true])).unwrap();
509 assert_eq!(rendered, "a,2,true");
510 }
511
512 #[test]
514 fn render_query_value_scalar_passthrough() {
515 assert_eq!(
516 HttpClient::render_query_value("q", &serde_json::json!("hi")).unwrap(),
517 "hi"
518 );
519 assert_eq!(
520 HttpClient::render_query_value("n", &serde_json::json!(7)).unwrap(),
521 "7"
522 );
523 }
524
525 #[test]
528 fn render_scalar_null_is_bare_null() {
529 assert_eq!(
530 render_scalar("x", &serde_json::Value::Null).unwrap(),
531 "null"
532 );
533 }
534
535 #[test]
538 fn substitute_path_rejects_object_param() {
539 let op = get_user_op();
540 let mut args = serde_json::Map::new();
541 args.insert("id".to_string(), serde_json::json!({"nested": "x"}));
542 let err = HttpClient::substitute_path(&op, &args).unwrap_err();
543 assert!(matches!(err, HttpConnectorError::Backend(_)));
544 let rendered = err.to_string();
545 assert!(
546 rendered.contains("id"),
547 "error must name the param: {rendered}"
548 );
549 for forbidden in ['{', '[', '"'] {
550 assert!(
551 !rendered.contains(forbidden),
552 "must not echo JSON: {rendered}"
553 );
554 }
555 assert!(
557 !rendered.contains("nested"),
558 "must not echo the value: {rendered}"
559 );
560 }
561
562 #[test]
564 fn build_query_rejects_object_param() {
565 let op = get_user_op();
566 let mut args = serde_json::Map::new();
567 args.insert("verbose".to_string(), serde_json::json!({"k": "v"}));
568 let err = HttpClient::build_query(&op, &args).unwrap_err();
569 assert!(matches!(err, HttpConnectorError::Backend(_)));
570 assert!(err.to_string().contains("verbose"));
571 }
572
573 #[test]
576 fn render_query_value_rejects_array_with_object_member() {
577 let err = HttpClient::render_query_value("tags", &serde_json::json!(["ok", {"bad": 1}]))
578 .unwrap_err();
579 assert!(matches!(err, HttpConnectorError::Backend(_)));
580 assert!(err.to_string().contains("tags"));
581 }
582
583 #[test]
585 fn build_headers_rejects_non_scalar_param() {
586 let op = Operation {
587 method: "GET".to_string(),
588 path: "/x".to_string(),
589 parameters: vec![Parameter::new("x-trace", ParameterLocation::Header, false)],
590 has_request_body: false,
591 base_url: None,
592 };
593 let mut args = serde_json::Map::new();
596 args.insert("x-trace".to_string(), serde_json::json!(["a", "b"]));
597 let err = HttpClient::build_headers(&op, &args).unwrap_err();
598 assert!(matches!(err, HttpConnectorError::Backend(_)));
599 assert!(err.to_string().contains("x-trace"));
600 let mut args2 = serde_json::Map::new();
602 args2.insert("x-trace".to_string(), serde_json::json!({"k": "v"}));
603 let err2 = HttpClient::build_headers(&op, &args2).unwrap_err();
604 assert!(matches!(err2, HttpConnectorError::Backend(_)));
605 assert!(err2.to_string().contains("x-trace"));
606 let mut args3 = serde_json::Map::new();
608 args3.insert("x-trace".to_string(), serde_json::json!("abc"));
609 let headers = HttpClient::build_headers(&op, &args3).unwrap();
610 assert_eq!(headers.get("x-trace").unwrap(), "abc");
611 }
612
613 #[test]
614 fn test_new_is_lazy_and_rejects_bad_url() {
615 let err = HttpClient::new(
617 reqwest::Client::new(),
618 "not a url".to_string(),
619 Arc::new(NoAuth),
620 )
621 .err()
622 .expect("bad URL should error");
623 assert!(matches!(err, HttpConnectorError::Backend(_)));
624 let rendered = err.to_string();
625 assert!(!rendered.contains("not a url"), "must not echo the bad URL");
626 }
627
628 #[tokio::test]
629 async fn http_connector_get_returns_json() {
630 use wiremock::matchers::{method, path};
631 use wiremock::{Mock, MockServer, ResponseTemplate};
632
633 let server = MockServer::start().await;
634 Mock::given(method("GET"))
635 .and(path("/users/42"))
636 .respond_with(
637 ResponseTemplate::new(200)
638 .set_body_json(serde_json::json!({"id": 42, "name": "Ada"})),
639 )
640 .mount(&server)
641 .await;
642
643 let client =
644 HttpClient::new(reqwest::Client::new(), server.uri(), Arc::new(NoAuth)).unwrap();
645 let op = get_user_op();
646 let args = serde_json::json!({"id": "42"});
647 let result = client.execute(&op, &args).await.unwrap();
648 assert_eq!(result["id"], 42);
649 assert_eq!(result["name"], "Ada");
650 }
651
652 #[tokio::test]
653 async fn http_connector_post_sends_body_and_auth() {
654 use wiremock::matchers::{body_json, header, method, path};
655 use wiremock::{Mock, MockServer, ResponseTemplate};
656
657 let server = MockServer::start().await;
658 Mock::given(method("POST"))
659 .and(path("/items"))
660 .and(header("authorization", "Bearer tok"))
661 .and(body_json(serde_json::json!({"name": "widget"})))
662 .respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({"ok": true})))
663 .mount(&server)
664 .await;
665
666 let auth = crate::http::auth::create_auth_provider(&crate::http::AuthConfig::Bearer {
667 token: "tok".to_string(),
668 required: true,
669 })
670 .unwrap();
671 let client = HttpClient::new(reqwest::Client::new(), server.uri(), auth).unwrap();
672 let op = Operation {
673 method: "POST".to_string(),
674 path: "/items".to_string(),
675 parameters: vec![],
676 has_request_body: true,
677 base_url: None,
678 };
679 let args = serde_json::json!({"name": "widget"});
680 let result = client.execute(&op, &args).await.unwrap();
681 assert_eq!(result["ok"], true);
682 }
683
684 #[tokio::test]
685 async fn http_connector_maps_non_2xx_to_status_without_url() {
686 use wiremock::matchers::{method, path};
687 use wiremock::{Mock, MockServer, ResponseTemplate};
688
689 let server = MockServer::start().await;
690 Mock::given(method("GET"))
691 .and(path("/users/42"))
692 .respond_with(ResponseTemplate::new(404))
693 .mount(&server)
694 .await;
695
696 let client =
697 HttpClient::new(reqwest::Client::new(), server.uri(), Arc::new(NoAuth)).unwrap();
698 let op = get_user_op();
699 let args = serde_json::json!({"id": "42"});
700 let err = client.execute(&op, &args).await.unwrap_err();
701 assert!(matches!(err, HttpConnectorError::Status { status: 404 }));
702 let rendered = err.to_string();
703 assert!(rendered.contains("404"));
704 assert!(
705 !rendered.contains("http://"),
706 "status error must not echo the URL"
707 );
708 }
709}