use crate::config::SafetyConfig;
#[derive(Debug, Clone, PartialEq)]
pub enum SafetyDecision {
Allow,
AllowRedacted(String),
Deny {
reason: String,
},
}
pub struct SafetyGate {
config: SafetyConfig,
}
impl SafetyGate {
pub fn new(config: SafetyConfig) -> Self {
Self { config }
}
pub fn check(&self, content: &str) -> SafetyDecision {
let mut redacted = content.to_string();
let mut was_redacted = false;
if self.config.credit_card_redaction {
let (new_text, found) = redact_credit_cards(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
if self.config.ssn_redaction {
let (new_text, found) = redact_ssns(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
if self.config.pii_detection {
let (new_text, found) = redact_emails(&redacted);
if found {
redacted = new_text;
was_redacted = true;
}
}
for pattern in &self.config.custom_patterns {
if content.contains(pattern.as_str()) {
return SafetyDecision::Deny {
reason: format!("Custom pattern matched: {}", pattern),
};
}
}
if was_redacted {
SafetyDecision::AllowRedacted(redacted)
} else {
SafetyDecision::Allow
}
}
pub fn redact(&self, content: &str) -> String {
match self.check(content) {
SafetyDecision::Allow => content.to_string(),
SafetyDecision::AllowRedacted(redacted) => redacted,
SafetyDecision::Deny { .. } => "[REDACTED]".to_string(),
}
}
}
fn redact_credit_cards(text: &str) -> (String, bool) {
let mut result = String::with_capacity(text.len());
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
let mut found = false;
while i < chars.len() {
if chars[i].is_ascii_digit() {
let start = i;
let mut digit_count = 0;
while i < chars.len()
&& (chars[i].is_ascii_digit() || chars[i] == ' ' || chars[i] == '-')
{
if chars[i].is_ascii_digit() {
digit_count += 1;
}
i += 1;
}
if (13..=16).contains(&digit_count) {
result.push_str("[CC_REDACTED]");
found = true;
} else {
for c in &chars[start..i] {
result.push(*c);
}
}
} else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
fn redact_ssns(text: &str) -> (String, bool) {
let mut result = String::new();
let chars: Vec<char> = text.chars().collect();
let mut found = false;
let mut i = 0;
while i < chars.len() {
if i + 10 < chars.len() && is_ssn_at(&chars, i) {
result.push_str("[SSN_REDACTED]");
found = true;
i += 11; } else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
fn is_ssn_at(chars: &[char], pos: usize) -> bool {
if pos + 10 >= chars.len() {
return false;
}
chars[pos].is_ascii_digit()
&& chars[pos + 1].is_ascii_digit()
&& chars[pos + 2].is_ascii_digit()
&& chars[pos + 3] == '-'
&& chars[pos + 4].is_ascii_digit()
&& chars[pos + 5].is_ascii_digit()
&& chars[pos + 6] == '-'
&& chars[pos + 7].is_ascii_digit()
&& chars[pos + 8].is_ascii_digit()
&& chars[pos + 9].is_ascii_digit()
&& chars[pos + 10].is_ascii_digit()
}
fn redact_emails(text: &str) -> (String, bool) {
let chars: Vec<char> = text.chars().collect();
let len = chars.len();
let mut result = String::with_capacity(text.len());
let mut found = false;
let mut i = 0;
while i < len {
if chars[i] == '@' {
let mut local_start = i;
while local_start > 0 && is_email_local_char(chars[local_start - 1]) {
local_start -= 1;
}
let mut domain_end = i + 1;
let mut has_dot = false;
while domain_end < len && is_email_domain_char(chars[domain_end]) {
if chars[domain_end] == '.' {
has_dot = true;
}
domain_end += 1;
}
while domain_end > i + 1
&& (chars[domain_end - 1] == '.' || chars[domain_end - 1] == '-')
{
if chars[domain_end - 1] == '.' {
has_dot = chars[i + 1..domain_end - 1].contains(&'.');
}
domain_end -= 1;
}
let local_len = i - local_start;
let domain_len = domain_end - (i + 1);
if local_len > 0 && domain_len >= 3 && has_dot {
let already_pushed = i - local_start;
let new_len = result.len() - already_pushed;
result.truncate(new_len);
result.push_str("[EMAIL_REDACTED]");
found = true;
i = domain_end;
} else {
result.push(chars[i]);
i += 1;
}
} else {
result.push(chars[i]);
i += 1;
}
}
(result, found)
}
fn is_email_local_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '+' || c == '-' || c == '_'
}
fn is_email_domain_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '.' || c == '-'
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SafetyConfig;
#[test]
fn test_email_redaction_preserves_tabs() {
let (result, found) = redact_emails("contact\tuser@example.com\there");
assert!(found);
assert_eq!(result, "contact\t[EMAIL_REDACTED]\there");
}
#[test]
fn test_email_redaction_preserves_newlines() {
let (result, found) = redact_emails("contact\nuser@example.com\nhere");
assert!(found);
assert_eq!(result, "contact\n[EMAIL_REDACTED]\nhere");
}
#[test]
fn test_email_redaction_preserves_multi_spaces() {
let (result, found) = redact_emails("contact user@example.com here");
assert!(found);
assert_eq!(result, "contact [EMAIL_REDACTED] here");
}
#[test]
fn test_email_redaction_preserves_mixed_whitespace() {
let (result, found) = redact_emails("contact\t user@example.com\n here");
assert!(found);
assert_eq!(result, "contact\t [EMAIL_REDACTED]\n here");
}
#[test]
fn test_email_redaction_basic() {
let (result, found) = redact_emails("email user@example.com here");
assert!(found);
assert_eq!(result, "email [EMAIL_REDACTED] here");
}
#[test]
fn test_email_redaction_no_email() {
let (result, found) = redact_emails("no email here");
assert!(!found);
assert_eq!(result, "no email here");
}
#[test]
fn test_email_redaction_multiple_emails() {
let (result, found) = redact_emails("a@b.com and c@d.org");
assert!(found);
assert_eq!(result, "[EMAIL_REDACTED] and [EMAIL_REDACTED]");
}
#[test]
fn test_email_redaction_at_start() {
let (result, found) = redact_emails("user@example.com is the contact");
assert!(found);
assert_eq!(result, "[EMAIL_REDACTED] is the contact");
}
#[test]
fn test_email_redaction_at_end() {
let (result, found) = redact_emails("contact: user@example.com");
assert!(found);
assert_eq!(result, "contact: [EMAIL_REDACTED]");
}
#[test]
fn test_safety_gate_email_preserves_whitespace() {
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let decision = gate.check("contact\tuser@example.com\nhere");
match decision {
SafetyDecision::AllowRedacted(redacted) => {
assert_eq!(redacted, "contact\t[EMAIL_REDACTED]\nhere");
}
other => panic!("Expected AllowRedacted, got {:?}", other),
}
}
#[test]
fn test_wasm_routing_matches_native_temporal() {
use crate::search::router::QueryRouter;
use crate::search::router::QueryRoute;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"what did I see yesterday",
"show me last week",
"results from today",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Temporal,
"Native router failed for: {}", q
);
assert_eq!(
route_query(q),
"Temporal",
"WASM router failed for: {}", q
);
}
}
#[test]
fn test_wasm_routing_matches_native_graph() {
use crate::search::router::QueryRouter;
use crate::search::router::QueryRoute;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"documents related to authentication",
"things connected to the API module",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Graph,
"Native router failed for: {}", q
);
assert_eq!(
route_query(q),
"Graph",
"WASM router failed for: {}", q
);
}
}
#[test]
fn test_wasm_routing_matches_native_keyword_short() {
use crate::search::router::QueryRouter;
use crate::search::router::QueryRoute;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"hello",
"rust programming",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Keyword,
"Native router failed for: {}", q
);
assert_eq!(
route_query(q),
"Keyword",
"WASM router failed for: {}", q
);
}
}
#[test]
fn test_wasm_routing_matches_native_keyword_quoted() {
use crate::search::router::QueryRouter;
use crate::search::router::QueryRoute;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let q = "\"exact phrase search\"";
assert_eq!(router.route(q), QueryRoute::Keyword);
assert_eq!(route_query(q), "Keyword");
}
#[test]
fn test_wasm_routing_matches_native_hybrid() {
use crate::search::router::QueryRouter;
use crate::search::router::QueryRoute;
use crate::wasm::helpers::route_query;
let router = QueryRouter::new();
let queries = [
"how to implement authentication in Rust",
"explain how embeddings work",
"something about machine learning",
];
for q in &queries {
assert_eq!(
router.route(q),
QueryRoute::Hybrid,
"Native router failed for: {}", q
);
assert_eq!(
route_query(q),
"Hybrid",
"WASM router failed for: {}", q
);
}
}
#[test]
fn test_wasm_safety_matches_native_cc() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "pay with 4111-1111-1111-1111";
assert!(matches!(gate.check(content), SafetyDecision::AllowRedacted(_)));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_ssn() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "my ssn 123-45-6789";
assert!(matches!(gate.check(content), SafetyDecision::AllowRedacted(_)));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_email() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "email user@example.com here";
assert!(matches!(gate.check(content), SafetyDecision::AllowRedacted(_)));
assert_eq!(safety_classify(content), "redact");
}
#[test]
fn test_wasm_safety_matches_native_custom_deny() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig {
custom_patterns: vec!["password".to_string()],
..Default::default()
};
let gate = SafetyGate::new(config);
let content = "my password is foo";
assert!(matches!(gate.check(content), SafetyDecision::Deny { .. }));
assert_eq!(safety_classify(content), "deny");
}
#[test]
fn test_wasm_safety_matches_native_allow() {
use crate::wasm::helpers::safety_classify;
let config = SafetyConfig::default();
let gate = SafetyGate::new(config);
let content = "the weather is nice";
assert_eq!(gate.check(content), SafetyDecision::Allow);
assert_eq!(safety_classify(content), "allow");
}
#[test]
fn test_mmr_produces_different_order_than_cosine() {
use crate::search::mmr::MmrReranker;
let mmr = MmrReranker::new(0.3);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = vec![
("a".to_string(), 0.95, vec![1.0, 0.0, 0.0, 0.0]),
("b".to_string(), 0.90, vec![0.99, 0.01, 0.0, 0.0]),
("c".to_string(), 0.60, vec![0.0, 1.0, 0.0, 0.0]),
];
let ranked = mmr.rerank(&query, &results, 3);
assert_eq!(ranked.len(), 3);
assert_eq!(ranked[0].0, "a");
assert_eq!(ranked[1].0, "c", "MMR should promote diverse result");
assert_eq!(ranked[2].0, "b");
}
}