use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use crate::engine::{NoopEngine, RedactionEngine};
use crate::label::PrivacyLabel;
use crate::merge::{merge_findings, Finding};
use crate::HARD_RULES;
#[derive(Debug, Clone, PartialEq)]
pub struct RedactionResult {
pub findings: Vec<Finding>,
pub redacted_text: String,
pub risk_signals: RiskSignals,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct RiskSignals {
pub counts_by_label: BTreeMap<PrivacyLabel, u32>,
pub total_risk_delta: u32,
pub has_secret: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum ScanError {
#[error("inference failed: {reason}")]
InferenceFailed {
reason: String,
},
#[error("empty input not allowed (use Option before calling)")]
EmptyInput,
}
pub fn scan_text(input: &str) -> Result<RedactionResult, ScanError> {
scan_text_with_engine(input, &NoopEngine)
}
pub fn scan_text_with_engine(
input: &str,
engine: &dyn RedactionEngine,
) -> Result<RedactionResult, ScanError> {
scan_text_with_engine_with_lang(input, engine, None)
}
pub fn scan_text_with_engine_with_lang(
input: &str,
engine: &dyn RedactionEngine,
lang: Option<&str>,
) -> Result<RedactionResult, ScanError> {
if input.is_empty() {
return Err(ScanError::EmptyInput);
}
let mut model_findings = engine.infer_with_lang(input, lang)?;
for f in &mut model_findings {
f.risk_delta = risk_of(f.kind);
}
let mut hard_findings = collect_hard_findings(input);
hard_findings.extend(collect_url_hard_findings(input));
let merged = merge_findings(&hard_findings, &model_findings);
let redacted_text = build_redacted_text(input, &merged);
let risk_signals = aggregate_risk(&merged);
Ok(RedactionResult {
findings: merged,
redacted_text,
risk_signals,
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EngineStatus {
Ok,
DegradedTimeout,
DegradedError,
}
#[derive(Debug, Clone)]
pub struct BudgetedScanOutcome {
pub result: RedactionResult,
pub status: EngineStatus,
}
pub fn scan_text_with_engine_budgeted(
input: &str,
engine: Arc<dyn RedactionEngine>,
budget: Duration,
) -> Result<BudgetedScanOutcome, ScanError> {
if input.is_empty() {
return Err(ScanError::EmptyInput);
}
let (tx, rx) = std::sync::mpsc::channel();
let input_owned = input.to_string();
let engine_for_thread = Arc::clone(&engine);
std::thread::spawn(move || {
let res = engine_for_thread.infer(&input_owned);
let _ = tx.send(res);
});
let (mut model_findings, status) = match rx.recv_timeout(budget) {
Ok(Ok(findings)) => (findings, EngineStatus::Ok),
Ok(Err(_engine_err)) => {
(Vec::new(), EngineStatus::DegradedError)
}
Err(_recv_timeout) => {
(Vec::new(), EngineStatus::DegradedTimeout)
}
};
for f in &mut model_findings {
f.risk_delta = risk_of(f.kind);
}
let mut hard_findings = collect_hard_findings(input);
hard_findings.extend(collect_url_hard_findings(input));
let merged = merge_findings(&hard_findings, &model_findings);
let redacted_text = build_redacted_text(input, &merged);
let risk_signals = aggregate_risk(&merged);
Ok(BudgetedScanOutcome {
result: RedactionResult {
findings: merged,
redacted_text,
risk_signals,
},
status,
})
}
fn collect_hard_findings(text: &str) -> Vec<Finding> {
let mut out: Vec<Finding> = Vec::new();
for rule in HARD_RULES.iter() {
for m in rule.pattern.find_iter(text) {
out.push(Finding::hard(
rule.name,
(m.start(), m.end()),
risk_of(rule.name),
));
}
}
out
}
fn collect_url_hard_findings(text: &str) -> Vec<Finding> {
use crate::ALL_RULES;
let mut out: Vec<Finding> = Vec::new();
for rule in ALL_RULES.iter() {
if rule.name == "generic_url" || rule.name == "internal_ipv4" {
for m in rule.pattern.find_iter(text) {
out.push(Finding::hard(
rule.name,
(m.start(), m.end()),
risk_of(rule.name),
));
}
}
}
out
}
pub(crate) fn risk_of(kind: &str) -> u32 {
match PrivacyLabel::from_kind(kind) {
Some(PrivacyLabel::Secret) => 25,
Some(PrivacyLabel::Email) | Some(PrivacyLabel::Url) => 10,
Some(_) => 5,
None => 5,
}
}
fn build_redacted_text(input: &str, findings: &[Finding]) -> String {
let mut sorted: Vec<&Finding> = findings.iter().collect();
sorted.sort_by_key(|f| std::cmp::Reverse(f.span.0));
let mut out = input.to_string();
for f in sorted {
let (start, end) = f.span;
if start > end || end > out.len() {
continue;
}
if !out.is_char_boundary(start) || !out.is_char_boundary(end) {
continue;
}
let placeholder = match PrivacyLabel::from_kind(f.kind) {
Some(label) => format!("[REDACTED {}]", label.as_str()),
None => format!("[REDACTED {}]", f.kind),
};
out.replace_range(start..end, &placeholder);
}
out
}
fn aggregate_risk(findings: &[Finding]) -> RiskSignals {
let mut counts: BTreeMap<PrivacyLabel, u32> = BTreeMap::new();
let mut total: u32 = 0;
for f in findings {
total = total.saturating_add(f.risk_delta);
if let Some(label) = PrivacyLabel::from_kind(f.kind) {
*counts.entry(label).or_insert(0) += 1;
}
}
let has_secret = counts.contains_key(&PrivacyLabel::Secret);
RiskSignals {
counts_by_label: counts,
total_risk_delta: total,
has_secret,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::{EngineError, RedactionEngine};
use crate::merge::{Finding, FindingSource};
use std::sync::Mutex;
struct LangCapturingMockEngine {
captured: Mutex<Vec<Option<String>>>,
}
impl LangCapturingMockEngine {
fn new() -> Self {
Self {
captured: Mutex::new(Vec::new()),
}
}
fn captured(&self) -> Vec<Option<String>> {
self.captured.lock().unwrap().clone()
}
}
impl RedactionEngine for LangCapturingMockEngine {
fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
self.captured.lock().unwrap().push(None);
Ok(Vec::new())
}
fn infer_with_lang(
&self,
text: &str,
lang: Option<&str>,
) -> Result<Vec<Finding>, EngineError> {
if lang.is_none() {
return self.infer(text);
}
self.captured.lock().unwrap().push(lang.map(String::from));
Ok(Vec::new())
}
}
#[test]
fn scan_text_with_engine_calls_infer_no_lang() {
let engine = LangCapturingMockEngine::new();
let _ = scan_text_with_engine("hello world test", &engine).unwrap();
let captured = engine.captured();
assert_eq!(captured, vec![None], "legacy 路径应走 infer (lang=None)");
}
#[test]
fn scan_text_with_engine_with_lang_none_equivalent_to_legacy() {
let engine = LangCapturingMockEngine::new();
let _ = scan_text_with_engine_with_lang("hello world test", &engine, None).unwrap();
let captured = engine.captured();
assert_eq!(
captured,
vec![None],
"lang None 应等价 legacy 路径(走 infer)"
);
}
#[test]
fn scan_text_with_engine_with_lang_transports_real_lang() {
let engine = LangCapturingMockEngine::new();
let _ = scan_text_with_engine_with_lang("hello world test", &engine, Some("de")).unwrap();
let captured = engine.captured();
assert_eq!(
captured,
vec![Some("de".to_string())],
"lang Some 应通过 infer_with_lang 透传 caller 字符串"
);
}
#[test]
fn scan_text_with_engine_with_lang_empty_input_fail_closed() {
let engine = LangCapturingMockEngine::new();
let r = scan_text_with_engine_with_lang("", &engine, Some("de"));
assert!(matches!(r, Err(ScanError::EmptyInput)));
assert!(engine.captured().is_empty(), "空输入早返,不应调用 engine");
}
#[test]
fn scan_text_empty_input_fail_closed() {
let r = scan_text("");
assert!(
matches!(r, Err(ScanError::EmptyInput)),
"空输入应返 EmptyInput,实际: {:?}",
r
);
}
#[test]
fn scan_text_secret_variant() {
let text = "log: token = ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ is rotated";
let r = scan_text(text).expect("非空应成功");
assert!(!r.findings.is_empty(), "应命中 github_token");
assert!(r.findings.iter().any(|f| f.kind == "github_token"));
assert!(r.risk_signals.has_secret, "counts 应含 Secret 桶");
assert!(
r.risk_signals
.counts_by_label
.get(&PrivacyLabel::Secret)
.copied()
.unwrap_or(0)
>= 1
);
assert!(!r
.redacted_text
.contains("ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"));
assert!(r.redacted_text.contains("[REDACTED secret]"));
}
#[test]
fn scan_text_email_via_model_mock() {
let model_email = Finding::model("private_email", (8, 28), 0.99, 10);
let merged = merge_findings(&[], &[model_email]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Email).copied(),
Some(1)
);
}
#[test]
fn scan_text_url_variant() {
assert_eq!(
PrivacyLabel::from_kind("internal_ipv4"),
Some(PrivacyLabel::Url)
);
let ip = Finding::model("private_url", (12, 25), 0.95, 10);
let merged = merge_findings(&[], &[ip]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Url).copied(),
Some(1)
);
}
#[test]
fn scan_text_person_via_model_mock() {
let f = Finding::model("private_person", (0, 13), 0.9, 5);
let merged = merge_findings(&[], &[f]);
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].source, FindingSource::Model);
assert_eq!(
PrivacyLabel::from_kind(merged[0].kind),
Some(PrivacyLabel::Person)
);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Person).copied(),
Some(1)
);
}
#[test]
fn scan_text_phone_via_model_mock() {
let f = Finding::model("private_phone", (5, 18), 0.88, 5);
let merged = merge_findings(&[], &[f]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Phone).copied(),
Some(1)
);
}
#[test]
fn scan_text_address_via_model_mock() {
let f = Finding::model("private_address", (10, 50), 0.91, 5);
let merged = merge_findings(&[], &[f]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Address).copied(),
Some(1)
);
}
#[test]
fn scan_text_date_via_model_mock() {
let f = Finding::model("private_date", (20, 30), 0.96, 5);
let merged = merge_findings(&[], &[f]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals.counts_by_label.get(&PrivacyLabel::Date).copied(),
Some(1)
);
}
#[test]
fn scan_text_account_number_via_model_mock() {
let f = Finding::model("private_account_number", (0, 16), 0.97, 5);
let merged = merge_findings(&[], &[f]);
let signals = aggregate_risk(&merged);
assert_eq!(
signals
.counts_by_label
.get(&PrivacyLabel::AccountNumber)
.copied(),
Some(1)
);
}
#[test]
fn scan_text_roundtrip_redacts_all_findings() {
let token = "ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ";
let anth = "sk-ant-api03_ABCDEFGHIJKLMNOPQRSTUVWX";
let text = format!("one {token} two {anth} three");
let r = scan_text(&text).expect("非空");
assert!(
r.findings.len() >= 2,
"至少 2 条 finding,实际 {}: {:?}",
r.findings.len(),
r.findings
);
assert!(
!r.redacted_text.contains(token),
"github token 原文泄漏:{}",
r.redacted_text
);
assert!(
!r.redacted_text.contains(anth),
"anthropic key 原文泄漏:{}",
r.redacted_text
);
let count = r.redacted_text.matches("[REDACTED secret]").count();
assert!(count >= 2, "应至少 2 处 [REDACTED secret],实际 {count}");
assert!(r.risk_signals.has_secret);
}
#[test]
fn scan_text_risk_signals_no_double_weighting() {
let hard = vec![Finding::hard("email", (10, 30), 10)];
let model = vec![Finding::model("private_email", (10, 30), 1.0, 10)];
let merged = merge_findings(&hard, &model);
assert_eq!(merged.len(), 1, "同 span 重叠应只留 Hard");
let signals = aggregate_risk(&merged);
assert_eq!(
signals.total_risk_delta, 10,
"重叠时只计 Hard 一次,不应累加到 20"
);
let model2 = vec![Finding::model("private_email", (100, 120), 1.0, 10)];
let merged2 = merge_findings(&hard, &model2);
let s2 = aggregate_risk(&merged2);
assert_eq!(s2.total_risk_delta, 20, "非重叠时应 Hard + Model 累加");
}
#[test]
fn scan_text_v03_public_api_intact() {
let v = serde_json::json!({"token": "ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"});
let (redacted, summary) = crate::redact(&v);
let s = serde_json::to_string(&redacted).expect("ser");
assert!(!s.contains("ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"));
assert!(summary.contains("finding:"));
let out = crate::scrub_text("token = ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ");
assert!(!out.contains("ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"));
assert!(out.contains("[REDACTED"));
let names = crate::scan_hard_findings("x = ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ");
assert!(names.contains(&"github_token"));
assert_eq!(
crate::detect_hard_secret("x=ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"),
Some("github_token")
);
assert_eq!(crate::detect_hard_secret("hello"), None);
assert_eq!(crate::ITERATION, "I01");
}
#[test]
fn build_redacted_text_out_of_bounds_span_is_skipped() {
let bad_finding = Finding::hard("github_token", (100, 200), 25);
let out = build_redacted_text("short text", &[bad_finding]);
assert_eq!(out, "short text", "越界 span 应跳过,原文不变");
}
#[test]
fn build_redacted_text_non_char_boundary_is_skipped() {
let text = "你好 world";
let bad = Finding::model("private_person", (1, 5), 0.9, 5);
let out = build_redacted_text(text, &[bad]);
assert_eq!(out, text);
}
#[test]
fn risk_of_tiers_match_adr_0012() {
assert_eq!(risk_of("github_token"), 25, "Secret = 25");
assert_eq!(risk_of("anthropic_api_key"), 25);
assert_eq!(risk_of("email"), 10, "Email = 10");
assert_eq!(risk_of("internal_ipv4"), 10, "Url = 10");
assert_eq!(risk_of("private_person"), 5, "Person = 5");
assert_eq!(risk_of("private_date"), 5, "Date = 5");
assert_eq!(risk_of("not_a_kind"), 5, "未知 kind 保守 5");
}
struct SleepyEngine {
dur: Duration,
}
impl RedactionEngine for SleepyEngine {
fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
std::thread::sleep(self.dur);
Ok(Vec::new())
}
}
struct ErrorEngine;
impl RedactionEngine for ErrorEngine {
fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
Err(EngineError::InferRun("mock failure".to_string()))
}
}
#[test]
fn budgeted_scan_within_budget_returns_ok() {
let engine: Arc<dyn RedactionEngine> = Arc::new(NoopEngine);
let outcome = scan_text_with_engine_budgeted(
"token=ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ",
engine,
Duration::from_millis(500),
)
.expect("budgeted scan should succeed within budget");
assert_eq!(outcome.status, EngineStatus::Ok, "NoopEngine 立即返,应 Ok");
assert!(
outcome
.result
.findings
.iter()
.any(|f| f.kind == "github_token"),
"Hard 路径应命中 github_token"
);
}
#[test]
fn budgeted_scan_timeout_degrades_to_hardonly() {
let engine: Arc<dyn RedactionEngine> = Arc::new(SleepyEngine {
dur: Duration::from_millis(500), });
let outcome = scan_text_with_engine_budgeted(
"token=ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ",
engine,
Duration::from_millis(50), )
.expect("budgeted scan should return outcome even on timeout");
assert_eq!(
outcome.status,
EngineStatus::DegradedTimeout,
"engine 500ms vs budget 50ms 应触发 timeout"
);
assert!(
outcome
.result
.findings
.iter()
.any(|f| f.kind == "github_token"),
"DegradedTimeout 路径下 Hard secret 必须仍命中(fail-closed bottom line)"
);
}
#[test]
fn budgeted_scan_engine_error_degrades_to_hardonly() {
let engine: Arc<dyn RedactionEngine> = Arc::new(ErrorEngine);
let outcome = scan_text_with_engine_budgeted(
"token=ghp_abcdefghijklmnopqrstuvwxyzABCDEFGHIJ",
engine,
Duration::from_millis(500),
)
.expect("budgeted scan should not propagate engine error");
assert_eq!(
outcome.status,
EngineStatus::DegradedError,
"engine InferRun 应触发 DegradedError"
);
assert!(
outcome
.result
.findings
.iter()
.any(|f| f.kind == "github_token"),
"DegradedError 路径下 Hard secret 必须仍命中"
);
}
#[test]
fn r1a_scan_text_generic_url_hard_match() {
let text = "Visit https://api.example.com/v1/users for docs.";
let r = scan_text(text).expect("scan ok");
let url_findings: Vec<_> = r
.findings
.iter()
.filter(|f| crate::PrivacyLabel::from_kind(f.kind) == Some(crate::PrivacyLabel::Url))
.collect();
assert!(
!url_findings.is_empty(),
"generic_url 应命中,实际 findings={:?}",
r.findings
);
assert!(
url_findings.iter().any(|f| f.span.0 == 6),
"url span 应起 idx=6,实际: {:?}",
url_findings
);
}
#[test]
fn r1a_internal_ipv4_still_matches() {
let text = "Server at 10.0.0.5 in cluster.";
let r = scan_text(text).expect("scan ok");
let url_findings: Vec<_> = r
.findings
.iter()
.filter(|f| crate::PrivacyLabel::from_kind(f.kind) == Some(crate::PrivacyLabel::Url))
.collect();
assert!(
!url_findings.is_empty(),
"internal_ipv4 应继续命中,findings={:?}",
r.findings
);
}
#[test]
fn r1a_generic_url_risk_delta_is_10() {
assert_eq!(
risk_of("generic_url"),
10,
"generic_url 应走 Url canonical risk = 10(ADR 0012 §1.3)"
);
}
#[test]
fn budgeted_scan_empty_input_fail_closed() {
let engine: Arc<dyn RedactionEngine> = Arc::new(NoopEngine);
let r = scan_text_with_engine_budgeted("", engine, Duration::from_millis(500));
assert!(
matches!(r, Err(ScanError::EmptyInput)),
"空输入应返 EmptyInput,实际: {:?}",
r
);
}
}