serde_xml/
escape.rs

1//! XML escape and unescape utilities.
2//!
3//! This module provides fast, allocation-minimizing functions for escaping
4//! and unescaping XML special characters.
5
6use memchr::memchr3;
7
8
9/// Escapes XML special characters in a string.
10///
11/// Returns a `Cow<str>` to avoid allocation when no escaping is needed.
12#[inline]
13pub fn escape(s: &str) -> std::borrow::Cow<'_, str> {
14    // Fast path: check if any escaping is needed
15    if !needs_escape(s.as_bytes()) {
16        return std::borrow::Cow::Borrowed(s);
17    }
18
19    let mut result = String::with_capacity(s.len() + s.len() / 8);
20    escape_to(s, &mut result);
21    std::borrow::Cow::Owned(result)
22}
23
24/// Checks if a byte slice needs escaping.
25#[inline]
26fn needs_escape(bytes: &[u8]) -> bool {
27    memchr3(b'<', b'>', b'&', bytes).is_some()
28        || memchr::memchr2(b'"', b'\'', bytes).is_some()
29}
30
31/// Escapes XML special characters and appends to the given string.
32#[inline]
33pub fn escape_to(s: &str, out: &mut String) {
34    let bytes = s.as_bytes();
35    let mut start = 0;
36
37    for (i, &byte) in bytes.iter().enumerate() {
38        let escaped = match byte {
39            b'<' => "&lt;",
40            b'>' => "&gt;",
41            b'&' => "&amp;",
42            b'"' => "&quot;",
43            b'\'' => "&apos;",
44            _ => continue,
45        };
46
47        if start < i {
48            // SAFETY: We're slicing at valid UTF-8 boundaries since we only
49            // escape ASCII characters.
50            out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
51        }
52        out.push_str(escaped);
53        start = i + 1;
54    }
55
56    if start < bytes.len() {
57        out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
58    }
59}
60
61/// Escapes XML special characters for attribute values.
62///
63/// This is the same as `escape` but optimized for attribute values.
64#[inline]
65pub fn escape_attr(s: &str) -> std::borrow::Cow<'_, str> {
66    escape(s)
67}
68
69/// Unescapes XML entities in a string.
70///
71/// Returns a `Cow<str>` to avoid allocation when no unescaping is needed.
72#[inline]
73pub fn unescape(s: &str) -> Result<std::borrow::Cow<'_, str>, UnescapeError> {
74    // Fast path: check if any unescaping is needed
75    if !s.contains('&') {
76        return Ok(std::borrow::Cow::Borrowed(s));
77    }
78
79    let mut result = String::with_capacity(s.len());
80    unescape_to(s, &mut result)?;
81    Ok(std::borrow::Cow::Owned(result))
82}
83
84/// Error type for unescape operations.
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct UnescapeError {
87    /// The invalid entity that caused the error.
88    pub entity: String,
89    /// Position in the input where the error occurred.
90    pub position: usize,
91}
92
93impl std::fmt::Display for UnescapeError {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "invalid XML entity '{}' at position {}", self.entity, self.position)
96    }
97}
98
99impl std::error::Error for UnescapeError {}
100
101/// Unescapes XML entities and appends to the given string.
102pub fn unescape_to(s: &str, out: &mut String) -> Result<(), UnescapeError> {
103    let bytes = s.as_bytes();
104    let mut i = 0;
105    let mut start = 0;
106
107    while i < bytes.len() {
108        if bytes[i] == b'&' {
109            // Append text before the entity
110            if start < i {
111                out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
112            }
113
114            let entity_start = i;
115            i += 1;
116
117            // Find the end of the entity
118            let semicolon = bytes[i..].iter().position(|&b| b == b';');
119            match semicolon {
120                Some(len) if len > 0 => {
121                    let entity = unsafe { std::str::from_utf8_unchecked(&bytes[i..i + len]) };
122                    let decoded = decode_entity(entity);
123
124                    match decoded {
125                        Some(c) => out.push(c),
126                        None => {
127                            // Try numeric character reference
128                            if let Some(c) = decode_numeric_entity(entity) {
129                                out.push(c);
130                            } else {
131                                return Err(UnescapeError {
132                                    entity: format!("&{};", entity),
133                                    position: entity_start,
134                                });
135                            }
136                        }
137                    }
138
139                    i += len + 1;
140                    start = i;
141                }
142                _ => {
143                    return Err(UnescapeError {
144                        entity: String::from("&"),
145                        position: entity_start,
146                    });
147                }
148            }
149        } else {
150            i += 1;
151        }
152    }
153
154    if start < bytes.len() {
155        out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
156    }
157
158    Ok(())
159}
160
161/// Decodes a named XML entity.
162#[inline]
163fn decode_entity(entity: &str) -> Option<char> {
164    match entity {
165        "lt" => Some('<'),
166        "gt" => Some('>'),
167        "amp" => Some('&'),
168        "quot" => Some('"'),
169        "apos" => Some('\''),
170        _ => None,
171    }
172}
173
174/// Decodes a numeric character reference (&#NNN; or &#xHHH;).
175#[inline]
176fn decode_numeric_entity(entity: &str) -> Option<char> {
177    if entity.is_empty() {
178        return None;
179    }
180
181    let bytes = entity.as_bytes();
182    if bytes[0] != b'#' {
183        return None;
184    }
185
186    let (radix, digits) = if bytes.len() > 1 && (bytes[1] == b'x' || bytes[1] == b'X') {
187        (16, &entity[2..])
188    } else {
189        (10, &entity[1..])
190    };
191
192    if digits.is_empty() {
193        return None;
194    }
195
196    let code = u32::from_str_radix(digits, radix).ok()?;
197    char::from_u32(code)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_escape_no_special_chars() {
206        let s = "Hello, World!";
207        let escaped = escape(s);
208        assert!(matches!(escaped, std::borrow::Cow::Borrowed(_)));
209        assert_eq!(escaped, s);
210    }
211
212    #[test]
213    fn test_escape_lt() {
214        assert_eq!(escape("<"), "&lt;");
215    }
216
217    #[test]
218    fn test_escape_gt() {
219        assert_eq!(escape(">"), "&gt;");
220    }
221
222    #[test]
223    fn test_escape_amp() {
224        assert_eq!(escape("&"), "&amp;");
225    }
226
227    #[test]
228    fn test_escape_quot() {
229        assert_eq!(escape("\""), "&quot;");
230    }
231
232    #[test]
233    fn test_escape_apos() {
234        assert_eq!(escape("'"), "&apos;");
235    }
236
237    #[test]
238    fn test_escape_mixed() {
239        assert_eq!(
240            escape("<div class=\"foo\">Hello & goodbye</div>"),
241            "&lt;div class=&quot;foo&quot;&gt;Hello &amp; goodbye&lt;/div&gt;"
242        );
243    }
244
245    #[test]
246    fn test_unescape_no_entities() {
247        let s = "Hello, World!";
248        let unescaped = unescape(s).unwrap();
249        assert!(matches!(unescaped, std::borrow::Cow::Borrowed(_)));
250        assert_eq!(unescaped, s);
251    }
252
253    #[test]
254    fn test_unescape_lt() {
255        assert_eq!(unescape("&lt;").unwrap(), "<");
256    }
257
258    #[test]
259    fn test_unescape_gt() {
260        assert_eq!(unescape("&gt;").unwrap(), ">");
261    }
262
263    #[test]
264    fn test_unescape_amp() {
265        assert_eq!(unescape("&amp;").unwrap(), "&");
266    }
267
268    #[test]
269    fn test_unescape_quot() {
270        assert_eq!(unescape("&quot;").unwrap(), "\"");
271    }
272
273    #[test]
274    fn test_unescape_apos() {
275        assert_eq!(unescape("&apos;").unwrap(), "'");
276    }
277
278    #[test]
279    fn test_unescape_mixed() {
280        assert_eq!(
281            unescape("&lt;div class=&quot;foo&quot;&gt;Hello &amp; goodbye&lt;/div&gt;").unwrap(),
282            "<div class=\"foo\">Hello & goodbye</div>"
283        );
284    }
285
286    #[test]
287    fn test_unescape_numeric_decimal() {
288        assert_eq!(unescape("&#65;").unwrap(), "A");
289        assert_eq!(unescape("&#97;").unwrap(), "a");
290        assert_eq!(unescape("&#8364;").unwrap(), "€");
291    }
292
293    #[test]
294    fn test_unescape_numeric_hex() {
295        assert_eq!(unescape("&#x41;").unwrap(), "A");
296        assert_eq!(unescape("&#x61;").unwrap(), "a");
297        assert_eq!(unescape("&#x20AC;").unwrap(), "€");
298    }
299
300    #[test]
301    fn test_unescape_invalid_entity() {
302        let result = unescape("&invalid;");
303        assert!(result.is_err());
304        let err = result.unwrap_err();
305        assert_eq!(err.entity, "&invalid;");
306        assert_eq!(err.position, 0);
307    }
308
309    #[test]
310    fn test_unescape_unterminated_entity() {
311        let result = unescape("&lt");
312        assert!(result.is_err());
313    }
314
315    #[test]
316    fn test_escape_to() {
317        let mut out = String::new();
318        escape_to("<test>", &mut out);
319        assert_eq!(out, "&lt;test&gt;");
320    }
321
322    #[test]
323    fn test_roundtrip() {
324        let original = "<div class=\"foo\">Hello & goodbye</div>";
325        let escaped = escape(original);
326        let unescaped = unescape(&escaped).unwrap();
327        assert_eq!(unescaped, original);
328    }
329}