1use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use thiserror::Error;
12use tracing::{debug, warn};
13
14use apcore::context::Context;
15use apcore::errors::ModuleError;
16use apcore::module::Module;
17use apcore::Registry;
18
19use crate::output::types::WriteResult;
20use crate::types::ScannedModule;
21
22#[derive(Debug, Error)]
24pub enum HTTPProxyWriterError {
25 #[error("invalid base_url: {0}")]
27 InvalidBaseUrl(String),
28 #[error("invalid timeout_secs: {0}")]
30 InvalidTimeout(String),
31}
32
33pub struct HTTPProxyRegistryWriter {
38 base_url: String,
39 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
40 client: reqwest::Client,
41}
42
43impl std::fmt::Debug for HTTPProxyRegistryWriter {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("HTTPProxyRegistryWriter")
46 .field("base_url", &self.base_url)
47 .field(
48 "auth_header_factory",
49 &self.auth_header_factory.as_ref().map(|_| "<factory>"),
50 )
51 .field("client", &self.client)
52 .finish()
53 }
54}
55
56impl HTTPProxyRegistryWriter {
57 pub fn new(
70 base_url: String,
71 auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
72 timeout_secs: f64,
73 ) -> Result<Self, HTTPProxyWriterError> {
74 let parsed = reqwest::Url::parse(&base_url)
75 .map_err(|e| HTTPProxyWriterError::InvalidBaseUrl(format!("'{}': {e}", base_url)))?;
76 if !matches!(parsed.scheme(), "http" | "https") {
77 return Err(HTTPProxyWriterError::InvalidBaseUrl(format!(
78 "scheme '{}' is not allowed — only http and https are permitted",
79 parsed.scheme()
80 )));
81 }
82
83 if !timeout_secs.is_finite() || timeout_secs <= 0.0 {
84 return Err(HTTPProxyWriterError::InvalidTimeout(format!(
85 "must be a positive finite number, got {timeout_secs}"
86 )));
87 }
88
89 let client = reqwest::Client::builder()
90 .timeout(std::time::Duration::from_secs_f64(timeout_secs))
91 .build()
92 .map_err(|e| {
93 HTTPProxyWriterError::InvalidBaseUrl(format!("failed to build HTTP client: {e}"))
94 })?;
95
96 Ok(Self {
97 base_url,
98 auth_header_factory: auth_header_factory.map(Arc::from),
99 client,
100 })
101 }
102
103 pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
105 let mut results: Vec<WriteResult> = Vec::new();
106
107 for module in modules {
108 let (http_method, url_path) = get_http_fields(module);
109 let path_params = extract_path_params(&url_path);
110 let proxy = ProxyModule {
111 base_url: self.base_url.clone(),
112 http_method,
113 url_path,
114 path_params,
115 input_schema: module.input_schema.clone(),
116 output_schema: module.output_schema.clone(),
117 description: module.description.clone(),
118 auth_header_factory: self.auth_header_factory.clone(),
119 client: self.client.clone(),
120 };
121
122 let descriptor = apcore::registry::registry::ModuleDescriptor {
123 module_id: module.module_id.clone(),
124 name: Some(module.module_id.clone()),
125 description: module.description.clone(),
126 documentation: module.documentation.clone(),
127 input_schema: module.input_schema.clone(),
128 output_schema: module.output_schema.clone(),
129 version: module.version.clone(),
130 tags: module.tags.clone(),
131 annotations: module.annotations.clone(),
132 examples: module.examples.clone(),
133 metadata: module.metadata.clone(),
134 display: module.display.clone(),
135 sunset_date: None,
136 dependencies: vec![],
137 enabled: true,
138 };
139
140 match registry.register(&module.module_id, Box::new(proxy), descriptor) {
141 Ok(()) => {
142 debug!("Registered HTTP proxy: {}", module.module_id);
143 results.push(WriteResult::new(module.module_id.clone()));
144 }
145 Err(e) => {
146 warn!(module_id = %module.module_id, error = %e, "HTTPProxyRegistryWriter registration failed");
147 results.push(WriteResult::failed(
148 module.module_id.clone(),
149 None,
150 e.to_string(),
151 ));
152 }
153 }
154 }
155
156 results
157 }
158}
159
160fn get_http_fields(module: &ScannedModule) -> (String, String) {
162 let http_method = module
163 .metadata
164 .get("http_method")
165 .and_then(|v| v.as_str())
166 .unwrap_or("GET")
167 .to_string();
168 let url_path = module
169 .metadata
170 .get("url_path")
171 .and_then(|v| v.as_str())
172 .unwrap_or("/")
173 .to_string();
174 (http_method, url_path)
175}
176
177static PATH_PARAM_RE: LazyLock<Regex> =
179 LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
180
181fn validate_path_params_filled(actual_path: &str) -> Result<(), String> {
185 if PATH_PARAM_RE.is_match(actual_path) {
186 let unfilled: Vec<&str> = PATH_PARAM_RE
187 .captures_iter(actual_path)
188 .filter_map(|cap| cap.get(1).map(|m| m.as_str()))
189 .collect();
190 Err(format!(
191 "Missing required path parameters {:?} — inputs must supply values for all path params in '{actual_path}'",
192 unfilled
193 ))
194 } else {
195 Ok(())
196 }
197}
198
199fn percent_encode_path_segment(s: &str) -> String {
201 let mut out = String::with_capacity(s.len());
202 for b in s.bytes() {
203 if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~') {
204 out.push(b as char);
205 } else {
206 out.push_str(&format!("%{:02X}", b));
207 }
208 }
209 out
210}
211
212fn extract_path_params(url_path: &str) -> HashSet<String> {
214 PATH_PARAM_RE
215 .captures_iter(url_path)
216 .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
217 .collect()
218}
219
220fn extract_error_message(body: &str) -> String {
226 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
227 for key in &["error_message", "detail", "error", "message"] {
228 if let Some(val) = parsed.get(key) {
229 let msg = match val {
230 serde_json::Value::String(s) => s.clone(),
231 other => other.to_string(),
232 };
233 if !msg.is_empty() {
234 return msg;
235 }
236 }
237 }
238 }
239
240 safe_truncate(body, 200)
241}
242
243fn safe_truncate(s: &str, max_chars: usize) -> String {
246 if s.chars().count() <= max_chars {
247 s.to_string()
248 } else {
249 s.chars().take(max_chars).collect()
250 }
251}
252
253struct ProxyModule {
255 base_url: String,
256 http_method: String,
257 url_path: String,
258 path_params: HashSet<String>,
259 input_schema: serde_json::Value,
260 output_schema: serde_json::Value,
261 description: String,
262 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
263 client: reqwest::Client,
265}
266
267#[async_trait]
268impl Module for ProxyModule {
269 fn input_schema(&self) -> serde_json::Value {
270 self.input_schema.clone()
271 }
272
273 fn output_schema(&self) -> serde_json::Value {
274 self.output_schema.clone()
275 }
276
277 fn description(&self) -> &str {
278 &self.description
279 }
280
281 async fn execute(
282 &self,
283 inputs: serde_json::Value,
284 _ctx: &Context<serde_json::Value>,
285 ) -> Result<serde_json::Value, ModuleError> {
286 let mut actual_path = self.url_path.clone();
287 let mut query: HashMap<String, String> = HashMap::new();
288 let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
289
290 if let Some(obj) = inputs.as_object() {
291 for (key, value) in obj {
292 if self.path_params.contains(key) {
293 let val_str = match value {
294 serde_json::Value::String(s) => s.clone(),
295 other => other.to_string(),
296 };
297 actual_path = actual_path.replace(
298 &format!("{{{key}}}"),
299 &percent_encode_path_segment(&val_str),
300 );
301 } else if self.http_method == "GET" {
302 let val_str = match value {
303 serde_json::Value::String(s) => s.clone(),
304 other => other.to_string(),
305 };
306 query.insert(key.clone(), val_str);
307 } else {
308 body.insert(key.clone(), value.clone());
309 }
310 }
311 }
312
313 if let Err(msg) = validate_path_params_filled(&actual_path) {
314 return Err(ModuleError::new(
315 apcore::errors::ErrorCode::ModuleExecuteError,
316 msg,
317 ));
318 }
319
320 let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
321
322 let mut request = match self.http_method.as_str() {
323 "GET" => self.client.get(&url),
324 "POST" => self.client.post(&url),
325 "PUT" => self.client.put(&url),
326 "PATCH" => self.client.patch(&url),
327 "DELETE" => self.client.delete(&url),
328 other => {
329 return Err(ModuleError::new(
330 apcore::errors::ErrorCode::ModuleExecuteError,
331 format!("Unsupported HTTP method: {other}"),
332 ))
333 }
334 };
335
336 if let Some(ref factory) = self.auth_header_factory {
338 for (header_name, header_value) in factory() {
339 request = request.header(&header_name, &header_value);
340 }
341 }
342
343 if !query.is_empty() {
344 request = request.query(&query.iter().collect::<Vec<_>>());
345 }
346 if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
347 request = request.json(&body);
348 }
349
350 let resp = request.send().await.map_err(|e| {
351 ModuleError::new(
352 apcore::errors::ErrorCode::ModuleExecuteError,
353 format!("HTTP request failed: {e}"),
354 )
355 })?;
356
357 let status = resp.status();
358 if status.is_success() {
359 if status.as_u16() == 204 {
360 return Ok(serde_json::json!({}));
361 }
362 resp.json().await.map_err(|e| {
363 ModuleError::new(
364 apcore::errors::ErrorCode::ModuleExecuteError,
365 format!("Failed to parse response JSON: {e}"),
366 )
367 })
368 } else {
369 let error_text = resp.text().await.unwrap_or_default();
370 let message = extract_error_message(&error_text);
371 Err(ModuleError::new(
372 apcore::errors::ErrorCode::ModuleExecuteError,
373 format!("HTTP {}: {}", status.as_u16(), message),
374 ))
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use serde_json::json;
383
384 #[test]
385 fn test_new_rejects_non_http_scheme() {
386 let result = HTTPProxyRegistryWriter::new("file:///etc/passwd".into(), None, 30.0);
387 assert!(result.is_err());
388 assert!(result
389 .unwrap_err()
390 .to_string()
391 .contains("scheme 'file' is not allowed"));
392 }
393
394 #[test]
395 fn test_new_rejects_invalid_url() {
396 let result = HTTPProxyRegistryWriter::new("not a url".into(), None, 30.0);
397 assert!(result.is_err());
398 }
399
400 #[test]
401 fn test_new_rejects_nan_timeout() {
402 let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, f64::NAN);
403 assert!(result.is_err());
404 assert!(result.unwrap_err().to_string().contains("timeout"));
405 }
406
407 #[test]
408 fn test_new_rejects_negative_timeout() {
409 let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, -1.0);
410 assert!(result.is_err());
411 }
412
413 #[test]
414 fn test_new_accepts_https_scheme() {
415 let result = HTTPProxyRegistryWriter::new("https://api.example.com".into(), None, 30.0);
416 assert!(result.is_ok());
417 }
418
419 #[test]
420 fn test_get_http_fields_defaults() {
421 let module = ScannedModule::new(
422 "test".into(),
423 "test".into(),
424 json!({}),
425 json!({}),
426 vec![],
427 "app:func".into(),
428 );
429 let (method, path) = get_http_fields(&module);
430 assert_eq!(method, "GET");
431 assert_eq!(path, "/");
432 }
433
434 #[test]
435 fn test_get_http_fields_from_metadata() {
436 let mut module = ScannedModule::new(
437 "test".into(),
438 "test".into(),
439 json!({}),
440 json!({}),
441 vec![],
442 "app:func".into(),
443 );
444 module.metadata.insert(
445 "http_method".into(),
446 serde_json::Value::String("POST".into()),
447 );
448 module.metadata.insert(
449 "url_path".into(),
450 serde_json::Value::String("/users".into()),
451 );
452 let (method, path) = get_http_fields(&module);
453 assert_eq!(method, "POST");
454 assert_eq!(path, "/users");
455 }
456
457 #[test]
458 fn test_extract_path_params() {
459 let params = extract_path_params("/users/{user_id}/tasks/{task_id}");
460 assert!(params.contains("user_id"));
461 assert!(params.contains("task_id"));
462 assert_eq!(params.len(), 2);
463 }
464
465 #[test]
466 fn test_extract_path_params_none() {
467 let params = extract_path_params("/users");
468 assert!(params.is_empty());
469 }
470
471 #[test]
472 fn test_extract_error_message_json_error_message() {
473 let body = r#"{"error_message": "not found"}"#;
474 assert_eq!(extract_error_message(body), "not found");
475 }
476
477 #[test]
478 fn test_extract_error_message_json_detail() {
479 let body = r#"{"detail": "unauthorized"}"#;
480 assert_eq!(extract_error_message(body), "unauthorized");
481 }
482
483 #[test]
484 fn test_extract_error_message_json_error() {
485 let body = r#"{"error": "bad request"}"#;
486 assert_eq!(extract_error_message(body), "bad request");
487 }
488
489 #[test]
490 fn test_extract_error_message_json_message() {
491 let body = r#"{"message": "server error"}"#;
492 assert_eq!(extract_error_message(body), "server error");
493 }
494
495 #[test]
496 fn test_extract_error_message_json_priority() {
497 let body = r#"{"error_message": "first", "message": "second"}"#;
499 assert_eq!(extract_error_message(body), "first");
500 }
501
502 #[test]
503 fn test_extract_error_message_plain_text_short() {
504 let body = "plain text error";
505 assert_eq!(extract_error_message(body), "plain text error");
506 }
507
508 #[test]
509 fn test_extract_error_message_plain_text_truncated() {
510 let body = "x".repeat(300);
511 let result = extract_error_message(&body);
512 assert_eq!(result.len(), 200);
513 }
514
515 #[test]
516 fn test_validate_path_params_filled_no_placeholders() {
517 assert!(validate_path_params_filled("/users/123/tasks/456").is_ok());
518 }
519
520 #[test]
521 fn test_validate_path_params_filled_static_path() {
522 assert!(validate_path_params_filled("/health").is_ok());
523 }
524
525 #[test]
526 fn test_validate_path_params_filled_unfilled_placeholder() {
527 let result = validate_path_params_filled("/users/{user_id}/tasks");
528 assert!(result.is_err());
529 let msg = result.unwrap_err();
530 assert!(
531 msg.contains("user_id"),
532 "error should name the unfilled param: {msg}"
533 );
534 }
535
536 #[test]
537 fn test_validate_path_params_filled_multiple_unfilled() {
538 let result = validate_path_params_filled("/users/{user_id}/tasks/{task_id}");
539 assert!(result.is_err());
540 let msg = result.unwrap_err();
541 assert!(msg.contains("user_id") || msg.contains("task_id"), "{msg}");
542 }
543
544 #[test]
545 fn test_safe_truncate_multibyte() {
546 let body = "\u{1F600}".repeat(300);
548 let result = safe_truncate(&body, 200);
549 assert_eq!(result.chars().count(), 200);
550 }
551}