Skip to main content

pi/connectors/
http.rs

1//! HTTP connector for extension hostcalls (pi.http).
2//!
3//! This adapter validates request payloads, enforces TLS policy, executes
4//! a simple GET/POST request via the internal HTTP client, and returns a
5//! normalized response shape.
6
7use crate::connectors::{
8    Connector, HostCallErrorCode, HostCallPayload, HostResultPayload, host_result_err,
9    host_result_err_with_details, host_result_ok,
10};
11use crate::error::{Error, Result};
12use crate::http::client::Client;
13use asupersync::http::h1::http_client::Scheme;
14use asupersync::http::h1::{ClientError, ParsedUrl};
15use async_trait::async_trait;
16use base64::Engine as _;
17use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
18use futures::StreamExt;
19use serde_json::{Value, json};
20use std::time::Duration;
21
22const DEFAULT_TIMEOUT_MS: u64 = 30_000;
23const DEFAULT_MAX_REQUEST_BYTES: usize = 50 * 1024 * 1024;
24const DEFAULT_MAX_RESPONSE_BYTES: usize = 50 * 1024 * 1024;
25
26#[derive(Debug, Clone)]
27pub struct HttpConnectorConfig {
28    pub require_tls: bool,
29    pub enforce_allowlist: bool,
30    pub allowlist: Vec<String>,
31    pub denylist: Vec<String>,
32    pub max_request_bytes: usize,
33    pub max_response_bytes: usize,
34    pub default_timeout_ms: u64,
35}
36
37impl Default for HttpConnectorConfig {
38    fn default() -> Self {
39        Self {
40            require_tls: true,
41            enforce_allowlist: false,
42            allowlist: Vec::new(),
43            denylist: Vec::new(),
44            max_request_bytes: DEFAULT_MAX_REQUEST_BYTES,
45            max_response_bytes: DEFAULT_MAX_RESPONSE_BYTES,
46            default_timeout_ms: DEFAULT_TIMEOUT_MS,
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct HttpConnector {
53    config: HttpConnectorConfig,
54    client: Client,
55}
56
57impl HttpConnector {
58    #[must_use]
59    pub fn new(mut config: HttpConnectorConfig) -> Self {
60        config.allowlist = normalize_allowlist(config.allowlist);
61        config.denylist = normalize_allowlist(config.denylist);
62        Self {
63            config,
64            client: Client::new(),
65        }
66    }
67
68    #[must_use]
69    pub fn with_defaults() -> Self {
70        Self::new(HttpConnectorConfig::default())
71    }
72}
73
74fn invalid_request(call_id: &str, message: impl Into<String>) -> HostResultPayload {
75    host_result_err(call_id, HostCallErrorCode::InvalidRequest, message, None)
76}
77
78fn io_error(call_id: &str, message: impl Into<String>) -> HostResultPayload {
79    host_result_err(call_id, HostCallErrorCode::Io, message, None)
80}
81
82fn timeout_error(call_id: &str, message: impl Into<String>) -> HostResultPayload {
83    host_result_err(call_id, HostCallErrorCode::Timeout, message, Some(true))
84}
85
86fn sanitize_invalid_url_reason(err: &ClientError) -> String {
87    match err {
88        ClientError::InvalidUrl(reason) => {
89            let reason = reason.trim();
90            if reason
91                .to_ascii_lowercase()
92                .starts_with("unsupported scheme in:")
93            {
94                "Unsupported URL scheme".to_string()
95            } else {
96                reason.to_string()
97            }
98        }
99        _ => "Invalid URL".to_string(),
100    }
101}
102
103fn deny_allowlist(call_id: &str, host: &str, allowlist: &[String]) -> HostResultPayload {
104    let details = json!({
105        "host": host,
106        "allowlist": allowlist,
107        "hint": "Add the host to capability_manifest.scope.hosts"
108    });
109    host_result_err_with_details(
110        call_id,
111        HostCallErrorCode::Denied,
112        "Host not in allowlist",
113        details,
114        None,
115    )
116}
117
118fn deny_denylist(call_id: &str, host: &str, denylist: &[String]) -> HostResultPayload {
119    let details = json!({
120        "host": host,
121        "denylist": denylist,
122        "hint": "Remove the host from the denylist to allow access"
123    });
124    host_result_err_with_details(
125        call_id,
126        HostCallErrorCode::Denied,
127        "Host is in denylist",
128        details,
129        None,
130    )
131}
132
133fn normalize_allowlist(allowlist: Vec<String>) -> Vec<String> {
134    allowlist
135        .into_iter()
136        .filter_map(|entry| normalize_host_entry(&entry))
137        .collect()
138}
139
140fn normalize_host_entry(raw: &str) -> Option<String> {
141    let trimmed = raw.trim();
142    if trimmed.is_empty() {
143        return None;
144    }
145
146    let mut host = trimmed.to_string();
147    if trimmed.contains("://") {
148        if let Ok(parsed) = ParsedUrl::parse(trimmed) {
149            host = parsed.host;
150        }
151    }
152
153    let host = host.trim().trim_end_matches('.');
154    let host = if host.starts_with('[') {
155        host.find(']').map_or(host, |end| &host[1..end])
156    } else if host.matches(':').count() == 1 {
157        host.split_once(':').map_or(host, |(left, _)| left)
158    } else {
159        host
160    };
161
162    let host = host.trim();
163    if host.is_empty() {
164        None
165    } else {
166        Some(host.to_ascii_lowercase())
167    }
168}
169
170fn host_is_allowed(host: &str, allowlist: &[String]) -> bool {
171    let Some(host) = normalize_host_entry(host) else {
172        return false;
173    };
174
175    for entry in allowlist {
176        let entry = entry.trim();
177        if entry.is_empty() {
178            continue;
179        }
180        if entry == "*" {
181            return true;
182        }
183        if let Some(suffix) = entry.strip_prefix("*.") {
184            if host == suffix || host.ends_with(&format!(".{suffix}")) {
185                return true;
186            }
187            continue;
188        }
189        if host == entry {
190            return true;
191        }
192        if host.ends_with(&format!(".{entry}")) {
193            return true;
194        }
195    }
196    false
197}
198
199fn host_is_denied(host: &str, denylist: &[String]) -> bool {
200    let Some(host) = normalize_host_entry(host) else {
201        return false;
202    };
203
204    for entry in denylist {
205        let entry = entry.trim();
206        if entry.is_empty() {
207            continue;
208        }
209        if entry == "*" {
210            return true;
211        }
212        if let Some(suffix) = entry.strip_prefix("*.") {
213            if host == suffix || host.ends_with(&format!(".{suffix}")) {
214                return true;
215            }
216            continue;
217        }
218        if host == entry {
219            return true;
220        }
221        if host.ends_with(&format!(".{entry}")) {
222            return true;
223        }
224    }
225    false
226}
227
228fn is_timeout_error(err: &Error) -> bool {
229    match err {
230        Error::Api(message) => message.to_ascii_lowercase().contains("timed out"),
231        Error::Io(err) => {
232            err.kind() == std::io::ErrorKind::TimedOut
233                || err.to_string().to_ascii_lowercase().contains("timed out")
234        }
235        _ => false,
236    }
237}
238
239fn is_timeout_io(err: &std::io::Error) -> bool {
240    err.kind() == std::io::ErrorKind::TimedOut
241        || err.to_string().to_ascii_lowercase().contains("timed out")
242}
243
244struct PreparedRequest {
245    url: String,
246    method: String,
247    headers: Vec<(String, String)>,
248    body: Option<Vec<u8>>,
249    timeout_ms: Option<u64>,
250}
251
252impl HttpConnector {
253    #[allow(clippy::too_many_lines)]
254    fn prepare_request(
255        &self,
256        call: &HostCallPayload,
257    ) -> std::result::Result<PreparedRequest, Box<HostResultPayload>> {
258        if !call.method.trim().eq_ignore_ascii_case("http") {
259            return Err(Box::new(invalid_request(
260                &call.call_id,
261                "Unsupported hostcall method for http connector",
262            )));
263        }
264
265        let Some(params) = call.params.as_object() else {
266            return Err(Box::new(invalid_request(
267                &call.call_id,
268                "http params must be an object",
269            )));
270        };
271
272        let url = match params.get("url").and_then(Value::as_str) {
273            Some(value) if !value.trim().is_empty() => value.trim().to_string(),
274            _ => return Err(Box::new(invalid_request(&call.call_id, "url is required"))),
275        };
276
277        let parsed = match ParsedUrl::parse(&url) {
278            Ok(parsed) => parsed,
279            Err(err) => {
280                let reason = sanitize_invalid_url_reason(&err);
281                return Err(Box::new(invalid_request(
282                    &call.call_id,
283                    format!("Invalid URL: {reason}"),
284                )));
285            }
286        };
287
288        if parsed.host.trim().is_empty() {
289            return Err(Box::new(invalid_request(
290                &call.call_id,
291                "URL host is required",
292            )));
293        }
294
295        match parsed.scheme {
296            Scheme::Http if self.config.require_tls => {
297                return Err(Box::new(host_result_err(
298                    &call.call_id,
299                    HostCallErrorCode::Denied,
300                    "TLS required: use https:// URLs",
301                    None,
302                )));
303            }
304            Scheme::Http | Scheme::Https => {}
305        }
306
307        if host_is_denied(&parsed.host, &self.config.denylist) {
308            return Err(Box::new(deny_denylist(
309                &call.call_id,
310                &parsed.host,
311                &self.config.denylist,
312            )));
313        }
314
315        let enforce_allowlist = self.config.enforce_allowlist || !self.config.allowlist.is_empty();
316        if enforce_allowlist {
317            if self.config.allowlist.is_empty() {
318                return Err(Box::new(host_result_err(
319                    &call.call_id,
320                    HostCallErrorCode::Denied,
321                    "HTTP allowlist is empty; update capability_manifest scope.hosts",
322                    None,
323                )));
324            }
325
326            if !host_is_allowed(&parsed.host, &self.config.allowlist) {
327                return Err(Box::new(deny_allowlist(
328                    &call.call_id,
329                    &parsed.host,
330                    &self.config.allowlist,
331                )));
332            }
333        }
334
335        let method = params
336            .get("method")
337            .and_then(Value::as_str)
338            .unwrap_or("GET")
339            .trim()
340            .to_ascii_uppercase();
341
342        if method != "GET" && method != "POST" {
343            return Err(Box::new(invalid_request(
344                &call.call_id,
345                format!("Unsupported HTTP method: {method}"),
346            )));
347        }
348
349        let body_val = params.get("body");
350        let body_bytes_val = params.get("body_bytes").or_else(|| params.get("bodyBytes"));
351
352        if body_val.is_some() && body_bytes_val.is_some() {
353            return Err(Box::new(invalid_request(
354                &call.call_id,
355                "body and body_bytes are mutually exclusive",
356            )));
357        }
358
359        if method == "GET" && (body_val.is_some() || body_bytes_val.is_some()) {
360            return Err(Box::new(invalid_request(
361                &call.call_id,
362                "GET requests must not include a body",
363            )));
364        }
365
366        let headers = if let Some(headers_value) = params.get("headers") {
367            let Some(headers_obj) = headers_value.as_object() else {
368                return Err(Box::new(invalid_request(
369                    &call.call_id,
370                    "headers must be an object",
371                )));
372            };
373            let mut out = Vec::with_capacity(headers_obj.len());
374            for (key, value) in headers_obj {
375                let value = value.as_str().map_or_else(
376                    || {
377                        if value.is_null() {
378                            String::new()
379                        } else {
380                            value.to_string()
381                        }
382                    },
383                    str::to_string,
384                );
385                out.push((key.clone(), value));
386            }
387            out
388        } else {
389            Vec::new()
390        };
391
392        let body = if let Some(body_bytes_value) = body_bytes_val {
393            let Some(encoded) = body_bytes_value.as_str() else {
394                return Err(Box::new(invalid_request(
395                    &call.call_id,
396                    "body_bytes must be a base64 string",
397                )));
398            };
399            let decoded = match BASE64_STANDARD.decode(encoded.as_bytes()) {
400                Ok(bytes) => bytes,
401                Err(err) => {
402                    return Err(Box::new(invalid_request(
403                        &call.call_id,
404                        format!("Invalid base64 body_bytes: {err}"),
405                    )));
406                }
407            };
408            Some(decoded)
409        } else if let Some(body_value) = body_val {
410            let body = body_value.as_str().map_or_else(
411                || {
412                    if body_value.is_null() {
413                        String::new()
414                    } else {
415                        body_value.to_string()
416                    }
417                },
418                str::to_string,
419            );
420            Some(body.into_bytes())
421        } else {
422            None
423        };
424
425        if let Some(ref bytes) = body {
426            if self.config.max_request_bytes > 0 && bytes.len() > self.config.max_request_bytes {
427                return Err(Box::new(invalid_request(
428                    &call.call_id,
429                    "request body too large",
430                )));
431            }
432        }
433
434        let timeout_ms_param = params
435            .get("timeout")
436            .and_then(Value::as_u64)
437            .or_else(|| params.get("timeoutMs").and_then(Value::as_u64))
438            .or_else(|| params.get("timeout_ms").and_then(Value::as_u64))
439            .filter(|value| *value > 0);
440
441        let timeout_ms_call = call.timeout_ms.filter(|value| *value > 0);
442
443        let timeout_ms = timeout_ms_param.or(timeout_ms_call).or({
444            if self.config.default_timeout_ms > 0 {
445                Some(self.config.default_timeout_ms)
446            } else {
447                None
448            }
449        });
450
451        Ok(PreparedRequest {
452            url,
453            method,
454            headers,
455            body,
456            timeout_ms,
457        })
458    }
459}
460
461#[async_trait]
462impl Connector for HttpConnector {
463    fn capability(&self) -> &'static str {
464        "http"
465    }
466
467    async fn dispatch(&self, call: &HostCallPayload) -> Result<HostResultPayload> {
468        let prepared = match self.prepare_request(call) {
469            Ok(prepared) => prepared,
470            Err(payload) => return Ok(*payload),
471        };
472
473        let mut builder = if prepared.method == "GET" {
474            self.client.get(&prepared.url)
475        } else {
476            self.client.post(&prepared.url)
477        };
478
479        for (key, value) in prepared.headers {
480            builder = builder.header(&key, value);
481        }
482
483        if let Some(body) = prepared.body {
484            builder = builder.body(body);
485        }
486
487        if let Some(timeout_ms) = prepared.timeout_ms {
488            builder = builder.timeout(Duration::from_millis(timeout_ms));
489        } else {
490            builder = builder.no_timeout();
491        }
492
493        let response = match builder.send().await {
494            Ok(response) => response,
495            Err(err) => {
496                if is_timeout_error(&err) {
497                    return Ok(timeout_error(&call.call_id, err.to_string()));
498                }
499                return Ok(io_error(&call.call_id, err.to_string()));
500            }
501        };
502
503        let status = response.status();
504        let headers = response.headers().to_vec();
505        let mut stream = response.bytes_stream();
506        let mut body_bytes = Vec::new();
507
508        while let Some(chunk) = stream.next().await {
509            match chunk {
510                Ok(bytes) => {
511                    if self.config.max_response_bytes > 0
512                        && body_bytes.len().saturating_add(bytes.len())
513                            > self.config.max_response_bytes
514                    {
515                        return Ok(invalid_request(&call.call_id, "response body too large"));
516                    }
517                    body_bytes.extend_from_slice(&bytes);
518                }
519                Err(err) => {
520                    if is_timeout_io(&err) {
521                        return Ok(timeout_error(&call.call_id, err.to_string()));
522                    }
523                    return Ok(io_error(&call.call_id, err.to_string()));
524                }
525            }
526        }
527
528        let mut headers_map = serde_json::Map::new();
529        for (key, value) in headers {
530            match headers_map.get_mut(&key) {
531                Some(Value::String(existing)) => {
532                    if !existing.is_empty() {
533                        existing.push_str(", ");
534                    }
535                    existing.push_str(&value);
536                }
537                _ => {
538                    headers_map.insert(key, Value::String(value));
539                }
540            }
541        }
542
543        let mut output = serde_json::Map::new();
544        output.insert("status".to_string(), json!(status));
545        output.insert("headers".to_string(), Value::Object(headers_map));
546
547        if let Ok(text) = String::from_utf8(body_bytes.clone()) {
548            output.insert("body".to_string(), Value::String(text));
549        } else {
550            let encoded = BASE64_STANDARD.encode(&body_bytes);
551            output.insert("body_bytes".to_string(), Value::String(encoded));
552        }
553
554        Ok(host_result_ok(&call.call_id, Value::Object(output)))
555    }
556}
557
558impl HttpConnector {
559    pub async fn dispatch_streaming(
560        &self,
561        call: &HostCallPayload,
562    ) -> std::result::Result<crate::http::client::Response, HostResultPayload> {
563        let prepared = match self.prepare_request(call) {
564            Ok(prepared) => prepared,
565            Err(payload) => return Err(*payload),
566        };
567
568        let mut builder = if prepared.method == "GET" {
569            self.client.get(&prepared.url)
570        } else {
571            self.client.post(&prepared.url)
572        };
573
574        for (key, value) in prepared.headers {
575            builder = builder.header(&key, value);
576        }
577
578        if let Some(body) = prepared.body {
579            builder = builder.body(body);
580        }
581
582        if let Some(timeout_ms) = prepared.timeout_ms {
583            builder = builder.timeout(Duration::from_millis(timeout_ms));
584        } else {
585            builder = builder.no_timeout();
586        }
587
588        match builder.send().await {
589            Ok(response) => Ok(response),
590            Err(err) => {
591                if is_timeout_error(&err) {
592                    Err(timeout_error(&call.call_id, err.to_string()))
593                } else {
594                    Err(io_error(&call.call_id, err.to_string()))
595                }
596            }
597        }
598    }
599}