Skip to main content

apcore_toolkit/output/
http_proxy_writer.rs

1// HTTP proxy registry writer.
2//
3// Registers scanned modules as HTTP proxy implementations that forward
4// requests to a running web API. Feature-gated behind `http-proxy`.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use thiserror::Error;
12use tracing::{debug, warn};
13
14use apcore::context::Context;
15use apcore::errors::ModuleError;
16use apcore::module::Module;
17use apcore::Registry;
18
19use crate::output::types::WriteResult;
20use crate::types::ScannedModule;
21
22/// Errors returned by [`HTTPProxyRegistryWriter::new`].
23#[derive(Debug, Error)]
24pub enum HTTPProxyWriterError {
25    /// `base_url` is not a valid URL or uses a non-http(s) scheme.
26    #[error("invalid base_url: {0}")]
27    InvalidBaseUrl(String),
28    /// `timeout_secs` is not a valid positive finite number.
29    #[error("invalid timeout_secs: {0}")]
30    InvalidTimeout(String),
31}
32
33/// Register scanned modules as HTTP proxy modules in the registry.
34///
35/// Each module's `execute()` sends an HTTP request to the target API
36/// instead of calling the handler directly.
37pub struct HTTPProxyRegistryWriter {
38    base_url: String,
39    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
40    client: reqwest::Client,
41}
42
43impl std::fmt::Debug for HTTPProxyRegistryWriter {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("HTTPProxyRegistryWriter")
46            .field("base_url", &self.base_url)
47            .field(
48                "auth_header_factory",
49                &self.auth_header_factory.as_ref().map(|_| "<factory>"),
50            )
51            .field("client", &self.client)
52            .finish()
53    }
54}
55
56impl HTTPProxyRegistryWriter {
57    /// Create a new HTTP proxy writer.
58    ///
59    /// - `base_url`: Base URL of the target API (must be `http://` or `https://`).
60    /// - `auth_header_factory`: Optional callable returning HTTP headers for auth.
61    /// - `timeout_secs`: HTTP request timeout in seconds (must be a positive finite number).
62    ///
63    /// # Errors
64    ///
65    /// Returns [`HTTPProxyWriterError::InvalidBaseUrl`] if `base_url` is not a valid URL
66    /// or its scheme is not `http` or `https` (SSRF prevention).
67    /// Returns [`HTTPProxyWriterError::InvalidTimeout`] if `timeout_secs` is not a
68    /// positive finite number.
69    pub fn new(
70        base_url: String,
71        auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
72        timeout_secs: f64,
73    ) -> Result<Self, HTTPProxyWriterError> {
74        let parsed = reqwest::Url::parse(&base_url)
75            .map_err(|e| HTTPProxyWriterError::InvalidBaseUrl(format!("'{}': {e}", base_url)))?;
76        if !matches!(parsed.scheme(), "http" | "https") {
77            return Err(HTTPProxyWriterError::InvalidBaseUrl(format!(
78                "scheme '{}' is not allowed — only http and https are permitted",
79                parsed.scheme()
80            )));
81        }
82
83        if !timeout_secs.is_finite() || timeout_secs <= 0.0 {
84            return Err(HTTPProxyWriterError::InvalidTimeout(format!(
85                "must be a positive finite number, got {timeout_secs}"
86            )));
87        }
88
89        let client = reqwest::Client::builder()
90            .timeout(std::time::Duration::from_secs_f64(timeout_secs))
91            .build()
92            .map_err(|e| {
93                HTTPProxyWriterError::InvalidBaseUrl(format!("failed to build HTTP client: {e}"))
94            })?;
95
96        Ok(Self {
97            base_url,
98            auth_header_factory: auth_header_factory.map(Arc::from),
99            client,
100        })
101    }
102
103    /// Register each ScannedModule as an HTTP proxy module.
104    pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
105        let mut results: Vec<WriteResult> = Vec::new();
106
107        for module in modules {
108            let (http_method, url_path) = get_http_fields(module);
109            let path_params = extract_path_params(&url_path);
110            let proxy = ProxyModule {
111                base_url: self.base_url.clone(),
112                http_method,
113                url_path,
114                path_params,
115                input_schema: module.input_schema.clone(),
116                output_schema: module.output_schema.clone(),
117                description: module.description.clone(),
118                auth_header_factory: self.auth_header_factory.clone(),
119                client: self.client.clone(),
120            };
121
122            let descriptor = apcore::registry::registry::ModuleDescriptor {
123                module_id: module.module_id.clone(),
124                name: Some(module.module_id.clone()),
125                description: module.description.clone(),
126                documentation: module.documentation.clone(),
127                input_schema: module.input_schema.clone(),
128                output_schema: module.output_schema.clone(),
129                version: module.version.clone(),
130                tags: module.tags.clone(),
131                annotations: module.annotations.clone(),
132                examples: module.examples.clone(),
133                metadata: module.metadata.clone(),
134                display: module.display.clone(),
135                sunset_date: None,
136                dependencies: vec![],
137                enabled: true,
138            };
139
140            match registry.register(&module.module_id, Box::new(proxy), descriptor) {
141                Ok(()) => {
142                    debug!("Registered HTTP proxy: {}", module.module_id);
143                    results.push(WriteResult::new(module.module_id.clone()));
144                }
145                Err(e) => {
146                    warn!(module_id = %module.module_id, error = %e, "HTTPProxyRegistryWriter registration failed");
147                    results.push(WriteResult::failed(
148                        module.module_id.clone(),
149                        None,
150                        e.to_string(),
151                    ));
152                }
153            }
154        }
155
156        results
157    }
158}
159
160/// Extract http_method and url_path from a ScannedModule's metadata.
161fn get_http_fields(module: &ScannedModule) -> (String, String) {
162    let http_method = module
163        .metadata
164        .get("http_method")
165        .and_then(|v| v.as_str())
166        .unwrap_or("GET")
167        .to_string();
168    let url_path = module
169        .metadata
170        .get("url_path")
171        .and_then(|v| v.as_str())
172        .unwrap_or("/")
173        .to_string();
174    (http_method, url_path)
175}
176
177/// Regex matching URL path parameters like `{user_id}`.
178static PATH_PARAM_RE: LazyLock<Regex> =
179    LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
180
181/// Validate that all `{param}` placeholders in `actual_path` were substituted.
182///
183/// Returns `Err` with the list of still-unfilled parameter names if any remain.
184fn validate_path_params_filled(actual_path: &str) -> Result<(), String> {
185    if PATH_PARAM_RE.is_match(actual_path) {
186        let unfilled: Vec<&str> = PATH_PARAM_RE
187            .captures_iter(actual_path)
188            .filter_map(|cap| cap.get(1).map(|m| m.as_str()))
189            .collect();
190        Err(format!(
191            "Missing required path parameters {:?} — inputs must supply values for all path params in '{actual_path}'",
192            unfilled
193        ))
194    } else {
195        Ok(())
196    }
197}
198
199/// Percent-encode a path segment value (RFC 3986 unreserved chars pass through).
200fn percent_encode_path_segment(s: &str) -> String {
201    let mut out = String::with_capacity(s.len());
202    for b in s.bytes() {
203        if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~') {
204            out.push(b as char);
205        } else {
206            out.push_str(&format!("%{:02X}", b));
207        }
208    }
209    out
210}
211
212/// Extract path parameter names from a URL pattern like `/users/{user_id}`.
213fn extract_path_params(url_path: &str) -> HashSet<String> {
214    PATH_PARAM_RE
215        .captures_iter(url_path)
216        .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
217        .collect()
218}
219
220/// Extract a human-readable error message from an HTTP error response body.
221///
222/// Tries to parse the body as JSON and looks for common error fields
223/// (`error_message`, `detail`, `error`, `message`) before falling back
224/// to a safely-truncated version of the raw text.
225fn extract_error_message(body: &str) -> String {
226    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
227        for key in &["error_message", "detail", "error", "message"] {
228            if let Some(val) = parsed.get(key) {
229                let msg = match val {
230                    serde_json::Value::String(s) => s.clone(),
231                    other => other.to_string(),
232                };
233                if !msg.is_empty() {
234                    return msg;
235                }
236            }
237        }
238    }
239
240    safe_truncate(body, 200)
241}
242
243/// Truncate a string to at most `max_chars` characters without panicking
244/// on multi-byte UTF-8 boundaries.
245fn safe_truncate(s: &str, max_chars: usize) -> String {
246    if s.chars().count() <= max_chars {
247        s.to_string()
248    } else {
249        s.chars().take(max_chars).collect()
250    }
251}
252
253/// A module that proxies requests to an HTTP API.
254struct ProxyModule {
255    base_url: String,
256    http_method: String,
257    url_path: String,
258    path_params: HashSet<String>,
259    input_schema: serde_json::Value,
260    output_schema: serde_json::Value,
261    description: String,
262    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
263    // Shared HTTP client — cloned from HTTPProxyRegistryWriter to reuse connection pool.
264    client: reqwest::Client,
265}
266
267#[async_trait]
268impl Module for ProxyModule {
269    fn input_schema(&self) -> serde_json::Value {
270        self.input_schema.clone()
271    }
272
273    fn output_schema(&self) -> serde_json::Value {
274        self.output_schema.clone()
275    }
276
277    fn description(&self) -> &str {
278        &self.description
279    }
280
281    async fn execute(
282        &self,
283        inputs: serde_json::Value,
284        _ctx: &Context<serde_json::Value>,
285    ) -> Result<serde_json::Value, ModuleError> {
286        let mut actual_path = self.url_path.clone();
287        let mut query: HashMap<String, String> = HashMap::new();
288        let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
289
290        if let Some(obj) = inputs.as_object() {
291            for (key, value) in obj {
292                if self.path_params.contains(key) {
293                    let val_str = match value {
294                        serde_json::Value::String(s) => s.clone(),
295                        other => other.to_string(),
296                    };
297                    actual_path = actual_path.replace(
298                        &format!("{{{key}}}"),
299                        &percent_encode_path_segment(&val_str),
300                    );
301                } else if self.http_method == "GET" {
302                    let val_str = match value {
303                        serde_json::Value::String(s) => s.clone(),
304                        other => other.to_string(),
305                    };
306                    query.insert(key.clone(), val_str);
307                } else {
308                    body.insert(key.clone(), value.clone());
309                }
310            }
311        }
312
313        if let Err(msg) = validate_path_params_filled(&actual_path) {
314            return Err(ModuleError::new(
315                apcore::errors::ErrorCode::ModuleExecuteError,
316                msg,
317            ));
318        }
319
320        let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
321
322        let mut request = match self.http_method.as_str() {
323            "GET" => self.client.get(&url),
324            "POST" => self.client.post(&url),
325            "PUT" => self.client.put(&url),
326            "PATCH" => self.client.patch(&url),
327            "DELETE" => self.client.delete(&url),
328            other => {
329                return Err(ModuleError::new(
330                    apcore::errors::ErrorCode::ModuleExecuteError,
331                    format!("Unsupported HTTP method: {other}"),
332                ))
333            }
334        };
335
336        // Apply auth headers from the factory, if configured
337        if let Some(ref factory) = self.auth_header_factory {
338            for (header_name, header_value) in factory() {
339                request = request.header(&header_name, &header_value);
340            }
341        }
342
343        if !query.is_empty() {
344            request = request.query(&query.iter().collect::<Vec<_>>());
345        }
346        if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
347            request = request.json(&body);
348        }
349
350        let resp = request.send().await.map_err(|e| {
351            ModuleError::new(
352                apcore::errors::ErrorCode::ModuleExecuteError,
353                format!("HTTP request failed: {e}"),
354            )
355        })?;
356
357        let status = resp.status();
358        if status.is_success() {
359            if status.as_u16() == 204 {
360                return Ok(serde_json::json!({}));
361            }
362            resp.json().await.map_err(|e| {
363                ModuleError::new(
364                    apcore::errors::ErrorCode::ModuleExecuteError,
365                    format!("Failed to parse response JSON: {e}"),
366                )
367            })
368        } else {
369            let error_text = resp.text().await.unwrap_or_default();
370            let message = extract_error_message(&error_text);
371            Err(ModuleError::new(
372                apcore::errors::ErrorCode::ModuleExecuteError,
373                format!("HTTP {}: {}", status.as_u16(), message),
374            ))
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use serde_json::json;
383
384    #[test]
385    fn test_new_rejects_non_http_scheme() {
386        let result = HTTPProxyRegistryWriter::new("file:///etc/passwd".into(), None, 30.0);
387        assert!(result.is_err());
388        assert!(result
389            .unwrap_err()
390            .to_string()
391            .contains("scheme 'file' is not allowed"));
392    }
393
394    #[test]
395    fn test_new_rejects_invalid_url() {
396        let result = HTTPProxyRegistryWriter::new("not a url".into(), None, 30.0);
397        assert!(result.is_err());
398    }
399
400    #[test]
401    fn test_new_rejects_nan_timeout() {
402        let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, f64::NAN);
403        assert!(result.is_err());
404        assert!(result.unwrap_err().to_string().contains("timeout"));
405    }
406
407    #[test]
408    fn test_new_rejects_negative_timeout() {
409        let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, -1.0);
410        assert!(result.is_err());
411    }
412
413    #[test]
414    fn test_new_accepts_https_scheme() {
415        let result = HTTPProxyRegistryWriter::new("https://api.example.com".into(), None, 30.0);
416        assert!(result.is_ok());
417    }
418
419    #[test]
420    fn test_get_http_fields_defaults() {
421        let module = ScannedModule::new(
422            "test".into(),
423            "test".into(),
424            json!({}),
425            json!({}),
426            vec![],
427            "app:func".into(),
428        );
429        let (method, path) = get_http_fields(&module);
430        assert_eq!(method, "GET");
431        assert_eq!(path, "/");
432    }
433
434    #[test]
435    fn test_get_http_fields_from_metadata() {
436        let mut module = ScannedModule::new(
437            "test".into(),
438            "test".into(),
439            json!({}),
440            json!({}),
441            vec![],
442            "app:func".into(),
443        );
444        module.metadata.insert(
445            "http_method".into(),
446            serde_json::Value::String("POST".into()),
447        );
448        module.metadata.insert(
449            "url_path".into(),
450            serde_json::Value::String("/users".into()),
451        );
452        let (method, path) = get_http_fields(&module);
453        assert_eq!(method, "POST");
454        assert_eq!(path, "/users");
455    }
456
457    #[test]
458    fn test_extract_path_params() {
459        let params = extract_path_params("/users/{user_id}/tasks/{task_id}");
460        assert!(params.contains("user_id"));
461        assert!(params.contains("task_id"));
462        assert_eq!(params.len(), 2);
463    }
464
465    #[test]
466    fn test_extract_path_params_none() {
467        let params = extract_path_params("/users");
468        assert!(params.is_empty());
469    }
470
471    #[test]
472    fn test_extract_error_message_json_error_message() {
473        let body = r#"{"error_message": "not found"}"#;
474        assert_eq!(extract_error_message(body), "not found");
475    }
476
477    #[test]
478    fn test_extract_error_message_json_detail() {
479        let body = r#"{"detail": "unauthorized"}"#;
480        assert_eq!(extract_error_message(body), "unauthorized");
481    }
482
483    #[test]
484    fn test_extract_error_message_json_error() {
485        let body = r#"{"error": "bad request"}"#;
486        assert_eq!(extract_error_message(body), "bad request");
487    }
488
489    #[test]
490    fn test_extract_error_message_json_message() {
491        let body = r#"{"message": "server error"}"#;
492        assert_eq!(extract_error_message(body), "server error");
493    }
494
495    #[test]
496    fn test_extract_error_message_json_priority() {
497        // error_message takes priority over message
498        let body = r#"{"error_message": "first", "message": "second"}"#;
499        assert_eq!(extract_error_message(body), "first");
500    }
501
502    #[test]
503    fn test_extract_error_message_plain_text_short() {
504        let body = "plain text error";
505        assert_eq!(extract_error_message(body), "plain text error");
506    }
507
508    #[test]
509    fn test_extract_error_message_plain_text_truncated() {
510        let body = "x".repeat(300);
511        let result = extract_error_message(&body);
512        assert_eq!(result.len(), 200);
513    }
514
515    #[test]
516    fn test_validate_path_params_filled_no_placeholders() {
517        assert!(validate_path_params_filled("/users/123/tasks/456").is_ok());
518    }
519
520    #[test]
521    fn test_validate_path_params_filled_static_path() {
522        assert!(validate_path_params_filled("/health").is_ok());
523    }
524
525    #[test]
526    fn test_validate_path_params_filled_unfilled_placeholder() {
527        let result = validate_path_params_filled("/users/{user_id}/tasks");
528        assert!(result.is_err());
529        let msg = result.unwrap_err();
530        assert!(
531            msg.contains("user_id"),
532            "error should name the unfilled param: {msg}"
533        );
534    }
535
536    #[test]
537    fn test_validate_path_params_filled_multiple_unfilled() {
538        let result = validate_path_params_filled("/users/{user_id}/tasks/{task_id}");
539        assert!(result.is_err());
540        let msg = result.unwrap_err();
541        assert!(msg.contains("user_id") || msg.contains("task_id"), "{msg}");
542    }
543
544    #[test]
545    fn test_safe_truncate_multibyte() {
546        // Each emoji is multiple bytes but one char
547        let body = "\u{1F600}".repeat(300);
548        let result = safe_truncate(&body, 200);
549        assert_eq!(result.chars().count(), 200);
550    }
551}