Skip to main content

apcore_toolkit/output/
http_proxy_writer.rs

1// HTTP proxy registry writer.
2//
3// Registers scanned modules as HTTP proxy implementations that forward
4// requests to a running web API. Feature-gated behind `http-proxy`.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use tracing::debug;
12
13use apcore::context::Context;
14use apcore::errors::ModuleError;
15use apcore::module::{Module, ModuleAnnotations};
16use apcore::Registry;
17
18use crate::output::types::WriteResult;
19use crate::types::ScannedModule;
20
21/// Register scanned modules as HTTP proxy modules in the registry.
22///
23/// Each module's `execute()` sends an HTTP request to the target API
24/// instead of calling the handler directly.
25pub struct HTTPProxyRegistryWriter {
26    base_url: String,
27    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
28    timeout_secs: f64,
29}
30
31impl HTTPProxyRegistryWriter {
32    /// Create a new HTTP proxy writer.
33    ///
34    /// - `base_url`: Base URL of the target API.
35    /// - `auth_header_factory`: Optional callable returning HTTP headers for auth.
36    /// - `timeout_secs`: HTTP request timeout in seconds.
37    pub fn new(
38        base_url: String,
39        auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
40        timeout_secs: f64,
41    ) -> Self {
42        Self {
43            base_url,
44            auth_header_factory: auth_header_factory.map(Arc::from),
45            timeout_secs,
46        }
47    }
48
49    /// Register each ScannedModule as an HTTP proxy module.
50    pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
51        let mut results: Vec<WriteResult> = Vec::new();
52
53        for module in modules {
54            let (http_method, url_path) = get_http_fields(module);
55            let path_params = extract_path_params(&url_path);
56            let proxy = ProxyModule {
57                base_url: self.base_url.clone(),
58                http_method,
59                url_path,
60                path_params,
61                input_schema: module.input_schema.clone(),
62                output_schema: module.output_schema.clone(),
63                description: module.description.clone(),
64                annotations: module.annotations.clone().unwrap_or_default(),
65                timeout_secs: self.timeout_secs,
66                auth_header_factory: self.auth_header_factory.clone(),
67            };
68
69            let descriptor = apcore::registry::registry::ModuleDescriptor {
70                name: module.module_id.clone(),
71                annotations: proxy.annotations.clone(),
72                input_schema: module.input_schema.clone(),
73                output_schema: module.output_schema.clone(),
74                enabled: true,
75                tags: module.tags.clone(),
76                dependencies: vec![],
77            };
78
79            match registry.register(&module.module_id, Box::new(proxy), descriptor) {
80                Ok(()) => {
81                    debug!("Registered HTTP proxy: {}", module.module_id);
82                    results.push(WriteResult::new(module.module_id.clone()));
83                }
84                Err(e) => {
85                    debug!("Skipped {}: {}", module.module_id, e);
86                    results.push(WriteResult::failed(
87                        module.module_id.clone(),
88                        None,
89                        e.to_string(),
90                    ));
91                }
92            }
93        }
94
95        results
96    }
97}
98
99/// Extract http_method and url_path from a ScannedModule's metadata.
100fn get_http_fields(module: &ScannedModule) -> (String, String) {
101    let http_method = module
102        .metadata
103        .get("http_method")
104        .and_then(|v| v.as_str())
105        .unwrap_or("GET")
106        .to_string();
107    let url_path = module
108        .metadata
109        .get("url_path")
110        .and_then(|v| v.as_str())
111        .unwrap_or("/")
112        .to_string();
113    (http_method, url_path)
114}
115
116/// Regex matching URL path parameters like `{user_id}`.
117static PATH_PARAM_RE: LazyLock<Regex> =
118    LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
119
120/// Extract path parameter names from a URL pattern like `/users/{user_id}`.
121fn extract_path_params(url_path: &str) -> HashSet<String> {
122    PATH_PARAM_RE
123        .captures_iter(url_path)
124        .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
125        .collect()
126}
127
128/// Extract a human-readable error message from an HTTP error response body.
129///
130/// Tries to parse the body as JSON and looks for common error fields
131/// (`error_message`, `detail`, `error`, `message`) before falling back
132/// to a safely-truncated version of the raw text.
133fn extract_error_message(body: &str) -> String {
134    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
135        for key in &["error_message", "detail", "error", "message"] {
136            if let Some(val) = parsed.get(key) {
137                let msg = match val {
138                    serde_json::Value::String(s) => s.clone(),
139                    other => other.to_string(),
140                };
141                if !msg.is_empty() {
142                    return msg;
143                }
144            }
145        }
146    }
147
148    safe_truncate(body, 200)
149}
150
151/// Truncate a string to at most `max_chars` characters without panicking
152/// on multi-byte UTF-8 boundaries.
153fn safe_truncate(s: &str, max_chars: usize) -> String {
154    if s.chars().count() <= max_chars {
155        s.to_string()
156    } else {
157        s.chars().take(max_chars).collect()
158    }
159}
160
161/// A module that proxies requests to an HTTP API.
162struct ProxyModule {
163    base_url: String,
164    http_method: String,
165    url_path: String,
166    path_params: HashSet<String>,
167    input_schema: serde_json::Value,
168    output_schema: serde_json::Value,
169    description: String,
170    annotations: ModuleAnnotations,
171    timeout_secs: f64,
172    auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
173}
174
175#[async_trait]
176impl Module for ProxyModule {
177    fn input_schema(&self) -> serde_json::Value {
178        self.input_schema.clone()
179    }
180
181    fn output_schema(&self) -> serde_json::Value {
182        self.output_schema.clone()
183    }
184
185    fn description(&self) -> &str {
186        &self.description
187    }
188
189    async fn execute(
190        &self,
191        inputs: serde_json::Value,
192        _ctx: &Context<serde_json::Value>,
193    ) -> Result<serde_json::Value, ModuleError> {
194        let client = reqwest::Client::builder()
195            .timeout(std::time::Duration::from_secs_f64(self.timeout_secs))
196            .build()
197            .map_err(|e| {
198                ModuleError::new(
199                    apcore::errors::ErrorCode::ModuleExecuteError,
200                    format!("Failed to create HTTP client: {e}"),
201                )
202            })?;
203
204        let mut actual_path = self.url_path.clone();
205        let mut query: HashMap<String, String> = HashMap::new();
206        let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
207
208        if let Some(obj) = inputs.as_object() {
209            for (key, value) in obj {
210                if self.path_params.contains(key) {
211                    let val_str = match value {
212                        serde_json::Value::String(s) => s.clone(),
213                        other => other.to_string(),
214                    };
215                    actual_path = actual_path.replace(&format!("{{{key}}}"), &val_str);
216                } else if self.http_method == "GET" {
217                    let val_str = match value {
218                        serde_json::Value::String(s) => s.clone(),
219                        other => other.to_string(),
220                    };
221                    query.insert(key.clone(), val_str);
222                } else {
223                    body.insert(key.clone(), value.clone());
224                }
225            }
226        }
227
228        let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
229
230        let mut request = match self.http_method.as_str() {
231            "GET" => client.get(&url),
232            "POST" => client.post(&url),
233            "PUT" => client.put(&url),
234            "PATCH" => client.patch(&url),
235            "DELETE" => client.delete(&url),
236            other => {
237                return Err(ModuleError::new(
238                    apcore::errors::ErrorCode::ModuleExecuteError,
239                    format!("Unsupported HTTP method: {other}"),
240                ))
241            }
242        };
243
244        // Apply auth headers from the factory, if configured
245        if let Some(ref factory) = self.auth_header_factory {
246            for (header_name, header_value) in factory() {
247                request = request.header(&header_name, &header_value);
248            }
249        }
250
251        if !query.is_empty() {
252            request = request.query(&query.iter().collect::<Vec<_>>());
253        }
254        if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
255            request = request.json(&body);
256        }
257
258        let resp = request.send().await.map_err(|e| {
259            ModuleError::new(
260                apcore::errors::ErrorCode::ModuleExecuteError,
261                format!("HTTP request failed: {e}"),
262            )
263        })?;
264
265        let status = resp.status();
266        if status.is_success() {
267            if status.as_u16() == 204 {
268                return Ok(serde_json::json!({}));
269            }
270            resp.json().await.map_err(|e| {
271                ModuleError::new(
272                    apcore::errors::ErrorCode::ModuleExecuteError,
273                    format!("Failed to parse response JSON: {e}"),
274                )
275            })
276        } else {
277            let error_text = resp.text().await.unwrap_or_default();
278            let message = extract_error_message(&error_text);
279            Err(ModuleError::new(
280                apcore::errors::ErrorCode::ModuleExecuteError,
281                format!("HTTP {}: {}", status.as_u16(), message),
282            ))
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use serde_json::json;
291
292    #[test]
293    fn test_get_http_fields_defaults() {
294        let module = ScannedModule::new(
295            "test".into(),
296            "test".into(),
297            json!({}),
298            json!({}),
299            vec![],
300            "app:func".into(),
301        );
302        let (method, path) = get_http_fields(&module);
303        assert_eq!(method, "GET");
304        assert_eq!(path, "/");
305    }
306
307    #[test]
308    fn test_get_http_fields_from_metadata() {
309        let mut module = ScannedModule::new(
310            "test".into(),
311            "test".into(),
312            json!({}),
313            json!({}),
314            vec![],
315            "app:func".into(),
316        );
317        module.metadata.insert(
318            "http_method".into(),
319            serde_json::Value::String("POST".into()),
320        );
321        module.metadata.insert(
322            "url_path".into(),
323            serde_json::Value::String("/users".into()),
324        );
325        let (method, path) = get_http_fields(&module);
326        assert_eq!(method, "POST");
327        assert_eq!(path, "/users");
328    }
329
330    #[test]
331    fn test_extract_path_params() {
332        let params = extract_path_params("/users/{user_id}/tasks/{task_id}");
333        assert!(params.contains("user_id"));
334        assert!(params.contains("task_id"));
335        assert_eq!(params.len(), 2);
336    }
337
338    #[test]
339    fn test_extract_path_params_none() {
340        let params = extract_path_params("/users");
341        assert!(params.is_empty());
342    }
343
344    #[test]
345    fn test_extract_error_message_json_error_message() {
346        let body = r#"{"error_message": "not found"}"#;
347        assert_eq!(extract_error_message(body), "not found");
348    }
349
350    #[test]
351    fn test_extract_error_message_json_detail() {
352        let body = r#"{"detail": "unauthorized"}"#;
353        assert_eq!(extract_error_message(body), "unauthorized");
354    }
355
356    #[test]
357    fn test_extract_error_message_json_error() {
358        let body = r#"{"error": "bad request"}"#;
359        assert_eq!(extract_error_message(body), "bad request");
360    }
361
362    #[test]
363    fn test_extract_error_message_json_message() {
364        let body = r#"{"message": "server error"}"#;
365        assert_eq!(extract_error_message(body), "server error");
366    }
367
368    #[test]
369    fn test_extract_error_message_json_priority() {
370        // error_message takes priority over message
371        let body = r#"{"error_message": "first", "message": "second"}"#;
372        assert_eq!(extract_error_message(body), "first");
373    }
374
375    #[test]
376    fn test_extract_error_message_plain_text_short() {
377        let body = "plain text error";
378        assert_eq!(extract_error_message(body), "plain text error");
379    }
380
381    #[test]
382    fn test_extract_error_message_plain_text_truncated() {
383        let body = "x".repeat(300);
384        let result = extract_error_message(&body);
385        assert_eq!(result.len(), 200);
386    }
387
388    #[test]
389    fn test_safe_truncate_multibyte() {
390        // Each emoji is multiple bytes but one char
391        let body = "\u{1F600}".repeat(300);
392        let result = safe_truncate(&body, 200);
393        assert_eq!(result.chars().count(), 200);
394    }
395}