1use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use thiserror::Error;
12use tracing::{debug, warn};
13
14use apcore::context::Context;
15use apcore::errors::ModuleError;
16use apcore::module::Module;
17use apcore::Registry;
18
19use crate::http_verb_map::extract_path_param_names;
20use crate::output::types::WriteResult;
21use crate::types::ScannedModule;
22
23#[derive(Debug, Error)]
25pub enum HTTPProxyRegistryWriterError {
26 #[error("invalid base_url: {0}")]
28 InvalidBaseUrl(String),
29 #[error("invalid timeout_secs: {0}")]
31 InvalidTimeout(String),
32}
33
34pub struct HTTPProxyRegistryWriter {
39 base_url: String,
40 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
41 client: reqwest::Client,
42}
43
44impl std::fmt::Debug for HTTPProxyRegistryWriter {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("HTTPProxyRegistryWriter")
47 .field("base_url", &self.base_url)
48 .field(
49 "auth_header_factory",
50 &self.auth_header_factory.as_ref().map(|_| "<factory>"),
51 )
52 .field("client", &self.client)
53 .finish()
54 }
55}
56
57impl HTTPProxyRegistryWriter {
58 pub fn new(
71 base_url: String,
72 auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
73 timeout_secs: f64,
74 ) -> Result<Self, HTTPProxyRegistryWriterError> {
75 let parsed = reqwest::Url::parse(&base_url).map_err(|e| {
76 HTTPProxyRegistryWriterError::InvalidBaseUrl(format!("'{}': {e}", base_url))
77 })?;
78 if !matches!(parsed.scheme(), "http" | "https") {
79 return Err(HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
80 "scheme '{}' is not allowed — only http and https are permitted",
81 parsed.scheme()
82 )));
83 }
84
85 if !timeout_secs.is_finite() || timeout_secs <= 0.0 {
86 return Err(HTTPProxyRegistryWriterError::InvalidTimeout(format!(
87 "must be a positive finite number, got {timeout_secs}"
88 )));
89 }
90
91 let client = reqwest::Client::builder()
92 .timeout(std::time::Duration::from_secs_f64(timeout_secs))
93 .build()
94 .map_err(|e| {
95 HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
96 "failed to build HTTP client: {e}"
97 ))
98 })?;
99
100 Ok(Self {
101 base_url,
102 auth_header_factory: auth_header_factory.map(Arc::from),
103 client,
104 })
105 }
106
107 pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
109 let mut results: Vec<WriteResult> = Vec::new();
110
111 for module in modules {
112 let (http_method, url_path) = get_http_fields(module);
113 let path_params = extract_path_param_names(&url_path);
114 let proxy = ProxyModule {
115 base_url: self.base_url.clone(),
116 http_method,
117 url_path,
118 path_params,
119 input_schema: module.input_schema.clone(),
120 output_schema: module.output_schema.clone(),
121 description: module.description.clone(),
122 auth_header_factory: self.auth_header_factory.clone(),
123 client: self.client.clone(),
124 };
125
126 let descriptor = apcore::registry::registry::ModuleDescriptor {
127 module_id: module.module_id.clone(),
128 name: Some(module.module_id.clone()),
129 description: module.description.clone(),
130 documentation: module.documentation.clone(),
131 input_schema: module.input_schema.clone(),
132 output_schema: module.output_schema.clone(),
133 version: module.version.clone(),
134 tags: module.tags.clone(),
135 annotations: module.annotations.clone(),
136 examples: module.examples.clone(),
137 metadata: module.metadata.clone(),
138 display: module.display.clone(),
139 sunset_date: None,
140 dependencies: vec![],
141 enabled: true,
142 };
143
144 match registry.register(&module.module_id, Box::new(proxy), descriptor) {
145 Ok(()) => {
146 debug!("Registered HTTP proxy: {}", module.module_id);
147 results.push(WriteResult::new(module.module_id.clone()));
148 }
149 Err(e) => {
150 warn!(module_id = %module.module_id, error = %e, "HTTPProxyRegistryWriter registration failed");
151 results.push(WriteResult::failed(
152 module.module_id.clone(),
153 None,
154 e.to_string(),
155 ));
156 }
157 }
158 }
159
160 results
161 }
162}
163
164fn get_http_fields(module: &ScannedModule) -> (String, String) {
166 let http_method = module
167 .metadata
168 .get("http_method")
169 .and_then(|v| v.as_str())
170 .unwrap_or("GET")
171 .to_string();
172 let url_path = module
173 .metadata
174 .get("url_path")
175 .and_then(|v| v.as_str())
176 .unwrap_or("/")
177 .to_string();
178 (http_method, url_path)
179}
180
181const BODY_METHODS: &[&str] = &["POST", "PUT", "PATCH"];
186
187static PATH_PARAM_RE: LazyLock<Regex> =
189 LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
190
191fn validate_path_params_filled(actual_path: &str) -> Result<(), String> {
195 if PATH_PARAM_RE.is_match(actual_path) {
196 let unfilled: Vec<&str> = PATH_PARAM_RE
197 .captures_iter(actual_path)
198 .filter_map(|cap| cap.get(1).map(|m| m.as_str()))
199 .collect();
200 Err(format!(
201 "Missing required path parameters {:?} — inputs must supply values for all path params in '{actual_path}'",
202 unfilled
203 ))
204 } else {
205 Ok(())
206 }
207}
208
209fn percent_encode_path_segment(s: &str) -> String {
215 let mut out = String::with_capacity(s.len());
216 for b in s.bytes() {
217 if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~') {
218 out.push(b as char);
219 } else {
220 out.push_str(&format!("%{:02X}", b));
221 }
222 }
223 out
224}
225
226fn extract_error_message(body: &str) -> String {
232 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
233 for key in &["error_message", "detail", "error", "message"] {
234 if let Some(val) = parsed.get(key) {
235 let msg = match val {
236 serde_json::Value::String(s) => s.clone(),
237 other => other.to_string(),
238 };
239 if !msg.is_empty() {
240 return msg;
241 }
242 }
243 }
244 }
245
246 safe_truncate(body, 200)
247}
248
249fn safe_truncate(s: &str, max_chars: usize) -> String {
255 if s.chars().count() <= max_chars {
256 s.to_string()
257 } else {
258 s.chars().take(max_chars).collect()
259 }
260}
261
262struct ProxyModule {
264 base_url: String,
265 http_method: String,
266 url_path: String,
267 path_params: HashSet<String>,
268 input_schema: serde_json::Value,
269 output_schema: serde_json::Value,
270 description: String,
271 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
272 client: reqwest::Client,
274}
275
276#[async_trait]
277impl Module for ProxyModule {
278 fn input_schema(&self) -> serde_json::Value {
279 self.input_schema.clone()
280 }
281
282 fn output_schema(&self) -> serde_json::Value {
283 self.output_schema.clone()
284 }
285
286 fn description(&self) -> &str {
287 &self.description
288 }
289
290 async fn execute(
291 &self,
292 inputs: serde_json::Value,
293 _ctx: &Context<serde_json::Value>,
294 ) -> Result<serde_json::Value, ModuleError> {
295 let mut actual_path = self.url_path.clone();
296 let mut query: HashMap<String, String> = HashMap::new();
297 let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
298
299 if let Some(obj) = inputs.as_object() {
300 let uses_body = BODY_METHODS.contains(&self.http_method.as_str());
301 for (key, value) in obj {
302 if self.path_params.contains(key) {
303 let val_str = match value {
304 serde_json::Value::String(s) => s.clone(),
305 other => other.to_string(),
306 };
307 actual_path = actual_path.replace(
308 &format!("{{{key}}}"),
309 &percent_encode_path_segment(&val_str),
310 );
311 } else if uses_body {
312 body.insert(key.clone(), value.clone());
313 } else {
314 let val_str = match value {
317 serde_json::Value::String(s) => s.clone(),
318 other => other.to_string(),
319 };
320 query.insert(key.clone(), val_str);
321 }
322 }
323 }
324
325 if let Err(msg) = validate_path_params_filled(&actual_path) {
326 return Err(ModuleError::new(
327 apcore::errors::ErrorCode::ModuleExecuteError,
328 msg,
329 ));
330 }
331
332 let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
333
334 let mut request = match self.http_method.as_str() {
335 "GET" => self.client.get(&url),
336 "POST" => self.client.post(&url),
337 "PUT" => self.client.put(&url),
338 "PATCH" => self.client.patch(&url),
339 "DELETE" => self.client.delete(&url),
340 other => {
341 return Err(ModuleError::new(
342 apcore::errors::ErrorCode::ModuleExecuteError,
343 format!("Unsupported HTTP method: {other}"),
344 ))
345 }
346 };
347
348 if let Some(ref factory) = self.auth_header_factory {
350 for (header_name, header_value) in factory() {
351 request = request.header(&header_name, &header_value);
352 }
353 }
354
355 if !query.is_empty() {
356 request = request.query(&query.iter().collect::<Vec<_>>());
357 }
358 if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
359 request = request.json(&body);
360 }
361
362 let resp = request.send().await.map_err(|e| {
363 ModuleError::new(
364 apcore::errors::ErrorCode::ModuleExecuteError,
365 format!("HTTP request failed: {e}"),
366 )
367 })?;
368
369 let status = resp.status();
370 if status.is_success() {
371 if status.as_u16() == 204 {
372 return Ok(serde_json::json!({}));
373 }
374 resp.json().await.map_err(|e| {
375 ModuleError::new(
376 apcore::errors::ErrorCode::ModuleExecuteError,
377 format!("Failed to parse response JSON: {e}"),
378 )
379 })
380 } else {
381 let error_text = resp.text().await.unwrap_or_default();
382 let message = extract_error_message(&error_text);
383 Err(ModuleError::new(
384 apcore::errors::ErrorCode::ModuleExecuteError,
385 format!("HTTP {}: {}", status.as_u16(), message),
386 ))
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use serde_json::json;
395
396 #[test]
397 fn test_new_rejects_non_http_scheme() {
398 let result = HTTPProxyRegistryWriter::new("file:///etc/passwd".into(), None, 30.0);
399 assert!(result.is_err());
400 assert!(result
401 .unwrap_err()
402 .to_string()
403 .contains("scheme 'file' is not allowed"));
404 }
405
406 #[test]
407 fn test_new_rejects_invalid_url() {
408 let result = HTTPProxyRegistryWriter::new("not a url".into(), None, 30.0);
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_new_rejects_nan_timeout() {
414 let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, f64::NAN);
415 assert!(result.is_err());
416 assert!(result.unwrap_err().to_string().contains("timeout"));
417 }
418
419 #[test]
420 fn test_new_rejects_negative_timeout() {
421 let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, -1.0);
422 assert!(result.is_err());
423 }
424
425 #[test]
426 fn test_new_accepts_https_scheme() {
427 let result = HTTPProxyRegistryWriter::new("https://api.example.com".into(), None, 30.0);
428 assert!(result.is_ok());
429 }
430
431 #[test]
432 fn test_get_http_fields_defaults() {
433 let module = ScannedModule::new(
434 "test".into(),
435 "test".into(),
436 json!({}),
437 json!({}),
438 vec![],
439 "app:func".into(),
440 );
441 let (method, path) = get_http_fields(&module);
442 assert_eq!(method, "GET");
443 assert_eq!(path, "/");
444 }
445
446 #[test]
447 fn test_get_http_fields_from_metadata() {
448 let mut module = ScannedModule::new(
449 "test".into(),
450 "test".into(),
451 json!({}),
452 json!({}),
453 vec![],
454 "app:func".into(),
455 );
456 module.metadata.insert(
457 "http_method".into(),
458 serde_json::Value::String("POST".into()),
459 );
460 module.metadata.insert(
461 "url_path".into(),
462 serde_json::Value::String("/users".into()),
463 );
464 let (method, path) = get_http_fields(&module);
465 assert_eq!(method, "POST");
466 assert_eq!(path, "/users");
467 }
468
469 #[test]
470 fn test_extract_path_params() {
471 let params = extract_path_param_names("/users/{user_id}/tasks/{task_id}");
472 assert!(params.contains("user_id"));
473 assert!(params.contains("task_id"));
474 assert_eq!(params.len(), 2);
475 }
476
477 #[test]
478 fn test_extract_path_params_none() {
479 let params = extract_path_param_names("/users");
480 assert!(params.is_empty());
481 }
482
483 #[test]
484 fn test_extract_path_params_colon_style() {
485 let params = extract_path_param_names("/users/:id");
490 assert!(
491 params.contains("id"),
492 "colon-style param ':id' should be recognised; got: {params:?}"
493 );
494 assert_eq!(params.len(), 1);
495 }
496
497 #[test]
498 fn test_extract_path_params_mixed_styles() {
499 let params = extract_path_param_names("/users/:user_id/tasks/{task_id}");
500 assert!(params.contains("user_id"));
501 assert!(params.contains("task_id"));
502 assert_eq!(params.len(), 2);
503 }
504
505 #[test]
506 fn test_extract_error_message_json_error_message() {
507 let body = r#"{"error_message": "not found"}"#;
508 assert_eq!(extract_error_message(body), "not found");
509 }
510
511 #[test]
512 fn test_extract_error_message_json_detail() {
513 let body = r#"{"detail": "unauthorized"}"#;
514 assert_eq!(extract_error_message(body), "unauthorized");
515 }
516
517 #[test]
518 fn test_extract_error_message_json_error() {
519 let body = r#"{"error": "bad request"}"#;
520 assert_eq!(extract_error_message(body), "bad request");
521 }
522
523 #[test]
524 fn test_extract_error_message_json_message() {
525 let body = r#"{"message": "server error"}"#;
526 assert_eq!(extract_error_message(body), "server error");
527 }
528
529 #[test]
530 fn test_extract_error_message_json_priority() {
531 let body = r#"{"error_message": "first", "message": "second"}"#;
533 assert_eq!(extract_error_message(body), "first");
534 }
535
536 #[test]
537 fn test_extract_error_message_plain_text_short() {
538 let body = "plain text error";
539 assert_eq!(extract_error_message(body), "plain text error");
540 }
541
542 #[test]
543 fn test_extract_error_message_plain_text_truncated() {
544 let body = "x".repeat(300);
545 let result = extract_error_message(&body);
546 assert_eq!(result.len(), 200);
547 }
548
549 #[test]
550 fn test_validate_path_params_filled_no_placeholders() {
551 assert!(validate_path_params_filled("/users/123/tasks/456").is_ok());
552 }
553
554 #[test]
555 fn test_validate_path_params_filled_static_path() {
556 assert!(validate_path_params_filled("/health").is_ok());
557 }
558
559 #[test]
560 fn test_validate_path_params_filled_unfilled_placeholder() {
561 let result = validate_path_params_filled("/users/{user_id}/tasks");
562 assert!(result.is_err());
563 let msg = result.unwrap_err();
564 assert!(
565 msg.contains("user_id"),
566 "error should name the unfilled param: {msg}"
567 );
568 }
569
570 #[test]
571 fn test_validate_path_params_filled_multiple_unfilled() {
572 let result = validate_path_params_filled("/users/{user_id}/tasks/{task_id}");
573 assert!(result.is_err());
574 let msg = result.unwrap_err();
575 assert!(msg.contains("user_id") || msg.contains("task_id"), "{msg}");
576 }
577
578 #[test]
579 fn test_safe_truncate_multibyte() {
580 let body = "\u{1F600}".repeat(300);
582 let result = safe_truncate(&body, 200);
583 assert_eq!(result.chars().count(), 200);
584 }
585
586 #[test]
592 fn test_body_methods_set_contents() {
593 assert!(BODY_METHODS.contains(&"POST"));
594 assert!(BODY_METHODS.contains(&"PUT"));
595 assert!(BODY_METHODS.contains(&"PATCH"));
596 assert!(!BODY_METHODS.contains(&"GET"));
597 assert!(!BODY_METHODS.contains(&"DELETE"));
598 assert!(!BODY_METHODS.contains(&"HEAD"));
599 assert!(!BODY_METHODS.contains(&"OPTIONS"));
600 }
601}