use std::collections::{HashMap, HashSet};
use std::sync::{Arc, LazyLock};
use async_trait::async_trait;
use regex::Regex;
use thiserror::Error;
use tracing::{debug, warn};
use apcore::context::Context;
use apcore::errors::ModuleError;
use apcore::module::Module;
use apcore::Registry;
use crate::http_verb_map::extract_path_param_names;
use crate::output::types::WriteResult;
use crate::types::ScannedModule;
#[derive(Debug, Error)]
pub enum HTTPProxyRegistryWriterError {
#[error("invalid base_url: {0}")]
InvalidBaseUrl(String),
#[error("invalid timeout_secs: {0}")]
InvalidTimeout(String),
}
pub struct HTTPProxyRegistryWriter {
base_url: String,
auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
client: reqwest::Client,
}
impl std::fmt::Debug for HTTPProxyRegistryWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HTTPProxyRegistryWriter")
.field("base_url", &self.base_url)
.field(
"auth_header_factory",
&self.auth_header_factory.as_ref().map(|_| "<factory>"),
)
.field("client", &self.client)
.finish()
}
}
impl HTTPProxyRegistryWriter {
pub fn new(
base_url: String,
auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
timeout_secs: f64,
) -> Result<Self, HTTPProxyRegistryWriterError> {
let parsed = reqwest::Url::parse(&base_url).map_err(|e| {
HTTPProxyRegistryWriterError::InvalidBaseUrl(format!("'{}': {e}", base_url))
})?;
if !matches!(parsed.scheme(), "http" | "https") {
return Err(HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
"scheme '{}' is not allowed — only http and https are permitted",
parsed.scheme()
)));
}
if !timeout_secs.is_finite() || timeout_secs <= 0.0 {
return Err(HTTPProxyRegistryWriterError::InvalidTimeout(format!(
"must be a positive finite number, got {timeout_secs}"
)));
}
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs_f64(timeout_secs))
.build()
.map_err(|e| {
HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
"failed to build HTTP client: {e}"
))
})?;
Ok(Self {
base_url,
auth_header_factory: auth_header_factory.map(Arc::from),
client,
})
}
pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
let mut results: Vec<WriteResult> = Vec::new();
for module in modules {
let (http_method, url_path) = get_http_fields(module);
let path_params = extract_path_param_names(&url_path);
let proxy = ProxyModule {
base_url: self.base_url.clone(),
http_method,
url_path,
path_params,
input_schema: module.input_schema.clone(),
output_schema: module.output_schema.clone(),
description: module.description.clone(),
auth_header_factory: self.auth_header_factory.clone(),
client: self.client.clone(),
};
let descriptor = apcore::registry::registry::ModuleDescriptor {
module_id: module.module_id.clone(),
name: Some(module.module_id.clone()),
description: module.description.clone(),
documentation: module.documentation.clone(),
input_schema: module.input_schema.clone(),
output_schema: module.output_schema.clone(),
version: module.version.clone(),
tags: module.tags.clone(),
annotations: module.annotations.clone(),
examples: module.examples.clone(),
metadata: module.metadata.clone(),
display: module.display.clone(),
sunset_date: None,
dependencies: vec![],
enabled: true,
};
match registry.register(&module.module_id, Box::new(proxy), descriptor) {
Ok(()) => {
debug!("Registered HTTP proxy: {}", module.module_id);
results.push(WriteResult::new(module.module_id.clone()));
}
Err(e) => {
warn!(module_id = %module.module_id, error = %e, "HTTPProxyRegistryWriter registration failed");
results.push(WriteResult::failed(
module.module_id.clone(),
None,
e.to_string(),
));
}
}
}
results
}
}
fn get_http_fields(module: &ScannedModule) -> (String, String) {
let http_method = module
.metadata
.get("http_method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let url_path = module
.metadata
.get("url_path")
.and_then(|v| v.as_str())
.unwrap_or("/")
.to_string();
(http_method, url_path)
}
const BODY_METHODS: &[&str] = &["POST", "PUT", "PATCH"];
static PATH_PARAM_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
fn validate_path_params_filled(actual_path: &str) -> Result<(), String> {
if PATH_PARAM_RE.is_match(actual_path) {
let unfilled: Vec<&str> = PATH_PARAM_RE
.captures_iter(actual_path)
.filter_map(|cap| cap.get(1).map(|m| m.as_str()))
.collect();
Err(format!(
"Missing required path parameters {:?} — inputs must supply values for all path params in '{actual_path}'",
unfilled
))
} else {
Ok(())
}
}
fn percent_encode_path_segment(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~') {
out.push(b as char);
} else {
out.push_str(&format!("%{:02X}", b));
}
}
out
}
fn extract_error_message(body: &str) -> String {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
for key in &["error_message", "detail", "error", "message"] {
if let Some(val) = parsed.get(key) {
let msg = match val {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
if !msg.is_empty() {
return msg;
}
}
}
}
safe_truncate(body, 200)
}
fn safe_truncate(s: &str, max_chars: usize) -> String {
if s.chars().count() <= max_chars {
s.to_string()
} else {
s.chars().take(max_chars).collect()
}
}
struct ProxyModule {
base_url: String,
http_method: String,
url_path: String,
path_params: HashSet<String>,
input_schema: serde_json::Value,
output_schema: serde_json::Value,
description: String,
auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
client: reqwest::Client,
}
#[async_trait]
impl Module for ProxyModule {
fn input_schema(&self) -> serde_json::Value {
self.input_schema.clone()
}
fn output_schema(&self) -> serde_json::Value {
self.output_schema.clone()
}
fn description(&self) -> &str {
&self.description
}
async fn execute(
&self,
inputs: serde_json::Value,
_ctx: &Context<serde_json::Value>,
) -> Result<serde_json::Value, ModuleError> {
let mut actual_path = self.url_path.clone();
let mut query: HashMap<String, String> = HashMap::new();
let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
if let Some(obj) = inputs.as_object() {
let uses_body = BODY_METHODS.contains(&self.http_method.as_str());
for (key, value) in obj {
if self.path_params.contains(key) {
let val_str = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
actual_path = actual_path.replace(
&format!("{{{key}}}"),
&percent_encode_path_segment(&val_str),
);
} else if uses_body {
body.insert(key.clone(), value.clone());
} else {
let val_str = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
query.insert(key.clone(), val_str);
}
}
}
if let Err(msg) = validate_path_params_filled(&actual_path) {
return Err(ModuleError::new(
apcore::errors::ErrorCode::ModuleExecuteError,
msg,
));
}
let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
let mut request = match self.http_method.as_str() {
"GET" => self.client.get(&url),
"POST" => self.client.post(&url),
"PUT" => self.client.put(&url),
"PATCH" => self.client.patch(&url),
"DELETE" => self.client.delete(&url),
other => {
return Err(ModuleError::new(
apcore::errors::ErrorCode::ModuleExecuteError,
format!("Unsupported HTTP method: {other}"),
))
}
};
if let Some(ref factory) = self.auth_header_factory {
for (header_name, header_value) in factory() {
request = request.header(&header_name, &header_value);
}
}
if !query.is_empty() {
request = request.query(&query.iter().collect::<Vec<_>>());
}
if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
request = request.json(&body);
}
let resp = request.send().await.map_err(|e| {
ModuleError::new(
apcore::errors::ErrorCode::ModuleExecuteError,
format!("HTTP request failed: {e}"),
)
})?;
let status = resp.status();
if status.is_success() {
if status.as_u16() == 204 {
return Ok(serde_json::json!({}));
}
resp.json().await.map_err(|e| {
ModuleError::new(
apcore::errors::ErrorCode::ModuleExecuteError,
format!("Failed to parse response JSON: {e}"),
)
})
} else {
let error_text = resp.text().await.unwrap_or_default();
let message = extract_error_message(&error_text);
Err(ModuleError::new(
apcore::errors::ErrorCode::ModuleExecuteError,
format!("HTTP {}: {}", status.as_u16(), message),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_new_rejects_non_http_scheme() {
let result = HTTPProxyRegistryWriter::new("file:///etc/passwd".into(), None, 30.0);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("scheme 'file' is not allowed"));
}
#[test]
fn test_new_rejects_invalid_url() {
let result = HTTPProxyRegistryWriter::new("not a url".into(), None, 30.0);
assert!(result.is_err());
}
#[test]
fn test_new_rejects_nan_timeout() {
let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, f64::NAN);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("timeout"));
}
#[test]
fn test_new_rejects_negative_timeout() {
let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, -1.0);
assert!(result.is_err());
}
#[test]
fn test_new_accepts_https_scheme() {
let result = HTTPProxyRegistryWriter::new("https://api.example.com".into(), None, 30.0);
assert!(result.is_ok());
}
#[test]
fn test_get_http_fields_defaults() {
let module = ScannedModule::new(
"test".into(),
"test".into(),
json!({}),
json!({}),
vec![],
"app:func".into(),
);
let (method, path) = get_http_fields(&module);
assert_eq!(method, "GET");
assert_eq!(path, "/");
}
#[test]
fn test_get_http_fields_from_metadata() {
let mut module = ScannedModule::new(
"test".into(),
"test".into(),
json!({}),
json!({}),
vec![],
"app:func".into(),
);
module.metadata.insert(
"http_method".into(),
serde_json::Value::String("POST".into()),
);
module.metadata.insert(
"url_path".into(),
serde_json::Value::String("/users".into()),
);
let (method, path) = get_http_fields(&module);
assert_eq!(method, "POST");
assert_eq!(path, "/users");
}
#[test]
fn test_extract_path_params() {
let params = extract_path_param_names("/users/{user_id}/tasks/{task_id}");
assert!(params.contains("user_id"));
assert!(params.contains("task_id"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_extract_path_params_none() {
let params = extract_path_param_names("/users");
assert!(params.is_empty());
}
#[test]
fn test_extract_path_params_colon_style() {
let params = extract_path_param_names("/users/:id");
assert!(
params.contains("id"),
"colon-style param ':id' should be recognised; got: {params:?}"
);
assert_eq!(params.len(), 1);
}
#[test]
fn test_extract_path_params_mixed_styles() {
let params = extract_path_param_names("/users/:user_id/tasks/{task_id}");
assert!(params.contains("user_id"));
assert!(params.contains("task_id"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_extract_error_message_json_error_message() {
let body = r#"{"error_message": "not found"}"#;
assert_eq!(extract_error_message(body), "not found");
}
#[test]
fn test_extract_error_message_json_detail() {
let body = r#"{"detail": "unauthorized"}"#;
assert_eq!(extract_error_message(body), "unauthorized");
}
#[test]
fn test_extract_error_message_json_error() {
let body = r#"{"error": "bad request"}"#;
assert_eq!(extract_error_message(body), "bad request");
}
#[test]
fn test_extract_error_message_json_message() {
let body = r#"{"message": "server error"}"#;
assert_eq!(extract_error_message(body), "server error");
}
#[test]
fn test_extract_error_message_json_priority() {
let body = r#"{"error_message": "first", "message": "second"}"#;
assert_eq!(extract_error_message(body), "first");
}
#[test]
fn test_extract_error_message_plain_text_short() {
let body = "plain text error";
assert_eq!(extract_error_message(body), "plain text error");
}
#[test]
fn test_extract_error_message_plain_text_truncated() {
let body = "x".repeat(300);
let result = extract_error_message(&body);
assert_eq!(result.len(), 200);
}
#[test]
fn test_validate_path_params_filled_no_placeholders() {
assert!(validate_path_params_filled("/users/123/tasks/456").is_ok());
}
#[test]
fn test_validate_path_params_filled_static_path() {
assert!(validate_path_params_filled("/health").is_ok());
}
#[test]
fn test_validate_path_params_filled_unfilled_placeholder() {
let result = validate_path_params_filled("/users/{user_id}/tasks");
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.contains("user_id"),
"error should name the unfilled param: {msg}"
);
}
#[test]
fn test_validate_path_params_filled_multiple_unfilled() {
let result = validate_path_params_filled("/users/{user_id}/tasks/{task_id}");
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(msg.contains("user_id") || msg.contains("task_id"), "{msg}");
}
#[test]
fn test_safe_truncate_multibyte() {
let body = "\u{1F600}".repeat(300);
let result = safe_truncate(&body, 200);
assert_eq!(result.chars().count(), 200);
}
#[test]
fn test_body_methods_set_contents() {
assert!(BODY_METHODS.contains(&"POST"));
assert!(BODY_METHODS.contains(&"PUT"));
assert!(BODY_METHODS.contains(&"PATCH"));
assert!(!BODY_METHODS.contains(&"GET"));
assert!(!BODY_METHODS.contains(&"DELETE"));
assert!(!BODY_METHODS.contains(&"HEAD"));
assert!(!BODY_METHODS.contains(&"OPTIONS"));
}
}