use std::collections::HashSet;
use std::fmt::Write as _;
use std::sync::LazyLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct ExfiltrationGuardConfig {
#[serde(default = "default_true")]
pub block_markdown_images: bool,
#[serde(default = "default_true")]
pub validate_tool_urls: bool,
#[serde(default = "default_true")]
pub guard_memory_writes: bool,
}
impl Default for ExfiltrationGuardConfig {
fn default() -> Self {
Self {
block_markdown_images: true,
validate_tool_urls: true,
guard_memory_writes: true,
}
}
}
static MARKDOWN_IMAGE_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"!\[([^\]]*)\]\((https?://[^)]+)\)").expect("valid MARKDOWN_IMAGE_RE")
});
static REFERENCE_DEF_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"(?m)^\[([^\]]+)\]:\s*(https?://\S+)").expect("valid REFERENCE_DEF_RE")
});
static REFERENCE_USAGE_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"!\[([^\]]*)\]\[([^\]]+)\]").expect("valid REFERENCE_USAGE_RE"));
static URL_EXTRACT_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#"https?://[^\s"'<>]+"#).expect("valid URL_EXTRACT_RE"));
#[derive(Debug, Clone, PartialEq)]
pub enum ExfiltrationEvent {
MarkdownImageBlocked { url: String },
SuspiciousToolUrl { url: String, tool_name: String },
MemoryWriteGuarded { reason: String },
}
#[derive(Debug, Clone)]
pub struct ExfiltrationGuard {
config: ExfiltrationGuardConfig,
}
impl ExfiltrationGuard {
#[must_use]
pub fn new(config: ExfiltrationGuardConfig) -> Self {
Self { config }
}
#[must_use]
pub fn scan_output(&self, text: &str) -> (String, Vec<ExfiltrationEvent>) {
if !self.config.block_markdown_images {
return (text.to_owned(), vec![]);
}
let mut events = Vec::new();
let mut result = text.to_owned();
let mut replacement = String::new();
let mut last_end = 0usize;
for cap in MARKDOWN_IMAGE_RE.captures_iter(text) {
let m = cap.get(0).expect("full match");
let raw_url = cap.get(2).expect("url group").as_str();
let url = percent_decode_url(raw_url);
if is_external_url(&url) {
replacement.push_str(&text[last_end..m.start()]);
let _ = write!(replacement, "[image removed: {url}]");
last_end = m.end();
events.push(ExfiltrationEvent::MarkdownImageBlocked { url });
}
}
if !events.is_empty() || last_end > 0 {
replacement.push_str(&text[last_end..]);
result = replacement;
}
let mut ref_defs: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
for cap in REFERENCE_DEF_RE.captures_iter(&result) {
let label = cap.get(1).expect("label").as_str().to_lowercase();
let raw_url = cap.get(2).expect("url").as_str();
let url = percent_decode_url(raw_url);
if is_external_url(&url) {
ref_defs.insert(label, url);
}
}
if !ref_defs.is_empty() {
let mut cleaned = String::with_capacity(result.len());
let mut last_end = 0usize;
for cap in REFERENCE_USAGE_RE.captures_iter(&result) {
let m = cap.get(0).expect("full match");
let label = cap.get(2).expect("label").as_str().to_lowercase();
if let Some(url) = ref_defs.get(&label) {
cleaned.push_str(&result[last_end..m.start()]);
let _ = write!(cleaned, "[image removed: {url}]");
last_end = m.end();
events.push(ExfiltrationEvent::MarkdownImageBlocked { url: url.clone() });
}
}
cleaned.push_str(&result[last_end..]);
result = cleaned;
let mut def_cleaned = String::with_capacity(result.len());
for line in result.split('\n') {
let mut keep = true;
for cap in REFERENCE_DEF_RE.captures_iter(line) {
let label = cap.get(1).expect("label").as_str().to_lowercase();
if ref_defs.contains_key(&label) {
keep = false;
break;
}
}
if keep {
def_cleaned.push_str(line);
def_cleaned.push('\n');
}
}
if !text.ends_with('\n') && def_cleaned.ends_with('\n') {
def_cleaned.pop();
}
result = def_cleaned;
}
(result, events)
}
#[must_use]
pub fn validate_tool_call(
&self,
tool_name: &str,
args_json: &str,
flagged_urls: &HashSet<String>,
) -> Vec<ExfiltrationEvent> {
if !self.config.validate_tool_urls || flagged_urls.is_empty() {
return vec![];
}
let parsed: serde_json::Value = match serde_json::from_str(args_json) {
Ok(v) => v,
Err(_) => {
return Self::scan_raw_args(tool_name, args_json, flagged_urls);
}
};
let mut events = Vec::new();
let mut strings = Vec::new();
collect_strings(&parsed, &mut strings);
for s in &strings {
for url_match in URL_EXTRACT_RE.find_iter(s) {
let url = url_match.as_str();
if flagged_urls.contains(url) {
events.push(ExfiltrationEvent::SuspiciousToolUrl {
url: url.to_owned(),
tool_name: tool_name.to_owned(),
});
}
}
}
events
}
#[must_use]
pub fn should_guard_memory_write(
&self,
has_injection_flags: bool,
) -> Option<ExfiltrationEvent> {
if !self.config.guard_memory_writes || !has_injection_flags {
return None;
}
Some(ExfiltrationEvent::MemoryWriteGuarded {
reason: "content contained injection patterns flagged by ContentSanitizer".to_owned(),
})
}
fn scan_raw_args(
tool_name: &str,
args: &str,
flagged_urls: &HashSet<String>,
) -> Vec<ExfiltrationEvent> {
URL_EXTRACT_RE
.find_iter(args)
.filter(|m| flagged_urls.contains(m.as_str()))
.map(|m| ExfiltrationEvent::SuspiciousToolUrl {
url: m.as_str().to_owned(),
tool_name: tool_name.to_owned(),
})
.collect()
}
}
#[must_use]
pub fn extract_flagged_urls(content: &str) -> HashSet<String> {
URL_EXTRACT_RE
.find_iter(content)
.map(|m| m.as_str().to_owned())
.collect()
}
fn percent_decode_url(raw: &str) -> String {
let mut out = String::with_capacity(raw.len());
let bytes = raw.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%'
&& i + 2 < bytes.len()
&& let (Some(hi), Some(lo)) = (
(bytes[i + 1] as char).to_digit(16),
(bytes[i + 2] as char).to_digit(16),
)
{
#[allow(clippy::cast_possible_truncation)]
let byte = ((hi << 4) | lo) as u8;
out.push(byte as char);
i += 3;
continue;
}
out.push(bytes[i] as char);
i += 1;
}
out
}
fn is_external_url(url: &str) -> bool {
url.starts_with("http://") || url.starts_with("https://")
}
fn collect_strings<'a>(value: &'a serde_json::Value, out: &mut Vec<&'a str>) {
match value {
serde_json::Value::String(s) => out.push(s.as_str()),
serde_json::Value::Array(arr) => {
for v in arr {
collect_strings(v, out);
}
}
serde_json::Value::Object(map) => {
for v in map.values() {
collect_strings(v, out);
}
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn guard() -> ExfiltrationGuard {
ExfiltrationGuard::new(ExfiltrationGuardConfig::default())
}
fn guard_disabled() -> ExfiltrationGuard {
ExfiltrationGuard::new(ExfiltrationGuardConfig {
block_markdown_images: false,
validate_tool_urls: false,
guard_memory_writes: false,
})
}
#[test]
fn strips_external_inline_image() {
let (cleaned, events) =
guard().scan_output("Before  after");
assert_eq!(
cleaned,
"Before [image removed: https://evil.com/p.gif] after"
);
assert_eq!(events.len(), 1);
assert!(
matches!(&events[0], ExfiltrationEvent::MarkdownImageBlocked { url } if url == "https://evil.com/p.gif")
);
}
#[test]
fn preserves_local_image() {
let text = "Look:  — local";
let (cleaned, events) = guard().scan_output(text);
assert_eq!(cleaned, text);
assert!(events.is_empty());
}
#[test]
fn preserves_data_uri() {
let text = "Inline: ";
let (cleaned, events) = guard().scan_output(text);
assert_eq!(cleaned, text);
assert!(events.is_empty());
}
#[test]
fn strips_multiple_external_images() {
let text = " text ";
let (cleaned, events) = guard().scan_output(text);
assert!(
!cleaned.contains(",
"first image syntax must be removed: {cleaned}"
);
assert!(
!cleaned.contains(",
"second image syntax must be removed: {cleaned}"
);
assert_eq!(events.len(), 2);
}
#[test]
fn scan_output_noop_when_disabled() {
let text = "";
let (cleaned, events) = guard_disabled().scan_output(text);
assert_eq!(cleaned, text);
assert!(events.is_empty());
}
#[test]
fn strips_reference_style_image() {
let text = "Here is the image: ![alt][ref]\n[ref]: https://evil.com/track.gif\nend";
let (cleaned, events) = guard().scan_output(text);
assert!(
!cleaned.contains("![alt][ref]"),
"image usage syntax must be removed: {cleaned}"
);
assert!(
!cleaned.contains("[ref]:"),
"reference definition must be removed: {cleaned}"
);
assert!(
cleaned.contains("[image removed:"),
"replacement label must be present: {cleaned}"
);
assert!(!events.is_empty(), "must generate event");
}
#[test]
fn preserves_local_reference_image() {
let text = "![alt][ref]\n[ref]: ./local.png\n";
let (cleaned, events) = guard().scan_output(text);
assert_eq!(cleaned, text);
assert!(events.is_empty());
}
#[test]
fn decodes_percent_encoded_url_in_inline_image() {
let text = "";
let (cleaned, _events) = guard().scan_output(text);
assert_eq!(
cleaned, text,
"percent-encoded scheme not detected by inline regex"
);
let normal = "";
let (normal_cleaned, normal_events) = guard().scan_output(normal);
assert!(
!normal_cleaned.contains(",
"normal URL must be removed"
);
assert_eq!(normal_events.len(), 1);
}
#[test]
fn empty_alt_text_still_blocked() {
let text = "";
let (cleaned, events) = guard().scan_output(text);
assert!(
!cleaned.contains(",
"markdown image syntax must be removed: {cleaned}"
);
assert!(
cleaned.contains("[image removed:"),
"replacement label must be present: {cleaned}"
);
assert_eq!(events.len(), 1);
}
#[test]
fn detects_flagged_url_in_json_string() {
let mut flagged = HashSet::new();
flagged.insert("https://evil.com/payload".to_owned());
let args = r#"{"url": "https://evil.com/payload"}"#;
let events = guard().validate_tool_call("fetch", args, &flagged);
assert_eq!(events.len(), 1);
assert!(
matches!(&events[0], ExfiltrationEvent::SuspiciousToolUrl { url, tool_name }
if url == "https://evil.com/payload" && tool_name == "fetch")
);
}
#[test]
fn no_event_when_url_not_flagged() {
let mut flagged = HashSet::new();
flagged.insert("https://other.com/benign".to_owned());
let args = r#"{"url": "https://legitimate.com/page"}"#;
let events = guard().validate_tool_call("fetch", args, &flagged);
assert!(events.is_empty());
}
#[test]
fn validate_tool_call_noop_when_disabled() {
let mut flagged = HashSet::new();
flagged.insert("https://evil.com/x".to_owned());
let args = r#"{"url": "https://evil.com/x"}"#;
let events = guard_disabled().validate_tool_call("fetch", args, &flagged);
assert!(events.is_empty());
}
#[test]
fn validate_tool_call_noop_with_empty_flagged() {
let args = r#"{"url": "https://evil.com/x"}"#;
let events = guard().validate_tool_call("fetch", args, &HashSet::new());
assert!(events.is_empty());
}
#[test]
fn extracts_urls_from_nested_json() {
let mut flagged = HashSet::new();
flagged.insert("https://evil.com/deep".to_owned());
let args = r#"{"nested": {"inner": ["https://evil.com/deep"]}}"#;
let events = guard().validate_tool_call("tool", args, &flagged);
assert_eq!(events.len(), 1);
}
#[test]
fn handles_escaped_slashes_in_json() {
let mut flagged = HashSet::new();
flagged.insert("https://evil.com/path".to_owned());
let args = r#"{"url": "https:\/\/evil.com\/path"}"#;
let parsed: serde_json::Value = serde_json::from_str(args).unwrap();
assert_eq!(parsed["url"], "https://evil.com/path");
let events = guard().validate_tool_call("fetch", args, &flagged);
assert_eq!(events.len(), 1, "JSON-escaped URL must be caught");
}
#[test]
fn guards_when_injection_flags_set() {
let event = guard().should_guard_memory_write(true);
assert!(event.is_some());
assert!(matches!(
event.unwrap(),
ExfiltrationEvent::MemoryWriteGuarded { .. }
));
}
#[test]
fn passes_when_no_injection_flags() {
let event = guard().should_guard_memory_write(false);
assert!(event.is_none());
}
#[test]
fn guard_memory_write_noop_when_disabled() {
let event = guard_disabled().should_guard_memory_write(true);
assert!(event.is_none());
}
#[test]
fn percent_decode_roundtrip() {
assert_eq!(
percent_decode_url("https://example.com"),
"https://example.com"
);
assert_eq!(
percent_decode_url("%68ttps://example.com"),
"https://example.com"
);
assert_eq!(percent_decode_url("hello%20world"), "hello world");
}
#[test]
fn extracts_urls_from_plain_text() {
let content = "check https://evil.com/x and https://other.com/y for details";
let urls = extract_flagged_urls(content);
assert!(urls.contains("https://evil.com/x"));
assert!(urls.contains("https://other.com/y"));
}
}