1#![cfg_attr(not(all(target_os = "linux", feature = "kvm")), allow(dead_code))]
2
3use std::collections::HashMap;
4use std::io::Read;
5use std::net::IpAddr;
6use std::net::SocketAddr;
7use std::net::ToSocketAddrs;
8use std::time::Duration;
9
10use base64::Engine;
11use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
12use mimobox_core::SandboxConfig;
13use reqwest::Method;
14use reqwest::blocking::Client;
15use reqwest::redirect::Policy;
16use serde::{Deserialize, Serialize};
17
18const DEFAULT_TIMEOUT_MS: u64 = 30_000;
19const DEFAULT_MAX_RESPONSE_BYTES: usize = 1024 * 1024;
20const MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024;
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct HttpProxyRequestPayload {
29 pub method: String,
31 pub url: String,
33 #[serde(default)]
35 pub headers: HashMap<String, String>,
36 #[serde(default)]
38 pub body_b64: Option<String>,
39 #[serde(default)]
41 pub timeout_ms: Option<u64>,
42 #[serde(default)]
44 pub max_response_bytes: Option<usize>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct HttpRequest {
50 pub method: String,
52 pub url: String,
54 pub headers: HashMap<String, String>,
56 pub body: Option<Vec<u8>>,
58 pub timeout_ms: u64,
60 pub max_response_bytes: usize,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct HttpResponse {
67 pub status: u16,
69 pub headers: HashMap<String, String>,
71 pub body: Vec<u8>,
73}
74
75#[derive(Debug, thiserror::Error)]
77pub enum HttpProxyError {
78 #[error("domain not in whitelist: {0}")]
80 DeniedHost(
81 String,
83 ),
84 #[error("HTTP request timed out")]
86 Timeout,
87 #[error("HTTP body exceeds size limit")]
89 BodyTooLarge,
90 #[error("HTTP connection failed: {0}")]
92 ConnectFail(
93 String,
95 ),
96 #[error("TLS handshake failed: {0}")]
98 TlsFail(
99 String,
101 ),
102 #[error("invalid URL: {0}")]
104 InvalidUrl(
105 String,
107 ),
108 #[error("DNS resolution hit private address: {0}")]
110 DnsRebind(
111 String,
113 ),
114 #[error("HTTP proxy internal error: {0}")]
116 Internal(
117 String,
119 ),
120}
121
122impl HttpProxyError {
123 pub fn code(&self) -> &'static str {
125 match self {
126 Self::DeniedHost(_) => "DENIED_HOST",
127 Self::Timeout => "TIMEOUT",
128 Self::BodyTooLarge => "BODY_TOO_LARGE",
129 Self::ConnectFail(_) => "CONNECT_FAIL",
130 Self::TlsFail(_) => "TLS_FAIL",
131 Self::InvalidUrl(_) => "INVALID_URL",
132 Self::DnsRebind(_) => "DNS_REBIND",
133 Self::Internal(_) => "INTERNAL",
134 }
135 }
136}
137
138impl TryFrom<HttpProxyRequestPayload> for HttpRequest {
139 type Error = HttpProxyError;
140
141 fn try_from(value: HttpProxyRequestPayload) -> Result<Self, Self::Error> {
142 let body = match value.body_b64 {
143 Some(encoded) => {
144 let bytes = BASE64_STANDARD.decode(encoded).map_err(|err| {
145 HttpProxyError::InvalidUrl(format!("body_b64 is not valid base64: {err}"))
146 })?;
147 if bytes.len() > MAX_REQUEST_BODY_BYTES {
148 return Err(HttpProxyError::BodyTooLarge);
149 }
150 Some(bytes)
151 }
152 None => None,
153 };
154
155 Ok(Self {
156 method: value.method,
157 url: value.url,
158 headers: value.headers,
159 body,
160 timeout_ms: value.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
161 max_response_bytes: value
162 .max_response_bytes
163 .unwrap_or(DEFAULT_MAX_RESPONSE_BYTES),
164 })
165 }
166}
167
168impl HttpRequest {
169 pub fn new(
174 method: impl Into<String>,
175 url: impl Into<String>,
176 headers: HashMap<String, String>,
177 body: Option<Vec<u8>>,
178 timeout_ms: Option<u64>,
179 max_response_bytes: Option<usize>,
180 ) -> Result<Self, HttpProxyError> {
181 let body = match body {
182 Some(bytes) if bytes.len() > MAX_REQUEST_BODY_BYTES => {
183 return Err(HttpProxyError::BodyTooLarge);
184 }
185 other => other,
186 };
187
188 Ok(Self {
189 method: method.into(),
190 url: url.into(),
191 headers,
192 body,
193 timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
194 max_response_bytes: max_response_bytes.unwrap_or(DEFAULT_MAX_RESPONSE_BYTES),
195 })
196 }
197
198 pub fn from_json(json: &str) -> Result<Self, HttpProxyError> {
202 let payload = serde_json::from_str::<HttpProxyRequestPayload>(json).map_err(|err| {
203 HttpProxyError::InvalidUrl(format!("invalid HTTP request JSON: {err}"))
204 })?;
205 Self::try_from(payload)
206 }
207}
208
209pub fn execute_http_request(
215 config: &SandboxConfig,
216 request: &HttpRequest,
217) -> Result<HttpResponse, HttpProxyError> {
218 let url = reqwest::Url::parse(&request.url)
219 .map_err(|err| HttpProxyError::InvalidUrl(err.to_string()))?;
220 validate_http_request(config, &url)?;
221 let verified_ip = validate_dns_resolution(&url)?;
222 let host = url
223 .host_str()
224 .ok_or_else(|| HttpProxyError::InvalidUrl("URL missing host".into()))?;
225 let port = url
226 .port_or_known_default()
227 .ok_or_else(|| HttpProxyError::InvalidUrl("URL missing port information".into()))?;
228 let socket_addr = SocketAddr::new(verified_ip, port);
229 let resolve_key = format!("{host}:{port}");
230
231 let method = Method::from_bytes(request.method.as_bytes())
232 .map_err(|err| HttpProxyError::InvalidUrl(format!("invalid HTTP method: {err}")))?;
233 let timeout = Duration::from_millis(request.timeout_ms.max(1));
234 let client = Client::builder()
235 .timeout(timeout)
236 .redirect(Policy::none())
237 .resolve(&resolve_key, socket_addr)
238 .build()
239 .map_err(|err| HttpProxyError::Internal(format!("failed to build HTTP client: {err}")))?;
240
241 let mut builder = client.request(method, url);
242 for (key, value) in &request.headers {
243 builder = builder.header(key, value);
244 }
245 if let Some(body) = &request.body {
246 builder = builder.body(body.clone());
247 }
248
249 let mut response = builder.send().map_err(map_reqwest_error)?;
250 let mut headers = HashMap::new();
251 for (name, value) in response.headers() {
252 headers.insert(
253 name.as_str().to_string(),
254 String::from_utf8_lossy(value.as_bytes()).into_owned(),
255 );
256 }
257 let body = read_response_body(&mut response, request.max_response_bytes)?;
258
259 Ok(HttpResponse {
260 status: response.status().as_u16(),
261 headers,
262 body,
263 })
264}
265
266pub fn is_allowed_http_host(config: &SandboxConfig, host: &str) -> bool {
272 let normalized_host = host.trim_end_matches('.').to_ascii_lowercase();
273 if normalized_host.is_empty() {
274 return false;
275 }
276
277 config.allowed_http_domains.iter().any(|rule| {
278 let rule = rule.trim_end_matches('.').to_ascii_lowercase();
279 if let Some(suffix) = rule.strip_prefix("*.") {
280 normalized_host.len() > suffix.len()
281 && normalized_host.ends_with(suffix)
282 && normalized_host
283 .as_bytes()
284 .get(normalized_host.len() - suffix.len() - 1)
285 == Some(&b'.')
286 } else {
287 normalized_host == rule
288 }
289 })
290}
291
292fn validate_http_request(config: &SandboxConfig, url: &reqwest::Url) -> Result<(), HttpProxyError> {
293 if url.scheme() != "https" {
294 return Err(HttpProxyError::InvalidUrl(format!(
295 "only HTTPS is allowed, got {}",
296 url.scheme()
297 )));
298 }
299
300 let host = url
301 .host_str()
302 .ok_or_else(|| HttpProxyError::InvalidUrl("URL missing host".into()))?;
303 validate_host(config, host)
304}
305
306fn validate_host(config: &SandboxConfig, host: &str) -> Result<(), HttpProxyError> {
307 let normalized_host = host.trim_end_matches('.').to_ascii_lowercase();
308 if normalized_host.is_empty() {
309 return Err(HttpProxyError::InvalidUrl("host must not be empty".into()));
310 }
311
312 if let Ok(ip) = normalized_host.parse::<IpAddr>() {
313 if is_private_ip(ip) {
314 return Err(HttpProxyError::DeniedHost(normalized_host));
315 }
316 return Err(HttpProxyError::InvalidUrl(
317 "direct IP access is forbidden".into(),
318 ));
319 }
320
321 if !is_allowed_http_host(config, &normalized_host) {
322 return Err(HttpProxyError::DeniedHost(normalized_host));
323 }
324
325 Ok(())
326}
327
328fn validate_dns_resolution(url: &reqwest::Url) -> Result<IpAddr, HttpProxyError> {
329 let host = url
330 .host_str()
331 .ok_or_else(|| HttpProxyError::InvalidUrl("URL missing host".into()))?;
332 let port = url
333 .port_or_known_default()
334 .ok_or_else(|| HttpProxyError::InvalidUrl("URL missing port information".into()))?;
335
336 let addrs = (host, port)
337 .to_socket_addrs()
338 .map_err(|err| HttpProxyError::ConnectFail(format!("DNS resolution failed: {err}")))?;
339
340 let mut has_addr = false;
341 let mut verified_ip = None;
342
343 for addr in addrs {
344 has_addr = true;
345 let ip = addr.ip();
346 if is_private_ip(ip) {
347 return Err(HttpProxyError::DnsRebind(format!(
348 "{host} resolved to private address {}",
349 ip
350 )));
351 }
352 verified_ip.get_or_insert(ip);
353 }
354
355 if !has_addr {
356 return Err(HttpProxyError::ConnectFail(format!(
357 "DNS resolution returned no addresses for {host}"
358 )));
359 }
360
361 verified_ip.ok_or_else(|| {
362 HttpProxyError::DnsRebind(format!("{host} resolved only to private addresses"))
363 })
364}
365
366fn is_private_ip(ip: IpAddr) -> bool {
367 match ip {
368 IpAddr::V4(ipv4) => {
369 ipv4.is_private() || ipv4.is_loopback() || ipv4.is_link_local() || ipv4.is_unspecified()
370 }
371 IpAddr::V6(ipv6) => {
372 ipv6.is_loopback()
373 || ipv6.is_unspecified()
374 || ipv6.is_unique_local()
375 || ipv6.is_unicast_link_local()
376 }
377 }
378}
379
380fn read_response_body(
381 response: &mut reqwest::blocking::Response,
382 max_response_bytes: usize,
383) -> Result<Vec<u8>, HttpProxyError> {
384 let mut body = Vec::new();
385 let mut buffer = [0u8; 8192];
386
387 loop {
388 let read = response.read(&mut buffer).map_err(|err| {
389 HttpProxyError::Internal(format!("failed to read HTTP response: {err}"))
390 })?;
391 if read == 0 {
392 break;
393 }
394 if body.len().saturating_add(read) > max_response_bytes {
395 return Err(HttpProxyError::BodyTooLarge);
396 }
397 body.extend_from_slice(&buffer[..read]);
398 }
399
400 Ok(body)
401}
402
403fn map_reqwest_error(err: reqwest::Error) -> HttpProxyError {
404 if err.is_timeout() {
405 return HttpProxyError::Timeout;
406 }
407 if err.is_connect() {
408 let message = err.to_string();
409 let lower = message.to_ascii_lowercase();
410 if lower.contains("certificate")
411 || lower.contains("tls")
412 || lower.contains("ssl")
413 || lower.contains("handshake")
414 {
415 return HttpProxyError::TlsFail(message);
416 }
417 return HttpProxyError::ConnectFail(message);
418 }
419 if err.is_builder() || err.is_request() {
420 return HttpProxyError::InvalidUrl(err.to_string());
421 }
422 HttpProxyError::Internal(err.to_string())
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 fn config(domains: &[&str]) -> SandboxConfig {
430 let mut config = SandboxConfig::default();
431 config.allowed_http_domains = domains.iter().map(|item| (*item).to_string()).collect();
432 config
433 }
434
435 #[test]
436 fn wildcard_domain_matches_subdomain_only() {
437 let config = config(&["*.openai.com"]);
438
439 assert!(is_allowed_http_host(&config, "api.openai.com"));
440 assert!(is_allowed_http_host(&config, "foo.bar.openai.com"));
441 assert!(!is_allowed_http_host(&config, "openai.com"));
442 assert!(!is_allowed_http_host(&config, "api.openai.org"));
443 }
444
445 #[test]
446 fn literal_ip_is_rejected() {
447 let config = config(&["*.openai.com"]);
448 let url = reqwest::Url::parse("https://127.0.0.1/v1/models").expect("URL 必须合法");
449
450 let err = validate_http_request(&config, &url).expect_err("IP 直连必须被拒绝");
451 assert!(matches!(
452 err,
453 HttpProxyError::DeniedHost(_) | HttpProxyError::InvalidUrl(_)
454 ));
455 }
456
457 #[test]
458 fn non_whitelisted_domain_is_rejected() {
459 let config = config(&["*.openai.com"]);
460 let url = reqwest::Url::parse("https://example.com/").expect("URL 必须合法");
461
462 let err = validate_http_request(&config, &url).expect_err("白名单外域名必须被拒绝");
463 assert!(matches!(err, HttpProxyError::DeniedHost(host) if host == "example.com"));
464 }
465
466 #[test]
467 fn localhost_is_blocked_by_dns_rebind_guard() {
468 let url = reqwest::Url::parse("https://localhost/").expect("URL 必须合法");
469
470 let err = validate_dns_resolution(&url).expect_err("localhost 必须被拒绝");
471 assert!(matches!(err, HttpProxyError::DnsRebind(_)));
472 }
473}