Skip to main content

apcore_toolkit/output/
http_proxy_writer.rs

1// HTTP proxy registry writer.
2//
3// Registers scanned modules as HTTP proxy implementations that forward
4// requests to a running web API. Feature-gated behind `http-proxy`.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use thiserror::Error;
12use tracing::{debug, warn};
13
14use apcore::context::Context;
15use apcore::errors::ModuleError;
16use apcore::module::Module;
17use apcore::Registry;
18
19use crate::http_verb_map::extract_path_param_names;
20use crate::output::types::WriteResult;
21use crate::types::ScannedModule;
22
23/// Errors returned by [`HTTPProxyRegistryWriter::new`].
24#[derive(Debug, Error)]
25pub enum HTTPProxyRegistryWriterError {
26    /// `base_url` is not a valid URL or uses a non-http(s) scheme.
27    #[error("invalid base_url: {0}")]
28    InvalidBaseUrl(String),
29    /// `timeout_secs` is not a valid positive finite number.
30    #[error("invalid timeout_secs: {0}")]
31    InvalidTimeout(String),
32}
33
34/// Register scanned modules as HTTP proxy modules in the registry.
35///
36/// Each module's `execute()` sends an HTTP request to the target API
37/// instead of calling the handler directly.
38pub struct HTTPProxyRegistryWriter {
39    base_url: String,
40    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
41    client: reqwest::Client,
42}
43
44impl std::fmt::Debug for HTTPProxyRegistryWriter {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("HTTPProxyRegistryWriter")
47            .field("base_url", &self.base_url)
48            .field(
49                "auth_header_factory",
50                &self.auth_header_factory.as_ref().map(|_| "<factory>"),
51            )
52            .field("client", &self.client)
53            .finish()
54    }
55}
56
57impl HTTPProxyRegistryWriter {
58    /// Create a new HTTP proxy writer.
59    ///
60    /// - `base_url`: Base URL of the target API (must be `http://` or `https://`).
61    /// - `auth_header_factory`: Optional callable returning HTTP headers for auth.
62    /// - `timeout_secs`: HTTP request timeout in seconds (must be a positive finite number).
63    ///
64    /// # Errors
65    ///
66    /// Returns [`HTTPProxyRegistryWriterError::InvalidBaseUrl`] if `base_url` is not a valid URL
67    /// or its scheme is not `http` or `https` (SSRF prevention).
68    /// Returns [`HTTPProxyRegistryWriterError::InvalidTimeout`] if `timeout_secs` is not a
69    /// positive finite number.
70    pub fn new(
71        base_url: String,
72        auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
73        timeout_secs: f64,
74    ) -> Result<Self, HTTPProxyRegistryWriterError> {
75        let parsed = reqwest::Url::parse(&base_url).map_err(|e| {
76            HTTPProxyRegistryWriterError::InvalidBaseUrl(format!("'{}': {e}", base_url))
77        })?;
78        if !matches!(parsed.scheme(), "http" | "https") {
79            return Err(HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
80                "scheme '{}' is not allowed — only http and https are permitted",
81                parsed.scheme()
82            )));
83        }
84
85        if !timeout_secs.is_finite() || timeout_secs <= 0.0 {
86            return Err(HTTPProxyRegistryWriterError::InvalidTimeout(format!(
87                "must be a positive finite number, got {timeout_secs}"
88            )));
89        }
90
91        let client = reqwest::Client::builder()
92            .timeout(std::time::Duration::from_secs_f64(timeout_secs))
93            .build()
94            .map_err(|e| {
95                HTTPProxyRegistryWriterError::InvalidBaseUrl(format!(
96                    "failed to build HTTP client: {e}"
97                ))
98            })?;
99
100        Ok(Self {
101            base_url,
102            auth_header_factory: auth_header_factory.map(Arc::from),
103            client,
104        })
105    }
106
107    /// Register each ScannedModule as an HTTP proxy module.
108    pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
109        let mut results: Vec<WriteResult> = Vec::new();
110
111        for module in modules {
112            let (http_method, url_path) = get_http_fields(module);
113            let path_params = extract_path_param_names(&url_path);
114            let proxy = ProxyModule {
115                base_url: self.base_url.clone(),
116                http_method,
117                url_path,
118                path_params,
119                input_schema: module.input_schema.clone(),
120                output_schema: module.output_schema.clone(),
121                description: module.description.clone(),
122                auth_header_factory: self.auth_header_factory.clone(),
123                client: self.client.clone(),
124            };
125
126            let descriptor = apcore::registry::registry::ModuleDescriptor {
127                module_id: module.module_id.clone(),
128                name: Some(module.module_id.clone()),
129                description: module.description.clone(),
130                documentation: module.documentation.clone(),
131                input_schema: module.input_schema.clone(),
132                output_schema: module.output_schema.clone(),
133                version: module.version.clone(),
134                tags: module.tags.clone(),
135                annotations: module.annotations.clone(),
136                examples: module.examples.clone(),
137                metadata: module.metadata.clone(),
138                display: module.display.clone(),
139                sunset_date: None,
140                dependencies: vec![],
141                enabled: true,
142            };
143
144            match registry.register(&module.module_id, Box::new(proxy), descriptor) {
145                Ok(()) => {
146                    debug!("Registered HTTP proxy: {}", module.module_id);
147                    results.push(WriteResult::new(module.module_id.clone()));
148                }
149                Err(e) => {
150                    warn!(module_id = %module.module_id, error = %e, "HTTPProxyRegistryWriter registration failed");
151                    results.push(WriteResult::failed(
152                        module.module_id.clone(),
153                        None,
154                        e.to_string(),
155                    ));
156                }
157            }
158        }
159
160        results
161    }
162}
163
164/// Extract http_method and url_path from a ScannedModule's metadata.
165fn get_http_fields(module: &ScannedModule) -> (String, String) {
166    let http_method = module
167        .metadata
168        .get("http_method")
169        .and_then(|v| v.as_str())
170        .unwrap_or("GET")
171        .to_string();
172    let url_path = module
173        .metadata
174        .get("url_path")
175        .and_then(|v| v.as_str())
176        .unwrap_or("/")
177        .to_string();
178    (http_method, url_path)
179}
180
181/// HTTP methods that conventionally carry a JSON request body. Other
182/// methods (`GET`, `HEAD`, `DELETE`, `OPTIONS`) forward non-path inputs
183/// via the query string so they are not silently dropped, matching the
184/// Python and TypeScript SDKs.
185const BODY_METHODS: &[&str] = &["POST", "PUT", "PATCH"];
186
187/// Regex matching URL path parameters like `{user_id}`.
188static PATH_PARAM_RE: LazyLock<Regex> =
189    LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
190
191/// Validate that all `{param}` placeholders in `actual_path` were substituted.
192///
193/// Returns `Err` with the list of still-unfilled parameter names if any remain.
194fn validate_path_params_filled(actual_path: &str) -> Result<(), String> {
195    if PATH_PARAM_RE.is_match(actual_path) {
196        let unfilled: Vec<&str> = PATH_PARAM_RE
197            .captures_iter(actual_path)
198            .filter_map(|cap| cap.get(1).map(|m| m.as_str()))
199            .collect();
200        Err(format!(
201            "Missing required path parameters {:?} — inputs must supply values for all path params in '{actual_path}'",
202            unfilled
203        ))
204    } else {
205        Ok(())
206    }
207}
208
209/// Percent-encode a single path segment value (RFC 3986 §2.3 unreserved chars pass through).
210///
211/// This is intentionally a private helper — it encodes exactly the characters
212/// that are unsafe in a URL path segment. Unreserved characters (`A-Z a-z 0-9 - . _ ~`)
213/// are passed through unchanged; all other bytes are percent-encoded as `%XX`.
214fn percent_encode_path_segment(s: &str) -> String {
215    let mut out = String::with_capacity(s.len());
216    for b in s.bytes() {
217        if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~') {
218            out.push(b as char);
219        } else {
220            out.push_str(&format!("%{:02X}", b));
221        }
222    }
223    out
224}
225
226/// Extract a human-readable error message from an HTTP error response body.
227///
228/// Private helper — tries to parse the body as JSON and looks for common error fields
229/// (`error_message`, `detail`, `error`, `message`) in that priority order, before
230/// falling back to a safely-truncated version of the raw text (max 200 characters).
231fn extract_error_message(body: &str) -> String {
232    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
233        for key in &["error_message", "detail", "error", "message"] {
234            if let Some(val) = parsed.get(key) {
235                let msg = match val {
236                    serde_json::Value::String(s) => s.clone(),
237                    other => other.to_string(),
238                };
239                if !msg.is_empty() {
240                    return msg;
241                }
242            }
243        }
244    }
245
246    safe_truncate(body, 200)
247}
248
249/// Truncate a string to at most `max_chars` characters without panicking
250/// on multi-byte UTF-8 boundaries.
251///
252/// Private helper — counts Unicode scalar values (chars), not bytes, so that
253/// multi-byte sequences (e.g. emoji) are each counted as one character.
254fn safe_truncate(s: &str, max_chars: usize) -> String {
255    if s.chars().count() <= max_chars {
256        s.to_string()
257    } else {
258        s.chars().take(max_chars).collect()
259    }
260}
261
262/// A module that proxies requests to an HTTP API.
263struct ProxyModule {
264    base_url: String,
265    http_method: String,
266    url_path: String,
267    path_params: HashSet<String>,
268    input_schema: serde_json::Value,
269    output_schema: serde_json::Value,
270    description: String,
271    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
272    // Shared HTTP client — cloned from HTTPProxyRegistryWriter to reuse connection pool.
273    client: reqwest::Client,
274}
275
276#[async_trait]
277impl Module for ProxyModule {
278    fn input_schema(&self) -> serde_json::Value {
279        self.input_schema.clone()
280    }
281
282    fn output_schema(&self) -> serde_json::Value {
283        self.output_schema.clone()
284    }
285
286    fn description(&self) -> &str {
287        &self.description
288    }
289
290    async fn execute(
291        &self,
292        inputs: serde_json::Value,
293        _ctx: &Context<serde_json::Value>,
294    ) -> Result<serde_json::Value, ModuleError> {
295        let mut actual_path = self.url_path.clone();
296        let mut query: HashMap<String, String> = HashMap::new();
297        let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
298
299        if let Some(obj) = inputs.as_object() {
300            let uses_body = BODY_METHODS.contains(&self.http_method.as_str());
301            for (key, value) in obj {
302                if self.path_params.contains(key) {
303                    let val_str = match value {
304                        serde_json::Value::String(s) => s.clone(),
305                        other => other.to_string(),
306                    };
307                    actual_path = actual_path.replace(
308                        &format!("{{{key}}}"),
309                        &percent_encode_path_segment(&val_str),
310                    );
311                } else if uses_body {
312                    body.insert(key.clone(), value.clone());
313                } else {
314                    // GET / HEAD / DELETE / OPTIONS — forward as query
315                    // string to mirror Python / TypeScript behaviour.
316                    let val_str = match value {
317                        serde_json::Value::String(s) => s.clone(),
318                        other => other.to_string(),
319                    };
320                    query.insert(key.clone(), val_str);
321                }
322            }
323        }
324
325        if let Err(msg) = validate_path_params_filled(&actual_path) {
326            return Err(ModuleError::new(
327                apcore::errors::ErrorCode::ModuleExecuteError,
328                msg,
329            ));
330        }
331
332        let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
333
334        let mut request = match self.http_method.as_str() {
335            "GET" => self.client.get(&url),
336            "POST" => self.client.post(&url),
337            "PUT" => self.client.put(&url),
338            "PATCH" => self.client.patch(&url),
339            "DELETE" => self.client.delete(&url),
340            other => {
341                return Err(ModuleError::new(
342                    apcore::errors::ErrorCode::ModuleExecuteError,
343                    format!("Unsupported HTTP method: {other}"),
344                ))
345            }
346        };
347
348        // Apply auth headers from the factory, if configured
349        if let Some(ref factory) = self.auth_header_factory {
350            for (header_name, header_value) in factory() {
351                request = request.header(&header_name, &header_value);
352            }
353        }
354
355        if !query.is_empty() {
356            request = request.query(&query.iter().collect::<Vec<_>>());
357        }
358        if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
359            request = request.json(&body);
360        }
361
362        let resp = request.send().await.map_err(|e| {
363            ModuleError::new(
364                apcore::errors::ErrorCode::ModuleExecuteError,
365                format!("HTTP request failed: {e}"),
366            )
367        })?;
368
369        let status = resp.status();
370        if status.is_success() {
371            if status.as_u16() == 204 {
372                return Ok(serde_json::json!({}));
373            }
374            resp.json().await.map_err(|e| {
375                ModuleError::new(
376                    apcore::errors::ErrorCode::ModuleExecuteError,
377                    format!("Failed to parse response JSON: {e}"),
378                )
379            })
380        } else {
381            let error_text = resp.text().await.unwrap_or_default();
382            let message = extract_error_message(&error_text);
383            Err(ModuleError::new(
384                apcore::errors::ErrorCode::ModuleExecuteError,
385                format!("HTTP {}: {}", status.as_u16(), message),
386            ))
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use serde_json::json;
395
396    #[test]
397    fn test_new_rejects_non_http_scheme() {
398        let result = HTTPProxyRegistryWriter::new("file:///etc/passwd".into(), None, 30.0);
399        assert!(result.is_err());
400        assert!(result
401            .unwrap_err()
402            .to_string()
403            .contains("scheme 'file' is not allowed"));
404    }
405
406    #[test]
407    fn test_new_rejects_invalid_url() {
408        let result = HTTPProxyRegistryWriter::new("not a url".into(), None, 30.0);
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_new_rejects_nan_timeout() {
414        let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, f64::NAN);
415        assert!(result.is_err());
416        assert!(result.unwrap_err().to_string().contains("timeout"));
417    }
418
419    #[test]
420    fn test_new_rejects_negative_timeout() {
421        let result = HTTPProxyRegistryWriter::new("http://localhost".into(), None, -1.0);
422        assert!(result.is_err());
423    }
424
425    #[test]
426    fn test_new_accepts_https_scheme() {
427        let result = HTTPProxyRegistryWriter::new("https://api.example.com".into(), None, 30.0);
428        assert!(result.is_ok());
429    }
430
431    #[test]
432    fn test_get_http_fields_defaults() {
433        let module = ScannedModule::new(
434            "test".into(),
435            "test".into(),
436            json!({}),
437            json!({}),
438            vec![],
439            "app:func".into(),
440        );
441        let (method, path) = get_http_fields(&module);
442        assert_eq!(method, "GET");
443        assert_eq!(path, "/");
444    }
445
446    #[test]
447    fn test_get_http_fields_from_metadata() {
448        let mut module = ScannedModule::new(
449            "test".into(),
450            "test".into(),
451            json!({}),
452            json!({}),
453            vec![],
454            "app:func".into(),
455        );
456        module.metadata.insert(
457            "http_method".into(),
458            serde_json::Value::String("POST".into()),
459        );
460        module.metadata.insert(
461            "url_path".into(),
462            serde_json::Value::String("/users".into()),
463        );
464        let (method, path) = get_http_fields(&module);
465        assert_eq!(method, "POST");
466        assert_eq!(path, "/users");
467    }
468
469    #[test]
470    fn test_extract_path_params() {
471        let params = extract_path_param_names("/users/{user_id}/tasks/{task_id}");
472        assert!(params.contains("user_id"));
473        assert!(params.contains("task_id"));
474        assert_eq!(params.len(), 2);
475    }
476
477    #[test]
478    fn test_extract_path_params_none() {
479        let params = extract_path_param_names("/users");
480        assert!(params.is_empty());
481    }
482
483    #[test]
484    fn test_extract_path_params_colon_style() {
485        // Regression test: colon-style params must not be silently dropped.
486        // The private PATH_PARAM_RE only handled brace-style; this test
487        // verifies that extract_path_param_names (from http_verb_map) handles
488        // both styles correctly.
489        let params = extract_path_param_names("/users/:id");
490        assert!(
491            params.contains("id"),
492            "colon-style param ':id' should be recognised; got: {params:?}"
493        );
494        assert_eq!(params.len(), 1);
495    }
496
497    #[test]
498    fn test_extract_path_params_mixed_styles() {
499        let params = extract_path_param_names("/users/:user_id/tasks/{task_id}");
500        assert!(params.contains("user_id"));
501        assert!(params.contains("task_id"));
502        assert_eq!(params.len(), 2);
503    }
504
505    #[test]
506    fn test_extract_error_message_json_error_message() {
507        let body = r#"{"error_message": "not found"}"#;
508        assert_eq!(extract_error_message(body), "not found");
509    }
510
511    #[test]
512    fn test_extract_error_message_json_detail() {
513        let body = r#"{"detail": "unauthorized"}"#;
514        assert_eq!(extract_error_message(body), "unauthorized");
515    }
516
517    #[test]
518    fn test_extract_error_message_json_error() {
519        let body = r#"{"error": "bad request"}"#;
520        assert_eq!(extract_error_message(body), "bad request");
521    }
522
523    #[test]
524    fn test_extract_error_message_json_message() {
525        let body = r#"{"message": "server error"}"#;
526        assert_eq!(extract_error_message(body), "server error");
527    }
528
529    #[test]
530    fn test_extract_error_message_json_priority() {
531        // error_message takes priority over message
532        let body = r#"{"error_message": "first", "message": "second"}"#;
533        assert_eq!(extract_error_message(body), "first");
534    }
535
536    #[test]
537    fn test_extract_error_message_plain_text_short() {
538        let body = "plain text error";
539        assert_eq!(extract_error_message(body), "plain text error");
540    }
541
542    #[test]
543    fn test_extract_error_message_plain_text_truncated() {
544        let body = "x".repeat(300);
545        let result = extract_error_message(&body);
546        assert_eq!(result.len(), 200);
547    }
548
549    #[test]
550    fn test_validate_path_params_filled_no_placeholders() {
551        assert!(validate_path_params_filled("/users/123/tasks/456").is_ok());
552    }
553
554    #[test]
555    fn test_validate_path_params_filled_static_path() {
556        assert!(validate_path_params_filled("/health").is_ok());
557    }
558
559    #[test]
560    fn test_validate_path_params_filled_unfilled_placeholder() {
561        let result = validate_path_params_filled("/users/{user_id}/tasks");
562        assert!(result.is_err());
563        let msg = result.unwrap_err();
564        assert!(
565            msg.contains("user_id"),
566            "error should name the unfilled param: {msg}"
567        );
568    }
569
570    #[test]
571    fn test_validate_path_params_filled_multiple_unfilled() {
572        let result = validate_path_params_filled("/users/{user_id}/tasks/{task_id}");
573        assert!(result.is_err());
574        let msg = result.unwrap_err();
575        assert!(msg.contains("user_id") || msg.contains("task_id"), "{msg}");
576    }
577
578    #[test]
579    fn test_safe_truncate_multibyte() {
580        // Each emoji is multiple bytes but one char
581        let body = "\u{1F600}".repeat(300);
582        let result = safe_truncate(&body, 200);
583        assert_eq!(result.chars().count(), 200);
584    }
585
586    // D11-1 regression: BODY_METHODS = {POST, PUT, PATCH}. All other methods
587    // (GET, HEAD, DELETE, OPTIONS) MUST forward non-path inputs as the query
588    // string, mirroring Python and TypeScript. Previously Rust routed any
589    // non-GET method into the JSON body — DELETE proxies were emitting bodies
590    // that most servers ignore or reject (RFC 9110 §9.3.5).
591    #[test]
592    fn test_body_methods_set_contents() {
593        assert!(BODY_METHODS.contains(&"POST"));
594        assert!(BODY_METHODS.contains(&"PUT"));
595        assert!(BODY_METHODS.contains(&"PATCH"));
596        assert!(!BODY_METHODS.contains(&"GET"));
597        assert!(!BODY_METHODS.contains(&"DELETE"));
598        assert!(!BODY_METHODS.contains(&"HEAD"));
599        assert!(!BODY_METHODS.contains(&"OPTIONS"));
600    }
601}