1use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, LazyLock};
8
9use async_trait::async_trait;
10use regex::Regex;
11use tracing::debug;
12
13use apcore::context::Context;
14use apcore::errors::ModuleError;
15use apcore::module::{Module, ModuleAnnotations};
16use apcore::Registry;
17
18use crate::output::types::WriteResult;
19use crate::types::ScannedModule;
20
21pub struct HTTPProxyRegistryWriter {
26 base_url: String,
27 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
28 timeout_secs: f64,
29}
30
31impl HTTPProxyRegistryWriter {
32 pub fn new(
38 base_url: String,
39 auth_header_factory: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
40 timeout_secs: f64,
41 ) -> Self {
42 Self {
43 base_url,
44 auth_header_factory: auth_header_factory.map(Arc::from),
45 timeout_secs,
46 }
47 }
48
49 pub fn write(&self, modules: &[ScannedModule], registry: &mut Registry) -> Vec<WriteResult> {
51 let mut results: Vec<WriteResult> = Vec::new();
52
53 for module in modules {
54 let (http_method, url_path) = get_http_fields(module);
55 let path_params = extract_path_params(&url_path);
56 let proxy = ProxyModule {
57 base_url: self.base_url.clone(),
58 http_method,
59 url_path,
60 path_params,
61 input_schema: module.input_schema.clone(),
62 output_schema: module.output_schema.clone(),
63 description: module.description.clone(),
64 annotations: module.annotations.clone().unwrap_or_default(),
65 timeout_secs: self.timeout_secs,
66 auth_header_factory: self.auth_header_factory.clone(),
67 };
68
69 let descriptor = apcore::registry::registry::ModuleDescriptor {
70 name: module.module_id.clone(),
71 annotations: proxy.annotations.clone(),
72 input_schema: module.input_schema.clone(),
73 output_schema: module.output_schema.clone(),
74 enabled: true,
75 tags: module.tags.clone(),
76 dependencies: vec![],
77 };
78
79 match registry.register(&module.module_id, Box::new(proxy), descriptor) {
80 Ok(()) => {
81 debug!("Registered HTTP proxy: {}", module.module_id);
82 results.push(WriteResult::new(module.module_id.clone()));
83 }
84 Err(e) => {
85 debug!("Skipped {}: {}", module.module_id, e);
86 results.push(WriteResult::failed(
87 module.module_id.clone(),
88 None,
89 e.to_string(),
90 ));
91 }
92 }
93 }
94
95 results
96 }
97}
98
99fn get_http_fields(module: &ScannedModule) -> (String, String) {
101 let http_method = module
102 .metadata
103 .get("http_method")
104 .and_then(|v| v.as_str())
105 .unwrap_or("GET")
106 .to_string();
107 let url_path = module
108 .metadata
109 .get("url_path")
110 .and_then(|v| v.as_str())
111 .unwrap_or("/")
112 .to_string();
113 (http_method, url_path)
114}
115
116static PATH_PARAM_RE: LazyLock<Regex> =
118 LazyLock::new(|| Regex::new(r"\{(\w+)\}").expect("static regex"));
119
120fn extract_path_params(url_path: &str) -> HashSet<String> {
122 PATH_PARAM_RE
123 .captures_iter(url_path)
124 .filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
125 .collect()
126}
127
128fn extract_error_message(body: &str) -> String {
134 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body) {
135 for key in &["error_message", "detail", "error", "message"] {
136 if let Some(val) = parsed.get(key) {
137 let msg = match val {
138 serde_json::Value::String(s) => s.clone(),
139 other => other.to_string(),
140 };
141 if !msg.is_empty() {
142 return msg;
143 }
144 }
145 }
146 }
147
148 safe_truncate(body, 200)
149}
150
151fn safe_truncate(s: &str, max_chars: usize) -> String {
154 if s.chars().count() <= max_chars {
155 s.to_string()
156 } else {
157 s.chars().take(max_chars).collect()
158 }
159}
160
161struct ProxyModule {
163 base_url: String,
164 http_method: String,
165 url_path: String,
166 path_params: HashSet<String>,
167 input_schema: serde_json::Value,
168 output_schema: serde_json::Value,
169 description: String,
170 annotations: ModuleAnnotations,
171 timeout_secs: f64,
172 auth_header_factory: Option<Arc<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
173}
174
175#[async_trait]
176impl Module for ProxyModule {
177 fn input_schema(&self) -> serde_json::Value {
178 self.input_schema.clone()
179 }
180
181 fn output_schema(&self) -> serde_json::Value {
182 self.output_schema.clone()
183 }
184
185 fn description(&self) -> &str {
186 &self.description
187 }
188
189 async fn execute(
190 &self,
191 inputs: serde_json::Value,
192 _ctx: &Context<serde_json::Value>,
193 ) -> Result<serde_json::Value, ModuleError> {
194 let client = reqwest::Client::builder()
195 .timeout(std::time::Duration::from_secs_f64(self.timeout_secs))
196 .build()
197 .map_err(|e| {
198 ModuleError::new(
199 apcore::errors::ErrorCode::ModuleExecuteError,
200 format!("Failed to create HTTP client: {e}"),
201 )
202 })?;
203
204 let mut actual_path = self.url_path.clone();
205 let mut query: HashMap<String, String> = HashMap::new();
206 let mut body: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
207
208 if let Some(obj) = inputs.as_object() {
209 for (key, value) in obj {
210 if self.path_params.contains(key) {
211 let val_str = match value {
212 serde_json::Value::String(s) => s.clone(),
213 other => other.to_string(),
214 };
215 actual_path = actual_path.replace(&format!("{{{key}}}"), &val_str);
216 } else if self.http_method == "GET" {
217 let val_str = match value {
218 serde_json::Value::String(s) => s.clone(),
219 other => other.to_string(),
220 };
221 query.insert(key.clone(), val_str);
222 } else {
223 body.insert(key.clone(), value.clone());
224 }
225 }
226 }
227
228 let url = format!("{}{}", self.base_url.trim_end_matches('/'), actual_path);
229
230 let mut request = match self.http_method.as_str() {
231 "GET" => client.get(&url),
232 "POST" => client.post(&url),
233 "PUT" => client.put(&url),
234 "PATCH" => client.patch(&url),
235 "DELETE" => client.delete(&url),
236 other => {
237 return Err(ModuleError::new(
238 apcore::errors::ErrorCode::ModuleExecuteError,
239 format!("Unsupported HTTP method: {other}"),
240 ))
241 }
242 };
243
244 if let Some(ref factory) = self.auth_header_factory {
246 for (header_name, header_value) in factory() {
247 request = request.header(&header_name, &header_value);
248 }
249 }
250
251 if !query.is_empty() {
252 request = request.query(&query.iter().collect::<Vec<_>>());
253 }
254 if !body.is_empty() && matches!(self.http_method.as_str(), "POST" | "PUT" | "PATCH") {
255 request = request.json(&body);
256 }
257
258 let resp = request.send().await.map_err(|e| {
259 ModuleError::new(
260 apcore::errors::ErrorCode::ModuleExecuteError,
261 format!("HTTP request failed: {e}"),
262 )
263 })?;
264
265 let status = resp.status();
266 if status.is_success() {
267 if status.as_u16() == 204 {
268 return Ok(serde_json::json!({}));
269 }
270 resp.json().await.map_err(|e| {
271 ModuleError::new(
272 apcore::errors::ErrorCode::ModuleExecuteError,
273 format!("Failed to parse response JSON: {e}"),
274 )
275 })
276 } else {
277 let error_text = resp.text().await.unwrap_or_default();
278 let message = extract_error_message(&error_text);
279 Err(ModuleError::new(
280 apcore::errors::ErrorCode::ModuleExecuteError,
281 format!("HTTP {}: {}", status.as_u16(), message),
282 ))
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use serde_json::json;
291
292 #[test]
293 fn test_get_http_fields_defaults() {
294 let module = ScannedModule::new(
295 "test".into(),
296 "test".into(),
297 json!({}),
298 json!({}),
299 vec![],
300 "app:func".into(),
301 );
302 let (method, path) = get_http_fields(&module);
303 assert_eq!(method, "GET");
304 assert_eq!(path, "/");
305 }
306
307 #[test]
308 fn test_get_http_fields_from_metadata() {
309 let mut module = ScannedModule::new(
310 "test".into(),
311 "test".into(),
312 json!({}),
313 json!({}),
314 vec![],
315 "app:func".into(),
316 );
317 module.metadata.insert(
318 "http_method".into(),
319 serde_json::Value::String("POST".into()),
320 );
321 module.metadata.insert(
322 "url_path".into(),
323 serde_json::Value::String("/users".into()),
324 );
325 let (method, path) = get_http_fields(&module);
326 assert_eq!(method, "POST");
327 assert_eq!(path, "/users");
328 }
329
330 #[test]
331 fn test_extract_path_params() {
332 let params = extract_path_params("/users/{user_id}/tasks/{task_id}");
333 assert!(params.contains("user_id"));
334 assert!(params.contains("task_id"));
335 assert_eq!(params.len(), 2);
336 }
337
338 #[test]
339 fn test_extract_path_params_none() {
340 let params = extract_path_params("/users");
341 assert!(params.is_empty());
342 }
343
344 #[test]
345 fn test_extract_error_message_json_error_message() {
346 let body = r#"{"error_message": "not found"}"#;
347 assert_eq!(extract_error_message(body), "not found");
348 }
349
350 #[test]
351 fn test_extract_error_message_json_detail() {
352 let body = r#"{"detail": "unauthorized"}"#;
353 assert_eq!(extract_error_message(body), "unauthorized");
354 }
355
356 #[test]
357 fn test_extract_error_message_json_error() {
358 let body = r#"{"error": "bad request"}"#;
359 assert_eq!(extract_error_message(body), "bad request");
360 }
361
362 #[test]
363 fn test_extract_error_message_json_message() {
364 let body = r#"{"message": "server error"}"#;
365 assert_eq!(extract_error_message(body), "server error");
366 }
367
368 #[test]
369 fn test_extract_error_message_json_priority() {
370 let body = r#"{"error_message": "first", "message": "second"}"#;
372 assert_eq!(extract_error_message(body), "first");
373 }
374
375 #[test]
376 fn test_extract_error_message_plain_text_short() {
377 let body = "plain text error";
378 assert_eq!(extract_error_message(body), "plain text error");
379 }
380
381 #[test]
382 fn test_extract_error_message_plain_text_truncated() {
383 let body = "x".repeat(300);
384 let result = extract_error_message(&body);
385 assert_eq!(result.len(), 200);
386 }
387
388 #[test]
389 fn test_safe_truncate_multibyte() {
390 let body = "\u{1F600}".repeat(300);
392 let result = safe_truncate(&body, 200);
393 assert_eq!(result.chars().count(), 200);
394 }
395}