1use std::borrow::Cow;
7
8use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
9use percent_encoding::percent_decode;
10
11use super::config::{SecretsConfig, ViolationAction};
12
13pub struct SecretsHandler {
22 eligible: Vec<EligibleSecret>,
24 all_placeholders: Vec<String>,
26 on_violation: ViolationAction,
28 has_ineligible: bool,
30 tls_intercepted: bool,
32 max_placeholder_len: usize,
34 prev_tail: Vec<u8>,
38}
39
40struct EligibleSecret {
42 placeholder: String,
43 value: String,
44 inject_headers: bool,
45 inject_basic_auth: bool,
46 inject_query_params: bool,
47 inject_body: bool,
48 require_tls_identity: bool,
49}
50
51impl EligibleSecret {
56 fn wants_header_injection(&self) -> bool {
59 self.inject_headers || self.inject_basic_auth || self.inject_query_params
60 }
61
62 fn substitute_in_headers(&self, headers: &str) -> String {
65 let mut result = String::with_capacity(headers.len());
66 for (i, line) in headers.split("\r\n").enumerate() {
67 if i > 0 {
68 result.push_str("\r\n");
69 }
70 match self.substitute_in_header_line(line, i == 0) {
71 Some(s) => result.push_str(&s),
72 None => result.push_str(line),
73 }
74 }
75 result
76 }
77
78 fn substitute_in_header_line(&self, line: &str, is_request_line: bool) -> Option<String> {
82 if self.inject_basic_auth
83 && is_authorization_header(line)
84 && let Some(replaced) = self.substitute_basic_auth_header(line)
85 {
86 return Some(replaced);
87 }
88 if self.inject_headers {
89 return Some(line.replace(&self.placeholder, &self.value));
90 }
91 if is_request_line && self.inject_query_params {
92 return Some(line.replace(&self.placeholder, &self.value));
93 }
94 None
95 }
96
97 fn substitute_basic_auth_header(&self, line: &str) -> Option<String> {
103 let decoded = decode_basic_credentials(line)?;
104 if !decoded.contains(&self.placeholder) {
105 return None;
106 }
107 let (name, _) = line.split_once(':')?;
108 let replaced = decoded.replace(&self.placeholder, &self.value);
109 Some(format!(
110 "{name}: Basic {}",
111 BASE64.encode(replaced.as_bytes())
112 ))
113 }
114}
115
116impl SecretsHandler {
117 pub fn new(config: &SecretsConfig, sni: &str, tls_intercepted: bool) -> Self {
124 let mut eligible = Vec::new();
125 let mut all_placeholders = Vec::new();
126
127 for secret in &config.secrets {
128 all_placeholders.push(secret.placeholder.clone());
129
130 let host_allowed = secret.allowed_hosts.is_empty()
131 || secret.allowed_hosts.iter().any(|p| p.matches(sni));
132
133 if host_allowed {
134 eligible.push(EligibleSecret {
135 placeholder: secret.placeholder.clone(),
136 value: secret.value.clone(),
137 inject_headers: secret.injection.headers,
138 inject_basic_auth: secret.injection.basic_auth,
139 inject_query_params: secret.injection.query_params,
140 inject_body: secret.injection.body,
141 require_tls_identity: secret.require_tls_identity,
142 });
143 }
144 }
145
146 let has_ineligible = eligible.len() < all_placeholders.len();
147 let max_placeholder_len = all_placeholders.iter().map(String::len).max().unwrap_or(0);
148
149 Self {
150 eligible,
151 all_placeholders,
152 on_violation: config.on_violation.clone(),
153 has_ineligible,
154 tls_intercepted,
155 max_placeholder_len,
156 prev_tail: Vec::new(),
157 }
158 }
159
160 pub fn substitute<'a>(&mut self, data: &'a [u8]) -> Option<Cow<'a, [u8]>> {
171 let boundary = find_header_boundary(data);
174 let (header_bytes, body_bytes) = match boundary {
175 Some(pos) => (&data[..pos], &data[pos..]),
176 None => (data, &[] as &[u8]),
177 };
178 let mut header_str = String::from_utf8_lossy(header_bytes).into_owned();
179 let mut body_str = if boundary.is_some() {
180 String::from_utf8_lossy(body_bytes).into_owned()
181 } else {
182 String::new()
183 };
184
185 if self.has_ineligible && self.has_violation(data, &header_str) {
187 self.update_tail(data);
188 match self.on_violation {
189 ViolationAction::Block => return None,
190 ViolationAction::BlockAndLog => {
191 tracing::warn!("secret violation: placeholder detected for disallowed host");
192 return None;
193 }
194 ViolationAction::BlockAndTerminate => {
195 tracing::error!(
196 "secret violation: placeholder detected for disallowed host — terminating"
197 );
198 return None;
199 }
200 }
201 }
202 self.update_tail(data);
203
204 if self.eligible.is_empty() {
205 return Some(Cow::Borrowed(data));
207 }
208
209 for secret in &self.eligible {
210 if secret.require_tls_identity && !self.tls_intercepted {
212 continue;
213 }
214 if secret.wants_header_injection() {
215 header_str = secret.substitute_in_headers(&header_str);
216 }
217 if boundary.is_some() && secret.inject_body && body_str.contains(&secret.placeholder) {
218 body_str = body_str.replace(&secret.placeholder, &secret.value);
219 }
220 }
221
222 if boundary.is_some() && body_str.len() != body_bytes.len() {
224 header_str = update_content_length(&header_str, body_str.len());
225 }
226
227 let mut output = header_str;
228 output.push_str(&body_str);
229 Some(Cow::Owned(output.into_bytes()))
230 }
231
232 pub fn is_empty(&self) -> bool {
234 self.all_placeholders.is_empty()
235 }
236
237 pub fn terminates_on_violation(&self) -> bool {
239 matches!(self.on_violation, ViolationAction::BlockAndTerminate)
240 }
241
242 fn has_violation(&self, data: &[u8], headers: &str) -> bool {
248 if self.eligible.len() == self.all_placeholders.len() {
251 return false;
252 }
253
254 let scan_buf: Cow<[u8]> = if self.prev_tail.is_empty() {
255 Cow::Borrowed(data)
256 } else {
257 let mut stitched = Vec::with_capacity(self.prev_tail.len() + data.len());
258 stitched.extend_from_slice(&self.prev_tail);
259 stitched.extend_from_slice(data);
260 Cow::Owned(stitched)
261 };
262 let scan = scan_buf.as_ref();
263
264 for placeholder in &self.all_placeholders {
265 if self.eligible.iter().any(|s| s.placeholder == *placeholder) {
266 continue;
267 }
268 let needle = placeholder.as_bytes();
269 if contains_bytes(scan, needle)
270 || url_decoded_contains(scan, needle)
271 || json_escaped_contains(scan, needle)
272 || basic_auth_decoded_contains(headers, placeholder)
273 {
274 return true;
275 }
276 }
277
278 false
279 }
280
281 fn update_tail(&mut self, data: &[u8]) {
285 let tail_size = self.max_placeholder_len.saturating_sub(1);
286 if tail_size == 0 {
287 return;
288 }
289 if data.len() >= tail_size {
290 self.prev_tail.clear();
291 self.prev_tail
292 .extend_from_slice(&data[data.len() - tail_size..]);
293 return;
294 }
295 self.prev_tail.extend_from_slice(data);
296 let overflow = self.prev_tail.len().saturating_sub(tail_size);
297 if overflow > 0 {
298 self.prev_tail.drain(..overflow);
299 }
300 }
301}
302
303fn is_authorization_header(line: &str) -> bool {
310 line.as_bytes()
311 .get(..14)
312 .is_some_and(|b| b.eq_ignore_ascii_case(b"authorization:"))
313}
314
315fn decode_basic_credentials(line: &str) -> Option<String> {
319 let (_, raw_value) = line.split_once(':')?;
320 let (scheme, encoded) = split_auth_scheme(raw_value.trim_start())?;
321 if !scheme.eq_ignore_ascii_case("basic") {
322 return None;
323 }
324 let bytes = BASE64.decode(encoded.trim()).ok()?;
325 String::from_utf8(bytes).ok()
326}
327
328fn split_auth_scheme(header_value: &str) -> Option<(&str, &str)> {
331 let split_at = header_value.find(char::is_whitespace)?;
332 let (scheme, rest) = header_value.split_at(split_at);
333 Some((scheme, rest.trim_start()))
334}
335
336fn basic_auth_decoded_contains(headers: &str, placeholder: &str) -> bool {
339 headers
340 .split("\r\n")
341 .filter(|line| is_authorization_header(line))
342 .filter_map(decode_basic_credentials)
343 .any(|decoded| decoded.contains(placeholder))
344}
345
346fn contains_bytes(haystack: &[u8], needle: &[u8]) -> bool {
348 if needle.is_empty() || haystack.len() < needle.len() {
349 return false;
350 }
351 haystack.windows(needle.len()).any(|w| w == needle)
352}
353
354fn url_decoded_contains(haystack: &[u8], needle: &[u8]) -> bool {
356 let decoded: Vec<u8> = percent_decode(haystack).collect();
357 contains_bytes(&decoded, needle)
358}
359
360fn json_escaped_contains(haystack: &[u8], needle: &[u8]) -> bool {
364 let mut decoded = Vec::with_capacity(haystack.len());
365 let mut i = 0;
366 while i < haystack.len() {
367 if haystack[i] == b'\\'
368 && i + 5 < haystack.len()
369 && haystack[i + 1] == b'u'
370 && let (Some(a), Some(b), Some(c), Some(d)) = (
371 hex_digit(haystack[i + 2]),
372 hex_digit(haystack[i + 3]),
373 hex_digit(haystack[i + 4]),
374 hex_digit(haystack[i + 5]),
375 )
376 {
377 let cp = ((a as u32) << 12) | ((b as u32) << 8) | ((c as u32) << 4) | (d as u32);
378 if let Some(ch) = char::from_u32(cp) {
379 let mut buf = [0u8; 4];
380 decoded.extend_from_slice(ch.encode_utf8(&mut buf).as_bytes());
381 }
382 i += 6;
383 continue;
384 }
385 decoded.push(haystack[i]);
386 i += 1;
387 }
388 contains_bytes(&decoded, needle)
389}
390
391fn hex_digit(b: u8) -> Option<u8> {
392 (b as char).to_digit(16).map(|d| d as u8)
393}
394
395fn update_content_length(headers: &str, new_len: usize) -> String {
400 let mut result = String::with_capacity(headers.len());
401 for (i, line) in headers.split("\r\n").enumerate() {
402 if i > 0 {
403 result.push_str("\r\n");
404 }
405 if line
406 .as_bytes()
407 .get(..15)
408 .is_some_and(|b| b.eq_ignore_ascii_case(b"content-length:"))
409 {
410 result.push_str(&format!("Content-Length: {new_len}"));
411 } else {
412 result.push_str(line);
413 }
414 }
415 result
416}
417
418fn find_header_boundary(data: &[u8]) -> Option<usize> {
420 data.windows(4)
421 .position(|w| w == b"\r\n\r\n")
422 .map(|pos| pos + 4)
423}
424
425#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::secrets::config::*;
433
434 fn make_config(secrets: Vec<SecretEntry>) -> SecretsConfig {
435 SecretsConfig {
436 secrets,
437 on_violation: ViolationAction::Block,
438 }
439 }
440
441 fn make_secret(placeholder: &str, value: &str, host: &str) -> SecretEntry {
442 SecretEntry {
443 env_var: "TEST_KEY".into(),
444 value: value.into(),
445 placeholder: placeholder.into(),
446 allowed_hosts: vec![HostPattern::Exact(host.into())],
447 injection: SecretInjection::default(),
448 require_tls_identity: true,
449 }
450 }
451
452 fn basic_auth_only() -> SecretInjection {
453 SecretInjection {
454 headers: false,
455 basic_auth: true,
456 query_params: false,
457 body: false,
458 }
459 }
460
461 #[test]
462 fn substitute_in_headers() {
463 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
464 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
465
466 let input = b"GET / HTTP/1.1\r\nAuthorization: Bearer $KEY\r\n\r\n";
467 let output = handler.substitute(input).unwrap();
468 assert_eq!(
469 String::from_utf8(output.into_owned()).unwrap(),
470 "GET / HTTP/1.1\r\nAuthorization: Bearer real-secret\r\n\r\n"
471 );
472 }
473
474 #[test]
475 fn no_substitute_for_wrong_host() {
476 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
477 let mut handler = SecretsHandler::new(&config, "evil.com", true);
478
479 let input = b"GET / HTTP/1.1\r\nAuthorization: Bearer $KEY\r\n\r\n";
480 assert!(handler.substitute(input).is_none());
481 }
482
483 #[test]
484 fn body_injection_disabled_by_default() {
485 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
486 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
487
488 let input = b"POST / HTTP/1.1\r\n\r\n{\"key\": \"$KEY\"}";
489 let output = handler.substitute(input).unwrap();
490 assert!(
491 String::from_utf8(output.into_owned())
492 .unwrap()
493 .contains("$KEY")
494 );
495 }
496
497 #[test]
498 fn body_injection_when_enabled() {
499 let mut secret = make_secret("$KEY", "real-secret", "api.openai.com");
500 secret.injection.body = true;
501 let config = make_config(vec![secret]);
502 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
503
504 let input = b"POST / HTTP/1.1\r\n\r\n{\"key\": \"$KEY\"}";
505 let output = handler.substitute(input).unwrap();
506 assert_eq!(
507 String::from_utf8(output.into_owned()).unwrap(),
508 "POST / HTTP/1.1\r\n\r\n{\"key\": \"real-secret\"}"
509 );
510 }
511
512 #[test]
513 fn body_injection_updates_content_length() {
514 let mut secret = make_secret("$KEY", "a]longer]secret]value", "api.openai.com");
515 secret.injection.body = true;
516 let config = make_config(vec![secret]);
517 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
518
519 let body = "{\"key\": \"$KEY\"}";
520 let input = format!(
521 "POST / HTTP/1.1\r\nContent-Length: {}\r\n\r\n{}",
522 body.len(),
523 body
524 );
525 let output = handler.substitute(input.as_bytes()).unwrap();
526 let result = String::from_utf8(output.into_owned()).unwrap();
527
528 let expected_body = "{\"key\": \"a]longer]secret]value\"}";
529 assert!(result.contains(expected_body));
530 assert!(result.contains(&format!("Content-Length: {}", expected_body.len())));
531 }
532
533 #[test]
534 fn body_injection_no_content_length_header() {
535 let mut secret = make_secret("$KEY", "longer-secret", "api.openai.com");
536 secret.injection.body = true;
537 let config = make_config(vec![secret]);
538 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
539
540 let input = b"POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n{\"key\": \"$KEY\"}";
542 let output = handler.substitute(input).unwrap();
543 let result = String::from_utf8(output.into_owned()).unwrap();
544 assert!(result.contains("longer-secret"));
545 assert!(!result.contains("Content-Length"));
546 }
547
548 #[test]
549 fn header_only_substitution_preserves_content_length() {
550 let config = make_config(vec![make_secret("$KEY", "longer-value", "api.openai.com")]);
551 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
552
553 let input =
554 b"GET / HTTP/1.1\r\nAuthorization: Bearer $KEY\r\nContent-Length: 5\r\n\r\nhello";
555 let output = handler.substitute(input).unwrap();
556 let result = String::from_utf8(output.into_owned()).unwrap();
557 assert!(result.contains("Content-Length: 5"));
559 assert!(result.ends_with("hello"));
560 }
561
562 #[test]
563 fn no_secrets_passthrough() {
564 let config = make_config(vec![]);
565 let mut handler = SecretsHandler::new(&config, "anything.com", true);
566
567 let input = b"GET / HTTP/1.1\r\n\r\n";
568 let output = handler.substitute(input).unwrap();
569 assert_eq!(&*output, input);
570 }
571
572 #[test]
573 fn require_tls_identity_blocks_on_non_intercepted() {
574 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
575 let mut handler = SecretsHandler::new(&config, "api.openai.com", false);
577
578 let input = b"GET / HTTP/1.1\r\nAuthorization: Bearer $KEY\r\n\r\n";
579 let output = handler.substitute(input).unwrap();
580 assert!(
582 String::from_utf8(output.into_owned())
583 .unwrap()
584 .contains("$KEY")
585 );
586 }
587
588 #[test]
589 fn basic_auth_only_does_not_substitute_other_schemes() {
590 let mut secret = make_secret("$KEY", "real-secret", "api.openai.com");
591 secret.injection = basic_auth_only();
592 let config = make_config(vec![secret]);
593 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
594
595 let input = b"GET / HTTP/1.1\r\nAuthorization: Bearer $KEY\r\nX-Custom: $KEY\r\n\r\n";
597 let output = handler.substitute(input).unwrap();
598 let result = String::from_utf8(output.into_owned()).unwrap();
599 assert!(result.contains("Authorization: Bearer $KEY"));
600 assert!(result.contains("X-Custom: $KEY"));
601 }
602
603 #[test]
604 fn basic_auth_decodes_substitutes_and_reencodes_credentials() {
605 let mut user = make_secret("$MSB_USER", "alice", "api.openai.com");
606 user.env_var = "USER".into();
607 user.injection = basic_auth_only();
608 let mut password = make_secret("$MSB_PASSWORD", "s3cr3t", "api.openai.com");
609 password.env_var = "PASSWORD".into();
610 password.injection = basic_auth_only();
611 let config = make_config(vec![user, password]);
612 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
613
614 let encoded = BASE64.encode(b"$MSB_USER:$MSB_PASSWORD");
615 let input = format!("GET / HTTP/1.1\r\nAuthorization: Basic {encoded}\r\n\r\n");
616 let output = handler.substitute(input.as_bytes()).unwrap();
617 let result = String::from_utf8(output.into_owned()).unwrap();
618
619 assert!(result.contains(&format!(
620 "Authorization: Basic {}",
621 BASE64.encode(b"alice:s3cr3t")
622 )));
623 assert!(!result.contains("$MSB_USER"));
624 assert!(!result.contains("$MSB_PASSWORD"));
625 }
626
627 #[test]
628 fn basic_auth_encoded_placeholder_is_blocked_for_wrong_host() {
629 let mut secret = make_secret("$MSB_PASSWORD", "s3cr3t", "api.openai.com");
630 secret.injection = basic_auth_only();
631 let config = make_config(vec![secret]);
632 let mut handler = SecretsHandler::new(&config, "evil.com", true);
633
634 let encoded = BASE64.encode(b"user:$MSB_PASSWORD");
635 let input = format!("GET / HTTP/1.1\r\nAuthorization: Basic {encoded}\r\n\r\n");
636
637 assert!(handler.substitute(input.as_bytes()).is_none());
638 }
639
640 #[test]
641 fn basic_auth_encoded_placeholder_is_not_replaced_when_scope_disabled() {
642 let mut secret = make_secret("$MSB_PASSWORD", "s3cr3t", "api.openai.com");
643 secret.injection = SecretInjection {
644 headers: false,
645 basic_auth: false,
646 query_params: false,
647 body: false,
648 };
649 let config = make_config(vec![secret]);
650 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
651
652 let encoded = BASE64.encode(b"user:$MSB_PASSWORD");
653 let input = format!("GET / HTTP/1.1\r\nAuthorization: Basic {encoded}\r\n\r\n");
654 let output = handler.substitute(input.as_bytes()).unwrap();
655
656 assert_eq!(String::from_utf8(output.into_owned()).unwrap(), input);
657 }
658
659 #[test]
660 fn query_params_substitution() {
661 let mut secret = make_secret("$KEY", "real-secret", "api.openai.com");
662 secret.injection = SecretInjection {
663 headers: false,
664 basic_auth: false,
665 query_params: true,
666 body: false,
667 };
668 let config = make_config(vec![secret]);
669 let mut handler = SecretsHandler::new(&config, "api.openai.com", true);
670
671 let input = b"GET /api?key=$KEY HTTP/1.1\r\nHost: api.openai.com\r\n\r\n";
672 let output = handler.substitute(input).unwrap();
673 let result = String::from_utf8(output.into_owned()).unwrap();
674 assert!(result.contains("GET /api?key=real-secret HTTP/1.1"));
676 }
678
679 #[test]
680 fn url_encoded_placeholder_in_query_blocks_for_wrong_host() {
681 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
682 let mut handler = SecretsHandler::new(&config, "evil.com", true);
683
684 let input = b"GET /api?token=%24KEY HTTP/1.1\r\nHost: evil.com\r\n\r\n";
686 assert!(handler.substitute(input).is_none());
687 }
688
689 #[test]
690 fn url_encoded_placeholder_in_body_blocks_for_wrong_host() {
691 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
692 let mut handler = SecretsHandler::new(&config, "evil.com", true);
693
694 let input = b"POST / HTTP/1.1\r\nContent-Length: 13\r\n\r\nkey=%24KEY&x=1";
695 assert!(handler.substitute(input).is_none());
696 }
697
698 #[test]
699 fn json_escaped_placeholder_in_body_blocks_for_wrong_host() {
700 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
701 let mut handler = SecretsHandler::new(&config, "evil.com", true);
702
703 let input =
705 b"POST / HTTP/1.1\r\nContent-Type: application/json\r\n\r\n{\"k\":\"\\u0024KEY\"}";
706 assert!(handler.substitute(input).is_none());
707 }
708
709 #[test]
710 fn placeholder_split_across_writes_blocks_for_wrong_host() {
711 let config = make_config(vec![make_secret("$KEY", "real-secret", "api.openai.com")]);
712 let mut handler = SecretsHandler::new(&config, "evil.com", true);
713
714 let first = b"GET / HTTP/1.1\r\nX-Token: $K";
716 let second = b"EY\r\nHost: evil.com\r\n\r\n";
717
718 assert!(handler.substitute(first).is_some());
720 assert!(handler.substitute(second).is_none());
722 }
723
724 #[test]
725 fn url_decoded_contains_basic() {
726 assert!(url_decoded_contains(b"foo%24KEYbar", b"$KEY"));
727 assert!(!url_decoded_contains(b"fooKEYbar", b"$KEY"));
728 assert!(url_decoded_contains(b"%2", b"%2"));
730 }
731
732 #[test]
733 fn json_escaped_contains_basic() {
734 assert!(json_escaped_contains(b"\"\\u0024KEY\"", b"$KEY"));
735 assert!(json_escaped_contains(
736 b"\\u0024\\u004B\\u0045\\u0059",
737 b"$KEY"
738 ));
739 assert!(!json_escaped_contains(b"KEY", b"$KEY"));
740 }
741}