Skip to main content

clawbox_proxy/
proxy.rs

1//! Forward proxy service — the core request pipeline.
2//!
3//! Flow: allowlist check → leak scan → private IP check → credential injection → forward → audit.
4
5use std::collections::HashMap;
6use std::net::IpAddr;
7use std::sync::Arc;
8use std::time::Instant;
9
10use chrono::Utc;
11use reqwest::Client;
12use thiserror::Error;
13
14use crate::allowlist::AllowlistEnforcer;
15use crate::audit::AuditEntry;
16use crate::credentials::CredentialInjector;
17use crate::leak_detection::LeakDetector;
18use crate::rate_limiter::RateLimiter;
19
20/// Errors from the proxy pipeline.
21#[derive(Debug, Error)]
22#[non_exhaustive]
23pub enum ProxyError {
24    #[error("URL blocked by allowlist: {0}")]
25    Blocked(String),
26    #[error("credential leak detected in outbound request")]
27    LeakDetected,
28    #[error("HTTP error: {0}")]
29    Http(#[from] reqwest::Error),
30    #[error("invalid URL: {0}")]
31    InvalidUrl(String),
32    #[error("request to private/internal IP blocked: {0}")]
33    PrivateIpBlocked(String),
34    #[error("failed to build HTTP client: {0}")]
35    ClientBuild(String),
36    #[error("rate limited: {0}")]
37    RateLimited(String),
38}
39
40/// Configuration for the proxy service.
41#[derive(Debug, Clone)]
42#[non_exhaustive]
43#[must_use]
44pub struct ProxyConfig {
45    pub allowlist: Vec<String>,
46    pub max_response_bytes: usize,
47    pub timeout_ms: u64,
48}
49
50impl ProxyConfig {
51    /// Create a new proxy configuration.
52    pub fn new(allowlist: Vec<String>, max_response_bytes: usize, timeout_ms: u64) -> Self {
53        Self {
54            allowlist,
55            max_response_bytes,
56            timeout_ms,
57        }
58    }
59}
60
61impl Default for ProxyConfig {
62    fn default() -> Self {
63        Self {
64            allowlist: Vec::new(),
65            max_response_bytes: 10 * 1024 * 1024, // 10MB
66            timeout_ms: 30_000,
67        }
68    }
69}
70
71/// Response from a proxied request.
72#[derive(Debug)]
73#[must_use]
74#[non_exhaustive]
75pub struct ProxyResponse {
76    pub status: u16,
77    pub headers: HashMap<String, String>,
78    pub body: String,
79    pub audit: AuditEntry,
80}
81
82/// The forward proxy service.
83#[non_exhaustive]
84pub struct ProxyService {
85    enforcer: AllowlistEnforcer,
86    injector: CredentialInjector,
87    leak_detector: LeakDetector,
88    client: Client,
89    config: ProxyConfig,
90    rate_limiter: Option<Arc<RateLimiter>>,
91    rate_limit_key: Option<String>,
92}
93
94/// Check if an IP address is private, loopback, link-local, or otherwise internal.
95fn is_private_ip(ip: &IpAddr) -> bool {
96    match ip {
97        IpAddr::V4(v4) => {
98            v4.is_loopback()          // 127.0.0.0/8
99                || v4.is_private()     // 10/8, 172.16/12, 192.168/16
100                || v4.is_link_local()  // 169.254/16
101                || v4.is_unspecified() // 0.0.0.0
102                || v4.is_broadcast() // 255.255.255.255
103        }
104        IpAddr::V6(v6) => {
105            if let Some(mapped) = v6.to_ipv4_mapped() {
106                return is_private_ip(&IpAddr::V4(mapped));
107            }
108            v6.is_loopback()       // ::1
109                || v6.is_unspecified() // ::
110                // fc00::/7 (unique local)
111                || (v6.segments()[0] & 0xfe00) == 0xfc00
112                // fe80::/10 (link-local)
113                || (v6.segments()[0] & 0xffc0) == 0xfe80
114        }
115    }
116}
117
118/// Check if a URL host is a private/internal IP. Returns Err if blocked.
119/// NOTE: This only catches IP-literal hosts. Full DNS rebinding protection
120/// requires a custom DNS resolver that checks resolved IPs before connecting.
121fn check_private_ip(parsed: &url::Url) -> Result<(), ProxyError> {
122    if let Some(host) = parsed.host_str() {
123        if let Ok(ip) = host.parse::<IpAddr>()
124            && is_private_ip(&ip)
125        {
126            return Err(ProxyError::PrivateIpBlocked(host.to_string()));
127        }
128        // Also catch IPv6 in brackets
129        let trimmed = host.trim_start_matches('[').trim_end_matches(']');
130        if let Ok(ip) = trimmed.parse::<IpAddr>()
131            && is_private_ip(&ip)
132        {
133            return Err(ProxyError::PrivateIpBlocked(trimmed.to_string()));
134        }
135    }
136    Ok(())
137}
138
139impl ProxyService {
140    pub fn new(
141        config: ProxyConfig,
142        injector: CredentialInjector,
143        leak_detector: LeakDetector,
144    ) -> Result<Self, ProxyError> {
145        let enforcer = AllowlistEnforcer::new(config.allowlist.clone());
146        let client = Client::builder()
147            .danger_accept_invalid_certs(false)
148            .redirect(reqwest::redirect::Policy::none())
149            .timeout(std::time::Duration::from_millis(config.timeout_ms))
150            .build()
151            .map_err(|e| ProxyError::ClientBuild(e.to_string()))?;
152        Ok(Self {
153            enforcer,
154            injector,
155            leak_detector,
156            client,
157            config,
158            rate_limiter: None,
159            rate_limit_key: None,
160        })
161    }
162
163    /// Attach a shared rate limiter.
164    /// Use a pre-built HTTP client (for connection pooling across services).
165    pub fn with_client(mut self, client: Client) -> Self {
166        self.client = client;
167        self
168    }
169
170    pub fn with_rate_limiter(mut self, limiter: Arc<RateLimiter>) -> Self {
171        self.rate_limiter = Some(limiter);
172        self
173    }
174
175    /// Set the rate limit key (typically tool name or container ID).
176    pub fn with_rate_limit_key(mut self, key: impl Into<String>) -> Self {
177        self.rate_limit_key = Some(key.into());
178        self
179    }
180
181    /// Forward a request through the proxy pipeline.
182    pub async fn forward_request(
183        &self,
184        url: &str,
185        method: &str,
186        headers: HashMap<String, String>,
187        body: Option<String>,
188    ) -> Result<ProxyResponse, ProxyError> {
189        let start = Instant::now();
190        let mut audit = AuditEntry::new(url.to_string(), method.to_string());
191
192        // 0. Rate limiting
193        if let Some(ref limiter) = self.rate_limiter {
194            let key = self.rate_limit_key.as_deref().unwrap_or("default");
195            if !limiter.check(key) {
196                return Err(ProxyError::RateLimited(key.to_string()));
197            }
198        }
199
200        // 1. Allowlist check
201        if !self.enforcer.is_allowed(url) {
202            audit.blocked = true;
203            audit.duration_ms = start.elapsed().as_millis() as u64;
204            return Err(ProxyError::Blocked(url.to_string()));
205        }
206
207        // 2. Parse URL and block private IPs
208        let parsed = url::Url::parse(url).map_err(|e| ProxyError::InvalidUrl(e.to_string()))?;
209        check_private_ip(&parsed)?;
210
211        // 2b. DNS pre-resolution - block domains resolving to private IPs
212        //     Pin the validated IP via reqwest::resolve() to prevent DNS rebinding (TOCTOU).
213        let pinned_client = if let Some(host) = parsed.host_str() {
214            if host.parse::<IpAddr>().is_err() {
215                let port = parsed.port_or_known_default().unwrap_or(80);
216                let lookup = format!("{}:{}", host, port);
217                match tokio::net::lookup_host(&lookup).await {
218                    Ok(addrs) => {
219                        let addrs: Vec<_> = addrs.collect();
220                        for addr in &addrs {
221                            if is_private_ip(&addr.ip()) {
222                                return Err(ProxyError::PrivateIpBlocked(format!(
223                                    "{} resolves to private IP {}",
224                                    host,
225                                    addr.ip()
226                                )));
227                            }
228                        }
229                        // Pin the first validated address to prevent a second DNS lookup
230                        let validated_addr = addrs[0];
231                        Some(
232                            Client::builder()
233                                .danger_accept_invalid_certs(false)
234                                .redirect(reqwest::redirect::Policy::none())
235                                .timeout(std::time::Duration::from_millis(self.config.timeout_ms))
236                                .resolve(host, validated_addr)
237                                .pool_max_idle_per_host(0)
238                                .build()
239                                .map_err(|e| ProxyError::ClientBuild(e.to_string()))?,
240                        )
241                    }
242                    Err(e) => {
243                        return Err(ProxyError::InvalidUrl(format!(
244                            "DNS resolution failed for {}: {}",
245                            host, e
246                        )));
247                    }
248                }
249            } else {
250                None // IP literal, no DNS pinning needed
251            }
252        } else {
253            None
254        };
255        let client = pinned_client.as_ref().unwrap_or(&self.client);
256
257        // 3. Leak detection on URL
258        let url_findings = self.leak_detector.scan(url);
259        if !url_findings.is_empty() {
260            audit.leak_detected = true;
261            audit.duration_ms = start.elapsed().as_millis() as u64;
262            return Err(ProxyError::LeakDetected);
263        }
264
265        // 4. Leak detection on headers
266        for v in headers.values() {
267            let findings = self.leak_detector.scan(v);
268            if !findings.is_empty() {
269                audit.leak_detected = true;
270                audit.duration_ms = start.elapsed().as_millis() as u64;
271                return Err(ProxyError::LeakDetected);
272            }
273        }
274
275        // 5. Leak detection on outbound body
276        if let Some(ref body_content) = body {
277            let findings = self.leak_detector.scan(body_content);
278            if !findings.is_empty() {
279                audit.leak_detected = true;
280                audit.duration_ms = start.elapsed().as_millis() as u64;
281                return Err(ProxyError::LeakDetected);
282            }
283        }
284
285        // 6. Inject credentials based on domain
286        let domain = parsed.host_str().unwrap_or("");
287
288        let mut req_headers = reqwest::header::HeaderMap::new();
289        for (k, v) in &headers {
290            if let (Ok(name), Ok(val)) = (
291                reqwest::header::HeaderName::from_bytes(k.as_bytes()),
292                reqwest::header::HeaderValue::from_str(v),
293            ) {
294                req_headers.insert(name, val);
295            }
296        }
297
298        if let Some(mapping) = self.injector.get_mapping(domain)
299            && let (Ok(name), Ok(val)) = (
300                reqwest::header::HeaderName::from_bytes(mapping.header.as_bytes()),
301                reqwest::header::HeaderValue::from_str(&mapping.value),
302            )
303        {
304            req_headers.insert(name, val);
305            audit.credential_injected = Some(domain.to_string());
306        }
307
308        // 7. Forward request
309        let reqwest_method = reqwest::Method::from_bytes(method.as_bytes())
310            .map_err(|_| ProxyError::InvalidUrl(format!("invalid method: {method}")))?;
311
312        let mut builder = client.request(reqwest_method, url).headers(req_headers);
313        if let Some(body_content) = body {
314            builder = builder.body(body_content);
315        }
316
317        let mut response = builder.send().await?;
318
319        let status = response.status().as_u16();
320        let resp_headers: HashMap<String, String> = response
321            .headers()
322            .iter()
323            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
324            .collect();
325
326        let max_bytes = self.config.max_response_bytes;
327        let mut body_bytes = Vec::with_capacity(max_bytes.min(65536));
328        while let Some(chunk) = response.chunk().await? {
329            body_bytes.extend_from_slice(&chunk);
330            if body_bytes.len() >= max_bytes {
331                body_bytes.truncate(max_bytes);
332                break;
333            }
334        }
335        let resp_body = String::from_utf8_lossy(&body_bytes).to_string();
336
337        // 8. Leak detection on response body
338        let resp_findings = self.leak_detector.scan(&resp_body);
339        if !resp_findings.is_empty() {
340            audit.leak_detected = true;
341            audit.duration_ms = start.elapsed().as_millis() as u64;
342            tracing::warn!(
343                url = url,
344                findings = resp_findings.len(),
345                "Credential leak detected in response body — redacting"
346            );
347            let redacted_body = self.leak_detector.redact(&resp_body);
348            return Ok(ProxyResponse {
349                status,
350                headers: resp_headers,
351                body: redacted_body,
352                audit,
353            });
354        }
355
356        // Scan response headers and redact any that contain leaked credentials
357        let mut resp_headers = resp_headers;
358        let mut leaked_header_names = Vec::new();
359        for (k, v) in &resp_headers {
360            let findings = self.leak_detector.scan(v);
361            if !findings.is_empty() {
362                audit.leak_detected = true;
363                leaked_header_names.push(k.clone());
364            }
365        }
366        if !leaked_header_names.is_empty() {
367            tracing::warn!(
368                url = url,
369                headers = ?leaked_header_names,
370                "Credential leak detected in response headers — redacting"
371            );
372            for header_name in &leaked_header_names {
373                resp_headers.insert(header_name.clone(), "[REDACTED]".to_string());
374            }
375        }
376
377        audit.status = status;
378        audit.duration_ms = start.elapsed().as_millis() as u64;
379        audit.timestamp = Utc::now().to_rfc3339();
380
381        Ok(ProxyResponse {
382            status,
383            headers: resp_headers,
384            body: resp_body,
385            audit,
386        })
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_blocked_url() {
396        let config = ProxyConfig {
397            allowlist: vec!["api.github.com".into()],
398            ..Default::default()
399        };
400        let service =
401            ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
402
403        let rt = tokio::runtime::Runtime::new().unwrap();
404        let result = rt.block_on(service.forward_request(
405            "https://evil.com/steal",
406            "GET",
407            HashMap::new(),
408            None,
409        ));
410
411        assert!(matches!(result, Err(ProxyError::Blocked(_))));
412    }
413
414    #[test]
415    fn test_leak_detected_in_body() {
416        let config = ProxyConfig {
417            allowlist: vec!["api.github.com".into()],
418            ..Default::default()
419        };
420        let mut detector = LeakDetector::new();
421        detector.add_known_secret("super_secret_key_12345");
422        let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
423
424        let rt = tokio::runtime::Runtime::new().unwrap();
425        let result = rt.block_on(service.forward_request(
426            "https://api.github.com/repos",
427            "POST",
428            HashMap::new(),
429            Some("body contains super_secret_key_12345 here".into()),
430        ));
431
432        assert!(matches!(result, Err(ProxyError::LeakDetected)));
433    }
434
435    #[test]
436    fn test_leak_detected_in_url() {
437        let config = ProxyConfig {
438            allowlist: vec!["api.github.com".into()],
439            ..Default::default()
440        };
441        let mut detector = LeakDetector::new();
442        detector.add_known_secret("my_secret_token");
443        let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
444
445        let rt = tokio::runtime::Runtime::new().unwrap();
446        let result = rt.block_on(service.forward_request(
447            "https://api.github.com/repos?key=my_secret_token",
448            "GET",
449            HashMap::new(),
450            None,
451        ));
452
453        assert!(matches!(result, Err(ProxyError::LeakDetected)));
454    }
455
456    #[test]
457    fn test_leak_detected_in_headers() {
458        let config = ProxyConfig {
459            allowlist: vec!["api.github.com".into()],
460            ..Default::default()
461        };
462        let mut detector = LeakDetector::new();
463        detector.add_known_secret("header_secret_value");
464        let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
465
466        let mut headers = HashMap::new();
467        headers.insert("X-Custom".to_string(), "header_secret_value".to_string());
468
469        let rt = tokio::runtime::Runtime::new().unwrap();
470        let result = rt.block_on(service.forward_request(
471            "https://api.github.com/repos",
472            "GET",
473            headers,
474            None,
475        ));
476
477        assert!(matches!(result, Err(ProxyError::LeakDetected)));
478    }
479
480    #[test]
481    fn test_private_ip_blocked() {
482        let config = ProxyConfig {
483            allowlist: vec!["*".into()],
484            ..Default::default()
485        };
486        let service =
487            ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
488
489        let rt = tokio::runtime::Runtime::new().unwrap();
490
491        for url in &[
492            "http://127.0.0.1/latest/meta-data",
493            "http://10.0.0.1/internal",
494            "http://172.16.0.1/internal",
495            "http://192.168.1.1/internal",
496            "http://169.254.169.254/latest/meta-data",
497            "http://0.0.0.0/",
498        ] {
499            let result = rt.block_on(service.forward_request(url, "GET", HashMap::new(), None));
500            assert!(
501                matches!(result, Err(ProxyError::PrivateIpBlocked(_))),
502                "Expected PrivateIpBlocked for {url}, got {result:?}"
503            );
504        }
505    }
506
507    #[test]
508    fn test_redirect_not_followed() {
509        // The client has redirect policy set to none.
510        // We can't easily test this without a server, but we verify construction works.
511        let config = ProxyConfig {
512            allowlist: vec!["httpbin.org".into()],
513            ..Default::default()
514        };
515        let _service =
516            ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
517    }
518
519    #[test]
520    fn test_allowed_url_passes_check() {
521        let config = ProxyConfig {
522            allowlist: vec!["httpbin.org".into()],
523            ..Default::default()
524        };
525        let enforcer = AllowlistEnforcer::new(config.allowlist.clone());
526        assert!(enforcer.is_allowed("https://httpbin.org/get"));
527    }
528
529    #[test]
530    fn test_ipv6_mapped_ipv4_blocked() {
531        let cases: Vec<(&str, bool)> = vec![
532            ("::ffff:127.0.0.1", true),
533            ("::ffff:10.0.0.1", true),
534            ("::ffff:192.168.1.1", true),
535            ("::ffff:172.16.0.1", true),
536            ("::ffff:8.8.8.8", false),
537            ("::1", true),
538        ];
539        for (s, expected) in cases {
540            let ip: IpAddr = s.parse().unwrap();
541            assert_eq!(
542                is_private_ip(&ip),
543                expected,
544                "is_private_ip({s}) = {expected}"
545            );
546        }
547    }
548
549    #[test]
550    fn test_dns_resolution_blocks_localhost() {
551        let config = ProxyConfig {
552            allowlist: vec!["*".into()],
553            ..Default::default()
554        };
555        let svc =
556            ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
557        let rt = tokio::runtime::Runtime::new().unwrap();
558        let result = rt.block_on(svc.forward_request(
559            "http://localhost:9800/test",
560            "GET",
561            HashMap::new(),
562            None,
563        ));
564        assert!(
565            matches!(result, Err(ProxyError::PrivateIpBlocked(_))),
566            "Expected PrivateIpBlocked for localhost, got {result:?}"
567        );
568    }
569
570    #[test]
571    fn test_response_header_leak_redacted() {
572        // Verify that the response header leak redaction logic works:
573        // When a response header value contains a known secret, it should be
574        // replaced with "[REDACTED]" rather than passed through.
575        let mut detector = LeakDetector::new();
576        let secret = "super_secret_credential_xyz";
577        detector.add_known_secret(secret);
578
579        let mut resp_headers: HashMap<String, String> = HashMap::new();
580        resp_headers.insert("x-safe".to_string(), "harmless".to_string());
581        resp_headers.insert("x-leaked".to_string(), format!("Bearer {}", secret));
582        resp_headers.insert("content-type".to_string(), "application/json".to_string());
583
584        // Simulate the response header scanning logic from forward_request
585        let mut leaked_header_names = Vec::new();
586        for (k, v) in &resp_headers {
587            let findings = detector.scan(v);
588            if !findings.is_empty() {
589                leaked_header_names.push(k.clone());
590            }
591        }
592        for header_name in &leaked_header_names {
593            resp_headers.insert(header_name.clone(), "[REDACTED]".to_string());
594        }
595
596        assert_eq!(
597            leaked_header_names.len(),
598            1,
599            "should detect exactly one leaked header"
600        );
601        assert!(leaked_header_names.contains(&"x-leaked".to_string()));
602        assert_eq!(resp_headers.get("x-leaked").unwrap(), "[REDACTED]");
603        assert_eq!(resp_headers.get("x-safe").unwrap(), "harmless");
604        assert_eq!(
605            resp_headers.get("content-type").unwrap(),
606            "application/json"
607        );
608    }
609}