use arc_swap::ArcSwap;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::sync::Arc;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MaskStrategy {
Redact,
Hash,
Last4,
Drop,
}
impl MaskStrategy {
fn parse(s: &str) -> Option<Self> {
match s {
"redact" => Some(Self::Redact),
"hash" => Some(Self::Hash),
"last4" => Some(Self::Last4),
"drop" => Some(Self::Drop),
_ => None,
}
}
}
#[derive(Clone, Debug)]
pub enum PathSeg {
Key(String),
Any,
}
#[derive(Clone, Debug)]
pub struct MaskRule {
pub path: Vec<PathSeg>,
pub strategy: MaskStrategy,
}
impl MaskRule {
pub fn parse(spec: &str) -> Option<Self> {
let (path_str, strategy) = match spec.rsplit_once(':') {
Some((p, s)) => (p, MaskStrategy::parse(s)?),
None => (spec, MaskStrategy::Redact),
};
if path_str.is_empty() {
return None;
}
let path = path_str
.split('.')
.map(|seg| {
if seg == "*" {
PathSeg::Any
} else {
PathSeg::Key(seg.to_owned())
}
})
.collect();
Some(Self { path, strategy })
}
}
pub struct MaskingPolicy {
pub version: u64,
rules: Vec<MaskRule>,
}
impl MaskingPolicy {
pub fn new(version: u64) -> Self {
Self {
version,
rules: Vec::new(),
}
}
pub fn field(mut self, spec: &str) -> Self {
let rule =
MaskRule::parse(spec).unwrap_or_else(|| panic!("invalid mask field spec: {spec:?}"));
self.rules.push(rule);
self
}
pub fn rules(&self) -> &[MaskRule] {
&self.rules
}
}
pub struct Masker {
policy: ArcSwap<MaskingPolicy>,
}
impl Masker {
pub fn new(initial: MaskingPolicy) -> Self {
Self {
policy: ArcSwap::from_pointee(initial),
}
}
pub fn reload(&self, next: MaskingPolicy) {
let current = self.policy.load().version;
if next.version <= current {
tracing::warn!(
current,
offered = next.version,
"ignoring stale masking policy reload"
);
return;
}
tracing::info!(version = next.version, "masking policy reloaded (live)");
self.policy.store(Arc::new(next));
}
pub fn version(&self) -> u64 {
self.policy.load().version
}
pub fn apply(&self, value: &mut Value) -> bool {
let policy = self.policy.load();
let mut touched = false;
for rule in &policy.rules {
touched |= apply_rule(value, &rule.path, rule.strategy);
}
touched
}
pub fn apply_with(&self, value: &mut Value, extra: &[MaskRule]) -> bool {
let mut touched = self.apply(value);
for rule in extra {
touched |= apply_rule(value, &rule.path, rule.strategy);
}
touched
}
}
fn apply_rule(v: &mut Value, path: &[PathSeg], strategy: MaskStrategy) -> bool {
match path.split_first() {
None => {
if v.is_null() {
return false;
}
*v = mask_leaf(v, strategy);
true
}
Some((PathSeg::Key(k), rest)) => match v {
Value::Object(map) => map
.get_mut(k)
.map(|child| apply_rule(child, rest, strategy))
.unwrap_or(false),
Value::Array(items) => items
.iter_mut()
.fold(false, |acc, item| acc | apply_rule(item, path, strategy)),
_ => false,
},
Some((PathSeg::Any, rest)) => match v {
Value::Array(items) => items
.iter_mut()
.fold(false, |acc, item| acc | apply_rule(item, rest, strategy)),
Value::Object(map) => map
.values_mut()
.fold(false, |acc, child| acc | apply_rule(child, rest, strategy)),
_ => false,
},
}
}
fn mask_leaf(v: &Value, strategy: MaskStrategy) -> Value {
match strategy {
MaskStrategy::Drop => Value::Null,
MaskStrategy::Hash => {
let raw = match v {
Value::String(s) => s.clone(),
other => other.to_string(),
};
let digest = Sha256::digest(raw.as_bytes());
Value::String(digest.iter().map(|b| format!("{b:02x}")).collect())
}
MaskStrategy::Last4 => {
let raw = match v {
Value::String(s) => s.clone(),
other => other.to_string(),
};
let chars: Vec<char> = raw.chars().collect();
let keep = chars.len().min(4);
let masked: String = std::iter::repeat_n('*', chars.len().saturating_sub(keep))
.chain(chars[chars.len() - keep..].iter().copied())
.collect();
Value::String(masked)
}
MaskStrategy::Redact => Value::String(redact_string(&match v {
Value::String(s) => s.clone(),
other => other.to_string(),
})),
}
}
fn redact_string(s: &str) -> String {
if let Some((local, domain)) = s.split_once('@') {
if let Some((host, tld)) = domain.rsplit_once('.') {
return format!(
"{}***@{}***.{}",
local.chars().next().unwrap_or('*'),
host.chars().next().unwrap_or('*'),
tld,
);
}
}
match s.chars().next() {
Some(c) => format!("{c}***"),
None => "***".to_owned(),
}
}
const MAX_MASKED_BODY: usize = 256 * 1024;
#[doc(hidden)]
pub async fn mask_response(
ctx: &crate::web::context::RequestContext,
fields: &'static [&'static str],
resp: axum::response::Response,
) -> axum::response::Response {
let Some(masker) = ctx.try_inject::<Masker>() else {
return resp;
};
let is_json = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.starts_with("application/json"))
.unwrap_or(false);
if !is_json {
return resp;
}
let (parts, body) = resp.into_parts();
let bytes = match axum::body::to_bytes(body, MAX_MASKED_BODY).await {
Ok(b) => b,
Err(_) => return axum::response::Response::from_parts(parts, axum::body::Body::empty()),
};
let Ok(mut value) = serde_json::from_slice::<Value>(&bytes) else {
return axum::response::Response::from_parts(parts, axum::body::Body::from(bytes));
};
let extra: Vec<MaskRule> = fields.iter().filter_map(|f| MaskRule::parse(f)).collect();
if masker.apply_with(&mut value, &extra) {
metrics::counter!("masked_responses_total").increment(1);
}
let masked = serde_json::to_vec(&value).unwrap_or_else(|_| bytes.to_vec());
let mut parts = parts;
parts.headers.remove("content-length"); axum::response::Response::from_parts(parts, axum::body::Body::from(masked))
}