Skip to main content

output_sanitize_rs/
lib.rs

1//! # output-sanitize-rs
2//!
3//! Strip dangerous HTML/SQL/shell snippets from LLM output before they
4//! reach a render path, a query, or a shell. Rust port of
5//! [`@mukundakatta/llm-output-sanitizer`](https://www.npmjs.com/package/@mukundakatta/llm-output-sanitizer).
6//!
7//! ## Example
8//!
9//! ```
10//! use output_sanitize_rs::sanitize;
11//! let r = sanitize("Hello <script>steal()</script>", Default::default());
12//! assert!(!r.safe);
13//! assert!(!r.text.contains("<script"));
14//! ```
15
16#![deny(missing_docs)]
17
18/// Output target.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum Sink {
21    /// No additional escaping (default).
22    #[default]
23    Markdown,
24    /// HTML-escape `<`, `>`, `&` after stripping.
25    Html,
26}
27
28/// One detection that was rewritten.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct Finding {
31    /// Category (`html`, `sql`, `shell`).
32    pub kind: &'static str,
33    /// The original matched substring.
34    pub matched: String,
35}
36
37/// Result of one sanitize pass.
38#[derive(Debug, Clone)]
39pub struct SanitizeResult {
40    /// True when no findings fired.
41    pub safe: bool,
42    /// Rewritten text.
43    pub text: String,
44    /// What was removed.
45    pub findings: Vec<Finding>,
46}
47
48/// Run the sanitizer.
49pub fn sanitize(text: &str, sink: Sink) -> SanitizeResult {
50    let mut findings = Vec::new();
51    let mut out = String::with_capacity(text.len());
52    let lower = text.to_ascii_lowercase();
53    let lower_bytes = lower.as_bytes();
54    let original_bytes = text.as_bytes();
55
56    let mut i = 0;
57    while i < original_bytes.len() {
58        // HTML: <script>, <iframe>, <object>, <embed>, <form>, <meta>, <link>
59        let tags = ["script", "iframe", "object", "embed", "form", "meta", "link"];
60        let mut matched = false;
61        if original_bytes[i] == b'<' {
62            for tag in tags {
63                let open = format!("<{}", tag);
64                let close = format!("</{}", tag);
65                if lower[i..].starts_with(&open) || lower[i..].starts_with(&close) {
66                    // Find the closing '>' for this tag.
67                    if let Some(end_rel) = lower_bytes[i..].iter().position(|&c| c == b'>') {
68                        let end = i + end_rel + 1;
69                        findings.push(Finding {
70                            kind: "html",
71                            matched: text[i..end].to_string(),
72                        });
73                        out.push_str("[removed:html]");
74                        i = end;
75                        matched = true;
76                        break;
77                    }
78                }
79            }
80        }
81        if matched {
82            continue;
83        }
84
85        // SQL: DROP/TRUNCATE/ALTER/DELETE FROM/INSERT INTO at word boundary
86        let sql_kws: &[&str] = &["drop ", "truncate ", "alter ", "delete from ", "insert into "];
87        let mut sql_matched = None;
88        for kw in sql_kws {
89            if lower[i..].starts_with(kw) && at_word_boundary(original_bytes, i) {
90                sql_matched = Some(*kw);
91                break;
92            }
93        }
94        if let Some(kw) = sql_matched {
95            // Take through end of word or end of input.
96            let mut end = i + kw.len();
97            while end < original_bytes.len() && !original_bytes[end].is_ascii_whitespace() {
98                end += 1;
99            }
100            findings.push(Finding {
101                kind: "sql",
102                matched: text[i..end].to_string(),
103            });
104            out.push_str("[removed:sql]");
105            i = end;
106            continue;
107        }
108
109        // Shell: rm -rf, curl|sh, wget|sh, chmod 777, sudo
110        let shell_signals: &[&str] = &["rm -rf", "chmod 777", "sudo ", "curl ", "wget "];
111        let mut shell_match = None;
112        for sig in shell_signals {
113            if lower[i..].starts_with(sig) && at_word_boundary(original_bytes, i) {
114                shell_match = Some(*sig);
115                break;
116            }
117        }
118        if let Some(sig) = shell_match {
119            let mut end = i + sig.len();
120            while end < original_bytes.len()
121                && original_bytes[end] != b'\n'
122                && original_bytes[end] != b';'
123            {
124                end += 1;
125            }
126            findings.push(Finding {
127                kind: "shell",
128                matched: text[i..end].to_string(),
129            });
130            out.push_str("[removed:shell]");
131            i = end;
132            continue;
133        }
134
135        // Default: copy one byte (UTF-8 safe — non-ASCII passes through one byte at a time
136        // because none of our patterns include multi-byte chars).
137        let c = text[i..].chars().next().unwrap();
138        out.push(c);
139        i += c.len_utf8();
140    }
141
142    if sink == Sink::Html {
143        out = out
144            .replace('&', "&amp;")
145            .replace('<', "&lt;")
146            .replace('>', "&gt;");
147    }
148
149    SanitizeResult {
150        safe: findings.is_empty(),
151        text: out,
152        findings,
153    }
154}
155
156fn at_word_boundary(bytes: &[u8], i: usize) -> bool {
157    i == 0 || !bytes[i - 1].is_ascii_alphanumeric()
158}