use std::fmt;
use crate::conn_string::{
parse as parse_conn_string, ConnectionTarget, ParseError as ConnParseError,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Boundary {
HttpHeader,
GrpcMetadata,
LogField,
AuditField,
JsonValue,
}
impl Boundary {
pub fn as_str(self) -> &'static str {
match self {
Boundary::HttpHeader => "HttpHeader",
Boundary::GrpcMetadata => "GrpcMetadata",
Boundary::LogField => "LogField",
Boundary::AuditField => "AuditField",
Boundary::JsonValue => "JsonValue",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EscapeError {
TooLong { boundary: Boundary, bytes: usize },
}
impl fmt::Display for EscapeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EscapeError::TooLong { boundary, bytes } => write!(
f,
"escape_for({}) would emit {} bytes (limit {})",
boundary.as_str(),
bytes,
Tainted::<String>::MAX_ESCAPED_LEN,
),
}
}
}
impl std::error::Error for EscapeError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EscapedFor {
boundary: Boundary,
value: String,
}
impl EscapedFor {
pub fn boundary(&self) -> Boundary {
self.boundary
}
pub fn as_str(&self) -> &str {
&self.value
}
pub fn into_string(self) -> String {
self.value
}
}
impl fmt::Display for EscapedFor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.value)
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Tainted<T>(pub(crate) T);
impl<T> Tainted<T> {
pub fn new(value: T) -> Self {
Self(value)
}
pub fn expose_secret(&self) -> &T {
&self.0
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: fmt::Debug> fmt::Debug for Tainted<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Tainted").field(&self.0).finish()
}
}
impl Tainted<String> {
pub const MAX_ESCAPED_LEN: usize = 8 * 1024;
pub fn escape_for(&self, boundary: Boundary) -> Result<EscapedFor, EscapeError> {
let escaped = match boundary {
Boundary::HttpHeader | Boundary::GrpcMetadata => escape_http_header(&self.0),
Boundary::LogField => escape_log_field(&self.0),
Boundary::AuditField | Boundary::JsonValue => self.0.clone(),
};
if escaped.len() > Self::MAX_ESCAPED_LEN {
return Err(EscapeError::TooLong {
boundary,
bytes: escaped.len(),
});
}
Ok(EscapedFor {
boundary,
value: escaped,
})
}
}
impl From<String> for Tainted<String> {
fn from(s: String) -> Self {
Tainted(s)
}
}
impl From<&str> for Tainted<String> {
fn from(s: &str) -> Self {
Tainted(s.to_string())
}
}
fn escape_http_header(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'\r' | b'\n' | 0x00 | b'\t' => {
}
_ => out.push(b as char),
}
}
out
}
fn escape_log_field(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
if b < 0x20 || b == 0x7F {
out.push('%');
out.push(hex_nibble(b >> 4));
out.push(hex_nibble(b & 0x0F));
} else {
out.push(b as char);
}
}
out
}
fn hex_nibble(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
10..=15 => (b'A' + (n - 10)) as char,
_ => unreachable!(),
}
}
pub fn audit_safe_log_field(value: &str) -> impl fmt::Display + '_ {
AuditSafeLogField(value)
}
struct AuditSafeLogField<'a>(&'a str);
impl fmt::Display for AuditSafeLogField<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for b in self.0.bytes() {
if b < 0x20 || b == 0x7F {
write!(f, "%{:02X}", b)?;
} else {
f.write_str(std::str::from_utf8(&[b]).unwrap_or("?"))?;
}
}
Ok(())
}
}
pub struct ConnStringSanitizer;
impl ConnStringSanitizer {
pub fn parse(uri: &str) -> Result<ParsedConnString, ConnParseError> {
let target = parse_conn_string(uri)?;
Ok(ParsedConnString { target })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedConnString {
target: ConnectionTarget,
}
impl ParsedConnString {
pub fn target(&self) -> TaintedTarget<'_> {
match &self.target {
ConnectionTarget::Memory => TaintedTarget::Memory,
ConnectionTarget::File { path } => TaintedTarget::File { path },
ConnectionTarget::Grpc { endpoint } => TaintedTarget::Grpc {
endpoint: TaintedRef(endpoint),
},
ConnectionTarget::GrpcCluster {
primary,
replicas,
force_primary,
} => TaintedTarget::GrpcCluster {
primary: TaintedRef(primary),
replicas,
force_primary: *force_primary,
},
ConnectionTarget::Http { base_url } => TaintedTarget::Http {
base_url: TaintedRef(base_url),
},
ConnectionTarget::RedWire { host, port, tls } => TaintedTarget::RedWire {
host: TaintedRef(host),
port: *port,
tls: *tls,
},
}
}
pub fn into_connection_target(self) -> ConnectionTarget {
self.target
}
pub fn as_connection_target(&self) -> &ConnectionTarget {
&self.target
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TaintedRef<'a>(&'a String);
impl<'a> TaintedRef<'a> {
pub fn expose_secret(&self) -> &'a str {
self.0.as_str()
}
pub fn to_owned_tainted(&self) -> Tainted<String> {
Tainted(self.0.clone())
}
pub fn escape_for(&self, boundary: Boundary) -> Result<EscapedFor, EscapeError> {
Tainted(self.0.clone()).escape_for(boundary)
}
}
#[derive(Debug)]
pub enum TaintedTarget<'a> {
Memory,
File {
path: &'a std::path::Path,
},
Grpc {
endpoint: TaintedRef<'a>,
},
GrpcCluster {
primary: TaintedRef<'a>,
replicas: &'a [String],
force_primary: bool,
},
Http {
base_url: TaintedRef<'a>,
},
RedWire {
host: TaintedRef<'a>,
port: u16,
tls: bool,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_strip_crlf_nul_tab() {
let t = Tainted::<String>::from("v1\r\nX-Forged: yes\0\there");
let e = t.escape_for(Boundary::HttpHeader).unwrap();
assert_eq!(e.boundary(), Boundary::HttpHeader);
assert!(!e.as_str().contains('\r'));
assert!(!e.as_str().contains('\n'));
assert!(!e.as_str().contains('\0'));
assert!(!e.as_str().contains('\t'));
assert_eq!(e.as_str(), "v1X-Forged: yeshere");
}
#[test]
fn grpc_metadata_matches_http_header_contract() {
let payload = "alice\r\nx-trace-id: forged";
let h = Tainted::from(payload)
.escape_for(Boundary::HttpHeader)
.unwrap();
let g = Tainted::from(payload)
.escape_for(Boundary::GrpcMetadata)
.unwrap();
assert_eq!(h.as_str(), g.as_str());
}
#[test]
fn log_field_percent_encodes_control_bytes() {
let t = Tainted::<String>::from(
"alice\nlevel=ERROR\rcluster_breach=true\ttab\0nul\x07bel\x1bescape\x7fdel",
);
let e = t.escape_for(Boundary::LogField).unwrap();
let s = e.as_str();
assert!(!s.contains('\n'));
assert!(!s.contains('\r'));
assert!(!s.contains('\0'));
assert!(!s.contains('\t'));
assert!(!s.contains('\x07'));
assert!(!s.contains('\x1b'));
assert!(!s.contains('\x7f'));
assert!(s.contains("%0A"));
assert!(s.contains("%0D"));
assert!(s.contains("%00"));
assert!(s.contains("%09"));
assert!(s.contains("%07"));
assert!(s.contains("%1B"));
assert!(s.contains("%7F"));
}
#[test]
fn audit_and_json_pass_through() {
let raw = "alice\nbob";
let a = Tainted::from(raw).escape_for(Boundary::AuditField).unwrap();
let j = Tainted::from(raw).escape_for(Boundary::JsonValue).unwrap();
assert_eq!(a.as_str(), raw);
assert_eq!(j.as_str(), raw);
assert_eq!(a.boundary(), Boundary::AuditField);
assert_eq!(j.boundary(), Boundary::JsonValue);
}
#[test]
fn audit_safe_log_field_strips_crlf() {
let evil = "alice\nlevel=ERROR cluster_breach=true";
let rendered = format!("{}", audit_safe_log_field(evil));
assert!(!rendered.contains('\n'));
assert!(!rendered.contains('\r'));
assert!(rendered.contains("%0A"));
}
#[test]
fn audit_safe_log_field_matches_log_field_boundary() {
let evil = "user\rname\nrow=1\0nul\x1Besc\x7Fdel";
let helper = format!("{}", audit_safe_log_field(evil));
let typed = Tainted::from(evil)
.escape_for(Boundary::LogField)
.unwrap()
.into_string();
assert_eq!(helper, typed);
}
#[test]
fn tainted_is_not_display() {
let t = Tainted::from("alice\nbob");
let dbg = format!("{:?}", t);
assert!(dbg.contains("\\n"), "Debug must escape control bytes");
}
#[test]
fn parser_round_trip_grpc() {
let parsed = ConnStringSanitizer::parse("grpc://node-1:5055").unwrap();
match parsed.target() {
TaintedTarget::Grpc { endpoint } => {
assert_eq!(endpoint.expose_secret(), "http://node-1:5055");
let h = endpoint.escape_for(Boundary::HttpHeader).unwrap();
assert!(!h.as_str().contains('\n'));
}
other => panic!("unexpected variant: {:?}", other),
}
}
#[test]
fn parser_round_trip_redwire() {
let parsed = ConnStringSanitizer::parse("reds://example.com:9999").unwrap();
match parsed.target() {
TaintedTarget::RedWire { host, port, tls } => {
assert_eq!(host.expose_secret(), "example.com");
assert_eq!(port, 9999);
assert!(tls);
}
other => panic!("unexpected variant: {:?}", other),
}
}
#[test]
fn parser_into_connection_target_compat() {
let parsed = ConnStringSanitizer::parse("memory://").unwrap();
assert_eq!(parsed.into_connection_target(), ConnectionTarget::Memory);
}
#[test]
fn escape_too_long_surfaces_typed_error() {
let big = "a".repeat(Tainted::<String>::MAX_ESCAPED_LEN + 1);
let err = Tainted::from(big.as_str())
.escape_for(Boundary::LogField)
.unwrap_err();
match err {
EscapeError::TooLong { boundary, bytes } => {
assert_eq!(boundary, Boundary::LogField);
assert!(bytes > Tainted::<String>::MAX_ESCAPED_LEN);
}
}
}
}