use crate::category::Category;
use crate::error::{Result, SanitizeError};
use crate::store::MappingStore;
use aho_corasick::AhoCorasick;
use regex::bytes::{Regex, RegexBuilder, RegexSet, RegexSetBuilder};
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::sync::Arc;
const DEFAULT_CHUNK_SIZE: usize = 1024 * 1024;
const DEFAULT_OVERLAP_SIZE: usize = 4096;
const REGEX_SIZE_LIMIT: usize = 1 << 20;
const REGEX_DFA_SIZE_LIMIT: usize = 1 << 20;
const DEFAULT_MAX_PATTERNS: usize = 10_000;
#[derive(Debug, Clone)]
pub struct ScanConfig {
pub chunk_size: usize,
pub overlap_size: usize,
}
impl Default for ScanConfig {
fn default() -> Self {
Self {
chunk_size: DEFAULT_CHUNK_SIZE,
overlap_size: DEFAULT_OVERLAP_SIZE,
}
}
}
impl ScanConfig {
#[must_use]
pub fn new(chunk_size: usize, overlap_size: usize) -> Self {
Self {
chunk_size,
overlap_size,
}
}
pub fn validate(&self) -> Result<()> {
if self.chunk_size == 0 {
return Err(SanitizeError::InvalidConfig(
"chunk_size must be > 0".into(),
));
}
if self.overlap_size >= self.chunk_size {
return Err(SanitizeError::InvalidConfig(
"overlap_size must be < chunk_size".into(),
));
}
Ok(())
}
}
#[inline]
fn compile_err(e: impl std::fmt::Display) -> SanitizeError {
SanitizeError::PatternCompileError(e.to_string())
}
pub struct ScanPattern {
regex: Regex,
category: Category,
label: String,
literal: Option<String>,
}
impl std::fmt::Debug for ScanPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScanPattern")
.field("pattern", &self.regex.as_str())
.field("category", &self.category)
.field("label", &self.label)
.field("literal", &self.literal.as_deref())
.finish()
}
}
impl Clone for ScanPattern {
fn clone(&self) -> Self {
Self {
regex: self.regex.clone(),
category: self.category.clone(),
label: self.label.clone(),
literal: self.literal.clone(),
}
}
}
impl ScanPattern {
pub fn from_regex(pattern: &str, category: Category, label: impl Into<String>) -> Result<Self> {
let regex = RegexBuilder::new(pattern)
.size_limit(REGEX_SIZE_LIMIT)
.dfa_size_limit(REGEX_DFA_SIZE_LIMIT)
.build()
.map_err(compile_err)?;
Ok(Self {
regex,
category,
label: label.into(),
literal: None,
})
}
pub fn from_literal(
literal: &str,
category: Category,
label: impl Into<String>,
) -> Result<Self> {
let escaped = regex::escape(literal);
let regex = RegexBuilder::new(&escaped)
.size_limit(REGEX_SIZE_LIMIT)
.dfa_size_limit(REGEX_DFA_SIZE_LIMIT)
.build()
.map_err(compile_err)?;
Ok(Self {
regex,
category,
label: label.into(),
literal: Some(literal.to_owned()),
})
}
#[must_use]
pub fn category(&self) -> &Category {
&self.category
}
#[must_use]
pub fn label(&self) -> &str {
&self.label
}
#[must_use]
pub fn regex_pattern(&self) -> &str {
self.regex.as_str()
}
}
#[derive(Debug, Clone, Copy)]
struct RawMatch {
start: usize,
end: usize,
pattern_idx: usize,
}
struct ScanScratch {
all_matches: Vec<RawMatch>,
selected: Vec<RawMatch>,
output: Vec<u8>,
pattern_counts: Vec<u64>,
}
impl ScanScratch {
fn new(pattern_count: usize, chunk_size: usize, overlap_size: usize) -> Self {
Self {
all_matches: Vec::new(),
selected: Vec::new(),
output: Vec::with_capacity(chunk_size + overlap_size),
pattern_counts: vec![0u64; pattern_count],
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScanStats {
pub bytes_processed: u64,
pub bytes_output: u64,
pub matches_found: u64,
pub replacements_applied: u64,
pub pattern_counts: HashMap<String, u64>,
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct ScanProgress {
pub bytes_processed: u64,
pub bytes_output: u64,
pub total_bytes: Option<u64>,
pub matches_found: u64,
pub replacements_applied: u64,
}
pub struct StreamScanner {
patterns: Vec<ScanPattern>,
regex_set: RegexSet,
regex_indices: Vec<usize>,
aho_corasick: Option<AhoCorasick>,
literal_indices: Vec<usize>,
store: Arc<MappingStore>,
config: ScanConfig,
}
impl StreamScanner {
pub fn new(
patterns: Vec<ScanPattern>,
store: Arc<MappingStore>,
config: ScanConfig,
) -> Result<Self> {
Self::new_with_max_patterns(patterns, store, config, DEFAULT_MAX_PATTERNS)
}
pub fn new_with_max_patterns(
patterns: Vec<ScanPattern>,
store: Arc<MappingStore>,
config: ScanConfig,
max_patterns: usize,
) -> Result<Self> {
config.validate()?;
if patterns.len() > max_patterns {
return Err(SanitizeError::InvalidConfig(format!(
"pattern count ({}) exceeds maximum allowed ({}) — \
RegexSet memory scales linearly with pattern count",
patterns.len(),
max_patterns
)));
}
let mut literal_bytes: Vec<Vec<u8>> = Vec::new();
let mut literal_indices: Vec<usize> = Vec::new();
let mut regex_strs: Vec<&str> = Vec::new();
let mut regex_indices: Vec<usize> = Vec::new();
for (i, pattern) in patterns.iter().enumerate() {
if let Some(lit) = &pattern.literal {
literal_bytes.push(lit.as_bytes().to_vec());
literal_indices.push(i);
} else {
regex_strs.push(pattern.regex_pattern());
regex_indices.push(i);
}
}
let aho_corasick = if literal_bytes.is_empty() {
None
} else {
Some(
AhoCorasick::new(&literal_bytes)
.map_err(compile_err)?,
)
};
let regex_set = if regex_strs.is_empty() {
RegexSetBuilder::new(Vec::<&str>::new())
.size_limit(REGEX_SIZE_LIMIT)
.dfa_size_limit(REGEX_DFA_SIZE_LIMIT)
.build()
.map_err(compile_err)?
} else {
RegexSetBuilder::new(®ex_strs)
.size_limit(REGEX_SIZE_LIMIT * regex_strs.len().max(1))
.dfa_size_limit(REGEX_DFA_SIZE_LIMIT * regex_strs.len().max(1))
.build()
.map_err(compile_err)?
};
Ok(Self {
patterns,
regex_set,
regex_indices,
aho_corasick,
literal_indices,
store,
config,
})
}
pub fn with_extra_literals(&self, extra: Vec<ScanPattern>) -> Result<Self> {
let mut patterns = self.patterns.clone();
patterns.extend(extra);
Self::new(patterns, Arc::clone(&self.store), self.config.clone())
}
pub fn scan_reader<R: Read, W: Write>(&self, reader: R, writer: W) -> Result<ScanStats> {
self.scan_reader_with_progress(reader, writer, None, |_| {})
}
pub fn scan_reader_with_progress<R: Read, W: Write, F>(
&self,
mut reader: R,
mut writer: W,
total_bytes: Option<u64>,
mut on_progress: F,
) -> Result<ScanStats>
where
F: FnMut(&ScanProgress),
{
let mut stats = ScanStats::default();
let mut carry: Vec<u8> = Vec::new();
let mut read_buf = vec![0u8; self.config.chunk_size];
let mut window: Vec<u8> =
Vec::with_capacity(self.config.chunk_size + self.config.overlap_size);
let mut scratch = ScanScratch::new(
self.patterns.len(),
self.config.chunk_size,
self.config.overlap_size,
);
loop {
let bytes_read = read_fully(&mut reader, &mut read_buf)?;
let is_eof = bytes_read < read_buf.len();
stats.bytes_processed += bytes_read as u64;
if bytes_read == 0 && carry.is_empty() {
break;
}
let new_data = &read_buf[..bytes_read];
window.clear();
window.extend_from_slice(&carry);
window.extend_from_slice(new_data);
if window.is_empty() {
break;
}
self.find_matches(&window, &mut scratch);
let base_commit = if is_eof {
window.len()
} else {
window.len().saturating_sub(self.config.overlap_size)
};
let commit_point =
self.adjusted_commit_point(&scratch.selected, base_commit, window.len(), is_eof);
self.apply_replacements(
&window[..commit_point],
&scratch.selected,
&mut stats,
&mut scratch.output,
&mut scratch.pattern_counts,
)?;
writer
.write_all(&scratch.output)
.map_err(|e| SanitizeError::IoError(e.to_string()))?;
stats.bytes_output += scratch.output.len() as u64;
for (idx, count) in scratch.pattern_counts.iter_mut().enumerate() {
if *count > 0 {
*stats
.pattern_counts
.entry(self.patterns[idx].label.clone())
.or_insert(0) += *count;
*count = 0; }
}
on_progress(&ScanProgress {
bytes_processed: stats.bytes_processed,
bytes_output: stats.bytes_output,
total_bytes,
matches_found: stats.matches_found,
replacements_applied: stats.replacements_applied,
});
if is_eof {
carry.clear();
break;
}
carry.clear();
carry.extend_from_slice(&window[commit_point..]);
}
Ok(stats)
}
pub fn scan_bytes(&self, input: &[u8]) -> Result<(Vec<u8>, ScanStats)> {
self.scan_bytes_with_progress(input, |_| {})
}
pub fn scan_bytes_with_progress<F>(
&self,
input: &[u8],
on_progress: F,
) -> Result<(Vec<u8>, ScanStats)>
where
F: FnMut(&ScanProgress),
{
let mut output = Vec::with_capacity(input.len());
let stats = self.scan_reader_with_progress(
input,
&mut output,
Some(input.len() as u64),
on_progress,
)?;
Ok((output, stats))
}
#[must_use]
pub fn config(&self) -> &ScanConfig {
&self.config
}
#[must_use]
pub fn store(&self) -> &Arc<MappingStore> {
&self.store
}
#[must_use]
pub fn pattern_count(&self) -> usize {
self.patterns.len()
}
pub fn from_encrypted_secrets(
encrypted_bytes: &[u8],
password: &str,
format: Option<crate::secrets::SecretsFormat>,
store: Arc<MappingStore>,
config: ScanConfig,
extra_patterns: Vec<ScanPattern>,
) -> Result<(Self, Vec<(usize, SanitizeError)>)> {
let (mut patterns, warnings) =
crate::secrets::load_encrypted_secrets(encrypted_bytes, password, format)?;
patterns.extend(extra_patterns);
let scanner = Self::new(patterns, store, config)?;
Ok((scanner, warnings))
}
pub fn from_plaintext_secrets(
plaintext: &[u8],
format: Option<crate::secrets::SecretsFormat>,
store: Arc<MappingStore>,
config: ScanConfig,
extra_patterns: Vec<ScanPattern>,
) -> Result<(Self, Vec<(usize, SanitizeError)>)> {
let (mut patterns, warnings) = crate::secrets::load_plaintext_secrets(plaintext, format)?;
patterns.extend(extra_patterns);
let scanner = Self::new(patterns, store, config)?;
Ok((scanner, warnings))
}
fn find_matches(&self, window: &[u8], scratch: &mut ScanScratch) {
scratch.all_matches.clear();
scratch.selected.clear();
if let Some(ac) = &self.aho_corasick {
for mat in ac.find_overlapping_iter(window) {
scratch.all_matches.push(RawMatch {
start: mat.start(),
end: mat.end(),
pattern_idx: self.literal_indices[mat.pattern().as_usize()],
});
}
}
for rs_idx in self.regex_set.matches(window) {
let pattern_idx = self.regex_indices[rs_idx];
for m in self.patterns[pattern_idx].regex.find_iter(window) {
scratch.all_matches.push(RawMatch {
start: m.start(),
end: m.end(),
pattern_idx,
});
}
}
if scratch.all_matches.is_empty() {
return;
}
scratch.all_matches.sort_unstable_by(|a, b| {
a.start
.cmp(&b.start)
.then_with(|| (b.end - b.start).cmp(&(a.end - a.start)))
});
let mut last_end = 0;
for m in scratch.all_matches.drain(..) {
if m.start >= last_end {
last_end = m.end;
scratch.selected.push(m);
}
}
}
#[allow(clippy::unused_self)] fn adjusted_commit_point(
&self,
matches: &[RawMatch],
base_commit: usize,
window_len: usize,
is_eof: bool,
) -> usize {
if is_eof {
return window_len;
}
let mut commit = base_commit;
for m in matches {
if m.start < commit && m.end > commit {
commit = m.end;
}
}
commit.min(window_len)
}
fn apply_replacements(
&self,
committed: &[u8],
matches: &[RawMatch],
stats: &mut ScanStats,
output_buf: &mut Vec<u8>,
pattern_counts: &mut [u64],
) -> Result<()> {
output_buf.clear();
let mut last_end = 0;
for &m in matches {
if m.start >= committed.len() {
break;
}
output_buf.extend_from_slice(&committed[last_end..m.start]);
let matched_text = String::from_utf8_lossy(&committed[m.start..m.end]);
let pattern = &self.patterns[m.pattern_idx];
let replacement = self.store.get_or_insert(&pattern.category, &matched_text)?;
output_buf.extend_from_slice(replacement.as_bytes());
last_end = m.end;
stats.matches_found += 1;
stats.replacements_applied += 1;
pattern_counts[m.pattern_idx] += 1;
}
output_buf.extend_from_slice(&committed[last_end..]);
Ok(())
}
}
const _: fn() = || {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<StreamScanner>();
assert_sync::<StreamScanner>();
};
fn read_fully<R: Read>(reader: &mut R, buf: &mut [u8]) -> Result<usize> {
let mut total = 0;
while total < buf.len() {
match reader.read(&mut buf[total..]) {
Ok(0) => break, Ok(n) => total += n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(SanitizeError::IoError(e.to_string())),
}
}
Ok(total)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::HmacGenerator;
fn test_scanner(patterns: Vec<ScanPattern>) -> StreamScanner {
let gen = Arc::new(HmacGenerator::new([42u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
StreamScanner::new(
patterns,
store,
ScanConfig {
chunk_size: 64,
overlap_size: 16,
},
)
.unwrap()
}
fn email_pattern() -> ScanPattern {
ScanPattern::from_regex(
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
Category::Email,
"email",
)
.unwrap()
}
fn ipv4_pattern() -> ScanPattern {
ScanPattern::from_regex(
r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
Category::IpV4,
"ipv4",
)
.unwrap()
}
#[test]
fn scanner_creation() {
let scanner = test_scanner(vec![email_pattern()]);
assert_eq!(scanner.pattern_count(), 1);
}
#[test]
fn invalid_config_zero_chunk() {
let gen = Arc::new(HmacGenerator::new([0u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let result = StreamScanner::new(vec![], store, ScanConfig::new(0, 0));
assert!(result.is_err());
}
#[test]
fn invalid_config_overlap_ge_chunk() {
let gen = Arc::new(HmacGenerator::new([0u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let result = StreamScanner::new(vec![], store, ScanConfig::new(100, 100));
assert!(result.is_err());
}
#[test]
fn empty_input() {
let scanner = test_scanner(vec![email_pattern()]);
let (output, stats) = scanner.scan_bytes(b"").unwrap();
assert!(output.is_empty());
assert_eq!(stats.matches_found, 0);
assert_eq!(stats.bytes_processed, 0);
}
#[test]
fn no_matches() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"There are no email addresses here.";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(output, input.as_slice());
assert_eq!(stats.matches_found, 0);
}
#[test]
fn single_email_replaced() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"Contact alice@corp.com for help.";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 1);
assert_eq!(stats.replacements_applied, 1);
assert!(!output
.windows(b"alice@corp.com".len())
.any(|w| w == b"alice@corp.com"));
let output_str = String::from_utf8_lossy(&output);
assert!(output_str.contains("@corp.com"));
assert_eq!(output.len(), input.len(), "length must be preserved");
assert!(output_str.starts_with("Contact "));
assert!(output_str.ends_with(" for help."));
}
#[test]
fn multiple_emails_replaced() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"From alice@corp.com to bob@corp.com cc admin@corp.com";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 3);
let out_str = String::from_utf8_lossy(&output);
assert!(!out_str.contains("alice@corp.com"));
assert!(!out_str.contains("bob@corp.com"));
assert!(!out_str.contains("admin@corp.com"));
}
#[test]
fn same_secret_same_replacement() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"First alice@corp.com then alice@corp.com again.";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 2);
let out_str = String::from_utf8_lossy(&output);
let parts: Vec<&str> = out_str.split("@corp.com").collect();
assert_eq!(parts.len(), 3);
}
#[test]
fn literal_pattern_matched() {
let pat = ScanPattern::from_literal(
"SECRET_API_KEY_12345",
Category::Custom("api_key".into()),
"api_key",
)
.unwrap();
let scanner = test_scanner(vec![pat]);
let input = b"key=SECRET_API_KEY_12345&foo=bar";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 1);
assert!(!output
.windows(b"SECRET_API_KEY_12345".len())
.any(|w| w == b"SECRET_API_KEY_12345"));
}
#[test]
fn multiple_pattern_types() {
let scanner = test_scanner(vec![email_pattern(), ipv4_pattern()]);
let input = b"Server 192.168.1.100 contact admin@server.com";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 2);
let out_str = String::from_utf8_lossy(&output);
assert!(!out_str.contains("192.168.1.100"));
assert!(!out_str.contains("admin@server.com"));
assert_eq!(*stats.pattern_counts.get("email").unwrap(), 1);
assert_eq!(*stats.pattern_counts.get("ipv4").unwrap(), 1);
}
#[test]
fn match_at_chunk_boundary() {
let gen = Arc::new(HmacGenerator::new([42u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let scanner = StreamScanner::new(
vec![email_pattern()],
store,
ScanConfig {
chunk_size: 20, overlap_size: 16,
},
)
.unwrap();
let input = b"AAAAAAAAAAAAAAAA alice@corp.com BBBBBBBBBBBBB";
let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.matches_found, 1);
let out_str = String::from_utf8_lossy(&output);
assert!(!out_str.contains("alice@corp.com"));
assert!(out_str.contains("@corp.com"), "domain must be preserved");
}
#[test]
fn large_input_many_chunks() {
let scanner = test_scanner(vec![email_pattern()]);
let mut input = Vec::new();
let filler = b"Lorem ipsum dolor sit amet. ";
for i in 0..20 {
input.extend_from_slice(filler);
let email = format!("user{}@example.com ", i);
input.extend_from_slice(email.as_bytes());
}
let (output, stats) = scanner.scan_bytes(&input).unwrap();
assert_eq!(stats.matches_found, 20);
let out_str = String::from_utf8_lossy(&output);
for i in 0..20 {
let email = format!("user{}@example.com", i);
assert!(!out_str.contains(&email));
}
}
#[test]
fn scan_bytes_with_progress_preserves_output_and_stats() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"Contact alice@corp.com and bob@corp.com for help.";
let (baseline_output, baseline_stats) = scanner.scan_bytes(input).unwrap();
let mut updates = Vec::new();
let (progress_output, progress_stats) = scanner
.scan_bytes_with_progress(input, |progress| updates.push(progress.clone()))
.unwrap();
assert_eq!(progress_output, baseline_output);
assert_eq!(
progress_stats.bytes_processed,
baseline_stats.bytes_processed
);
assert_eq!(progress_stats.bytes_output, baseline_stats.bytes_output);
assert_eq!(progress_stats.matches_found, baseline_stats.matches_found);
assert_eq!(
progress_stats.replacements_applied,
baseline_stats.replacements_applied
);
assert!(!updates.is_empty());
assert_eq!(updates.last().unwrap().bytes_processed, input.len() as u64);
assert_eq!(
updates.last().unwrap().total_bytes,
Some(input.len() as u64)
);
assert_eq!(updates.last().unwrap().matches_found, 2);
}
#[test]
fn scan_reader_with_progress_reports_multiple_updates_for_multi_chunk_input() {
let scanner = test_scanner(vec![email_pattern()]);
let mut input = Vec::new();
for i in 0..8 {
input.extend_from_slice(b"padding padding padding ");
input.extend_from_slice(format!("user{i}@example.com ").as_bytes());
}
let mut output = Vec::new();
let mut updates = Vec::new();
let stats = scanner
.scan_reader_with_progress(
&input[..],
&mut output,
Some(input.len() as u64),
|progress| {
updates.push(progress.clone());
},
)
.unwrap();
assert!(updates.len() >= 2);
assert_eq!(
updates.last().unwrap().bytes_processed,
stats.bytes_processed
);
assert_eq!(updates.last().unwrap().bytes_output, stats.bytes_output);
assert_eq!(
updates.last().unwrap().total_bytes,
Some(input.len() as u64)
);
}
#[test]
fn scan_reader_writer() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"hello alice@corp.com world";
let mut output = Vec::new();
let stats = scanner.scan_reader(&input[..], &mut output).unwrap();
assert_eq!(stats.matches_found, 1);
let out_str = String::from_utf8_lossy(&output);
assert!(out_str.contains("@corp.com"), "domain must be preserved");
}
#[test]
fn invalid_regex_pattern() {
let result = ScanPattern::from_regex("[invalid(", Category::Email, "bad");
assert!(result.is_err());
}
#[test]
fn default_config_valid() {
ScanConfig::default().validate().unwrap();
}
#[test]
fn config_chunk_1_overlap_0() {
let gen = Arc::new(HmacGenerator::new([42u8; 32]));
let store = Arc::new(MappingStore::new(gen, None));
let scanner = StreamScanner::new(vec![], store, ScanConfig::new(1, 0)).unwrap();
let (output, _) = scanner.scan_bytes(b"hello").unwrap();
assert_eq!(output, b"hello");
}
#[test]
fn bytes_output_preserved_on_replacement() {
let scanner = test_scanner(vec![email_pattern()]);
let input = b"a@b.cc"; let (output, stats) = scanner.scan_bytes(input).unwrap();
assert_eq!(stats.bytes_processed, input.len() as u64);
assert_eq!(stats.bytes_output, output.len() as u64);
assert_eq!(output.len(), input.len());
}
}