1use std::collections::HashMap;
6use std::net::IpAddr;
7use std::sync::Arc;
8use std::time::Instant;
9
10use chrono::Utc;
11use reqwest::Client;
12use thiserror::Error;
13
14use crate::allowlist::AllowlistEnforcer;
15use crate::audit::AuditEntry;
16use crate::credentials::CredentialInjector;
17use crate::leak_detection::LeakDetector;
18use crate::rate_limiter::RateLimiter;
19
20#[derive(Debug, Error)]
22#[non_exhaustive]
23pub enum ProxyError {
24 #[error("URL blocked by allowlist: {0}")]
25 Blocked(String),
26 #[error("credential leak detected in outbound request")]
27 LeakDetected,
28 #[error("HTTP error: {0}")]
29 Http(#[from] reqwest::Error),
30 #[error("invalid URL: {0}")]
31 InvalidUrl(String),
32 #[error("request to private/internal IP blocked: {0}")]
33 PrivateIpBlocked(String),
34 #[error("failed to build HTTP client: {0}")]
35 ClientBuild(String),
36 #[error("rate limited: {0}")]
37 RateLimited(String),
38}
39
40#[derive(Debug, Clone)]
42#[non_exhaustive]
43#[must_use]
44pub struct ProxyConfig {
45 pub allowlist: Vec<String>,
46 pub max_response_bytes: usize,
47 pub timeout_ms: u64,
48}
49
50impl ProxyConfig {
51 pub fn new(allowlist: Vec<String>, max_response_bytes: usize, timeout_ms: u64) -> Self {
53 Self {
54 allowlist,
55 max_response_bytes,
56 timeout_ms,
57 }
58 }
59}
60
61impl Default for ProxyConfig {
62 fn default() -> Self {
63 Self {
64 allowlist: Vec::new(),
65 max_response_bytes: 10 * 1024 * 1024, timeout_ms: 30_000,
67 }
68 }
69}
70
71#[derive(Debug)]
73#[must_use]
74#[non_exhaustive]
75pub struct ProxyResponse {
76 pub status: u16,
77 pub headers: HashMap<String, String>,
78 pub body: String,
79 pub audit: AuditEntry,
80}
81
82#[non_exhaustive]
84pub struct ProxyService {
85 enforcer: AllowlistEnforcer,
86 injector: CredentialInjector,
87 leak_detector: LeakDetector,
88 client: Client,
89 config: ProxyConfig,
90 rate_limiter: Option<Arc<RateLimiter>>,
91 rate_limit_key: Option<String>,
92}
93
94fn is_private_ip(ip: &IpAddr) -> bool {
96 match ip {
97 IpAddr::V4(v4) => {
98 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.is_broadcast() }
104 IpAddr::V6(v6) => {
105 if let Some(mapped) = v6.to_ipv4_mapped() {
106 return is_private_ip(&IpAddr::V4(mapped));
107 }
108 v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00
112 || (v6.segments()[0] & 0xffc0) == 0xfe80
114 }
115 }
116}
117
118fn check_private_ip(parsed: &url::Url) -> Result<(), ProxyError> {
122 if let Some(host) = parsed.host_str() {
123 if let Ok(ip) = host.parse::<IpAddr>()
124 && is_private_ip(&ip)
125 {
126 return Err(ProxyError::PrivateIpBlocked(host.to_string()));
127 }
128 let trimmed = host.trim_start_matches('[').trim_end_matches(']');
130 if let Ok(ip) = trimmed.parse::<IpAddr>()
131 && is_private_ip(&ip)
132 {
133 return Err(ProxyError::PrivateIpBlocked(trimmed.to_string()));
134 }
135 }
136 Ok(())
137}
138
139impl ProxyService {
140 pub fn new(
141 config: ProxyConfig,
142 injector: CredentialInjector,
143 leak_detector: LeakDetector,
144 ) -> Result<Self, ProxyError> {
145 let enforcer = AllowlistEnforcer::new(config.allowlist.clone());
146 let client = Client::builder()
147 .danger_accept_invalid_certs(false)
148 .redirect(reqwest::redirect::Policy::none())
149 .timeout(std::time::Duration::from_millis(config.timeout_ms))
150 .build()
151 .map_err(|e| ProxyError::ClientBuild(e.to_string()))?;
152 Ok(Self {
153 enforcer,
154 injector,
155 leak_detector,
156 client,
157 config,
158 rate_limiter: None,
159 rate_limit_key: None,
160 })
161 }
162
163 pub fn with_client(mut self, client: Client) -> Self {
166 self.client = client;
167 self
168 }
169
170 pub fn with_rate_limiter(mut self, limiter: Arc<RateLimiter>) -> Self {
171 self.rate_limiter = Some(limiter);
172 self
173 }
174
175 pub fn with_rate_limit_key(mut self, key: impl Into<String>) -> Self {
177 self.rate_limit_key = Some(key.into());
178 self
179 }
180
181 pub async fn forward_request(
183 &self,
184 url: &str,
185 method: &str,
186 headers: HashMap<String, String>,
187 body: Option<String>,
188 ) -> Result<ProxyResponse, ProxyError> {
189 let start = Instant::now();
190 let mut audit = AuditEntry::new(url.to_string(), method.to_string());
191
192 if let Some(ref limiter) = self.rate_limiter {
194 let key = self.rate_limit_key.as_deref().unwrap_or("default");
195 if !limiter.check(key) {
196 return Err(ProxyError::RateLimited(key.to_string()));
197 }
198 }
199
200 if !self.enforcer.is_allowed(url) {
202 audit.blocked = true;
203 audit.duration_ms = start.elapsed().as_millis() as u64;
204 return Err(ProxyError::Blocked(url.to_string()));
205 }
206
207 let parsed = url::Url::parse(url).map_err(|e| ProxyError::InvalidUrl(e.to_string()))?;
209 check_private_ip(&parsed)?;
210
211 let pinned_client = if let Some(host) = parsed.host_str() {
214 if host.parse::<IpAddr>().is_err() {
215 let port = parsed.port_or_known_default().unwrap_or(80);
216 let lookup = format!("{}:{}", host, port);
217 match tokio::net::lookup_host(&lookup).await {
218 Ok(addrs) => {
219 let addrs: Vec<_> = addrs.collect();
220 for addr in &addrs {
221 if is_private_ip(&addr.ip()) {
222 return Err(ProxyError::PrivateIpBlocked(format!(
223 "{} resolves to private IP {}",
224 host,
225 addr.ip()
226 )));
227 }
228 }
229 let validated_addr = addrs[0];
231 Some(
232 Client::builder()
233 .danger_accept_invalid_certs(false)
234 .redirect(reqwest::redirect::Policy::none())
235 .timeout(std::time::Duration::from_millis(self.config.timeout_ms))
236 .resolve(host, validated_addr)
237 .pool_max_idle_per_host(0)
238 .build()
239 .map_err(|e| ProxyError::ClientBuild(e.to_string()))?,
240 )
241 }
242 Err(e) => {
243 return Err(ProxyError::InvalidUrl(format!(
244 "DNS resolution failed for {}: {}",
245 host, e
246 )));
247 }
248 }
249 } else {
250 None }
252 } else {
253 None
254 };
255 let client = pinned_client.as_ref().unwrap_or(&self.client);
256
257 let url_findings = self.leak_detector.scan(url);
259 if !url_findings.is_empty() {
260 audit.leak_detected = true;
261 audit.duration_ms = start.elapsed().as_millis() as u64;
262 return Err(ProxyError::LeakDetected);
263 }
264
265 for v in headers.values() {
267 let findings = self.leak_detector.scan(v);
268 if !findings.is_empty() {
269 audit.leak_detected = true;
270 audit.duration_ms = start.elapsed().as_millis() as u64;
271 return Err(ProxyError::LeakDetected);
272 }
273 }
274
275 if let Some(ref body_content) = body {
277 let findings = self.leak_detector.scan(body_content);
278 if !findings.is_empty() {
279 audit.leak_detected = true;
280 audit.duration_ms = start.elapsed().as_millis() as u64;
281 return Err(ProxyError::LeakDetected);
282 }
283 }
284
285 let domain = parsed.host_str().unwrap_or("");
287
288 let mut req_headers = reqwest::header::HeaderMap::new();
289 for (k, v) in &headers {
290 if let (Ok(name), Ok(val)) = (
291 reqwest::header::HeaderName::from_bytes(k.as_bytes()),
292 reqwest::header::HeaderValue::from_str(v),
293 ) {
294 req_headers.insert(name, val);
295 }
296 }
297
298 if let Some(mapping) = self.injector.get_mapping(domain)
299 && let (Ok(name), Ok(val)) = (
300 reqwest::header::HeaderName::from_bytes(mapping.header.as_bytes()),
301 reqwest::header::HeaderValue::from_str(&mapping.value),
302 )
303 {
304 req_headers.insert(name, val);
305 audit.credential_injected = Some(domain.to_string());
306 }
307
308 let reqwest_method = reqwest::Method::from_bytes(method.as_bytes())
310 .map_err(|_| ProxyError::InvalidUrl(format!("invalid method: {method}")))?;
311
312 let mut builder = client.request(reqwest_method, url).headers(req_headers);
313 if let Some(body_content) = body {
314 builder = builder.body(body_content);
315 }
316
317 let mut response = builder.send().await?;
318
319 let status = response.status().as_u16();
320 let resp_headers: HashMap<String, String> = response
321 .headers()
322 .iter()
323 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
324 .collect();
325
326 let max_bytes = self.config.max_response_bytes;
327 let mut body_bytes = Vec::with_capacity(max_bytes.min(65536));
328 while let Some(chunk) = response.chunk().await? {
329 body_bytes.extend_from_slice(&chunk);
330 if body_bytes.len() >= max_bytes {
331 body_bytes.truncate(max_bytes);
332 break;
333 }
334 }
335 let resp_body = String::from_utf8_lossy(&body_bytes).to_string();
336
337 let resp_findings = self.leak_detector.scan(&resp_body);
339 if !resp_findings.is_empty() {
340 audit.leak_detected = true;
341 audit.duration_ms = start.elapsed().as_millis() as u64;
342 tracing::warn!(
343 url = url,
344 findings = resp_findings.len(),
345 "Credential leak detected in response body — redacting"
346 );
347 let redacted_body = self.leak_detector.redact(&resp_body);
348 return Ok(ProxyResponse {
349 status,
350 headers: resp_headers,
351 body: redacted_body,
352 audit,
353 });
354 }
355
356 let mut resp_headers = resp_headers;
358 let mut leaked_header_names = Vec::new();
359 for (k, v) in &resp_headers {
360 let findings = self.leak_detector.scan(v);
361 if !findings.is_empty() {
362 audit.leak_detected = true;
363 leaked_header_names.push(k.clone());
364 }
365 }
366 if !leaked_header_names.is_empty() {
367 tracing::warn!(
368 url = url,
369 headers = ?leaked_header_names,
370 "Credential leak detected in response headers — redacting"
371 );
372 for header_name in &leaked_header_names {
373 resp_headers.insert(header_name.clone(), "[REDACTED]".to_string());
374 }
375 }
376
377 audit.status = status;
378 audit.duration_ms = start.elapsed().as_millis() as u64;
379 audit.timestamp = Utc::now().to_rfc3339();
380
381 Ok(ProxyResponse {
382 status,
383 headers: resp_headers,
384 body: resp_body,
385 audit,
386 })
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_blocked_url() {
396 let config = ProxyConfig {
397 allowlist: vec!["api.github.com".into()],
398 ..Default::default()
399 };
400 let service =
401 ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
402
403 let rt = tokio::runtime::Runtime::new().unwrap();
404 let result = rt.block_on(service.forward_request(
405 "https://evil.com/steal",
406 "GET",
407 HashMap::new(),
408 None,
409 ));
410
411 assert!(matches!(result, Err(ProxyError::Blocked(_))));
412 }
413
414 #[test]
415 fn test_leak_detected_in_body() {
416 let config = ProxyConfig {
417 allowlist: vec!["api.github.com".into()],
418 ..Default::default()
419 };
420 let mut detector = LeakDetector::new();
421 detector.add_known_secret("super_secret_key_12345");
422 let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
423
424 let rt = tokio::runtime::Runtime::new().unwrap();
425 let result = rt.block_on(service.forward_request(
426 "https://api.github.com/repos",
427 "POST",
428 HashMap::new(),
429 Some("body contains super_secret_key_12345 here".into()),
430 ));
431
432 assert!(matches!(result, Err(ProxyError::LeakDetected)));
433 }
434
435 #[test]
436 fn test_leak_detected_in_url() {
437 let config = ProxyConfig {
438 allowlist: vec!["api.github.com".into()],
439 ..Default::default()
440 };
441 let mut detector = LeakDetector::new();
442 detector.add_known_secret("my_secret_token");
443 let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
444
445 let rt = tokio::runtime::Runtime::new().unwrap();
446 let result = rt.block_on(service.forward_request(
447 "https://api.github.com/repos?key=my_secret_token",
448 "GET",
449 HashMap::new(),
450 None,
451 ));
452
453 assert!(matches!(result, Err(ProxyError::LeakDetected)));
454 }
455
456 #[test]
457 fn test_leak_detected_in_headers() {
458 let config = ProxyConfig {
459 allowlist: vec!["api.github.com".into()],
460 ..Default::default()
461 };
462 let mut detector = LeakDetector::new();
463 detector.add_known_secret("header_secret_value");
464 let service = ProxyService::new(config, CredentialInjector::new(), detector).unwrap();
465
466 let mut headers = HashMap::new();
467 headers.insert("X-Custom".to_string(), "header_secret_value".to_string());
468
469 let rt = tokio::runtime::Runtime::new().unwrap();
470 let result = rt.block_on(service.forward_request(
471 "https://api.github.com/repos",
472 "GET",
473 headers,
474 None,
475 ));
476
477 assert!(matches!(result, Err(ProxyError::LeakDetected)));
478 }
479
480 #[test]
481 fn test_private_ip_blocked() {
482 let config = ProxyConfig {
483 allowlist: vec!["*".into()],
484 ..Default::default()
485 };
486 let service =
487 ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
488
489 let rt = tokio::runtime::Runtime::new().unwrap();
490
491 for url in &[
492 "http://127.0.0.1/latest/meta-data",
493 "http://10.0.0.1/internal",
494 "http://172.16.0.1/internal",
495 "http://192.168.1.1/internal",
496 "http://169.254.169.254/latest/meta-data",
497 "http://0.0.0.0/",
498 ] {
499 let result = rt.block_on(service.forward_request(url, "GET", HashMap::new(), None));
500 assert!(
501 matches!(result, Err(ProxyError::PrivateIpBlocked(_))),
502 "Expected PrivateIpBlocked for {url}, got {result:?}"
503 );
504 }
505 }
506
507 #[test]
508 fn test_redirect_not_followed() {
509 let config = ProxyConfig {
512 allowlist: vec!["httpbin.org".into()],
513 ..Default::default()
514 };
515 let _service =
516 ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
517 }
518
519 #[test]
520 fn test_allowed_url_passes_check() {
521 let config = ProxyConfig {
522 allowlist: vec!["httpbin.org".into()],
523 ..Default::default()
524 };
525 let enforcer = AllowlistEnforcer::new(config.allowlist.clone());
526 assert!(enforcer.is_allowed("https://httpbin.org/get"));
527 }
528
529 #[test]
530 fn test_ipv6_mapped_ipv4_blocked() {
531 let cases: Vec<(&str, bool)> = vec![
532 ("::ffff:127.0.0.1", true),
533 ("::ffff:10.0.0.1", true),
534 ("::ffff:192.168.1.1", true),
535 ("::ffff:172.16.0.1", true),
536 ("::ffff:8.8.8.8", false),
537 ("::1", true),
538 ];
539 for (s, expected) in cases {
540 let ip: IpAddr = s.parse().unwrap();
541 assert_eq!(
542 is_private_ip(&ip),
543 expected,
544 "is_private_ip({s}) = {expected}"
545 );
546 }
547 }
548
549 #[test]
550 fn test_dns_resolution_blocks_localhost() {
551 let config = ProxyConfig {
552 allowlist: vec!["*".into()],
553 ..Default::default()
554 };
555 let svc =
556 ProxyService::new(config, CredentialInjector::new(), LeakDetector::new()).unwrap();
557 let rt = tokio::runtime::Runtime::new().unwrap();
558 let result = rt.block_on(svc.forward_request(
559 "http://localhost:9800/test",
560 "GET",
561 HashMap::new(),
562 None,
563 ));
564 assert!(
565 matches!(result, Err(ProxyError::PrivateIpBlocked(_))),
566 "Expected PrivateIpBlocked for localhost, got {result:?}"
567 );
568 }
569
570 #[test]
571 fn test_response_header_leak_redacted() {
572 let mut detector = LeakDetector::new();
576 let secret = "super_secret_credential_xyz";
577 detector.add_known_secret(secret);
578
579 let mut resp_headers: HashMap<String, String> = HashMap::new();
580 resp_headers.insert("x-safe".to_string(), "harmless".to_string());
581 resp_headers.insert("x-leaked".to_string(), format!("Bearer {}", secret));
582 resp_headers.insert("content-type".to_string(), "application/json".to_string());
583
584 let mut leaked_header_names = Vec::new();
586 for (k, v) in &resp_headers {
587 let findings = detector.scan(v);
588 if !findings.is_empty() {
589 leaked_header_names.push(k.clone());
590 }
591 }
592 for header_name in &leaked_header_names {
593 resp_headers.insert(header_name.clone(), "[REDACTED]".to_string());
594 }
595
596 assert_eq!(
597 leaked_header_names.len(),
598 1,
599 "should detect exactly one leaked header"
600 );
601 assert!(leaked_header_names.contains(&"x-leaked".to_string()));
602 assert_eq!(resp_headers.get("x-leaked").unwrap(), "[REDACTED]");
603 assert_eq!(resp_headers.get("x-safe").unwrap(), "harmless");
604 assert_eq!(
605 resp_headers.get("content-type").unwrap(),
606 "application/json"
607 );
608 }
609}