1use memchr::memchr3;
7
8
9#[inline]
13pub fn escape(s: &str) -> std::borrow::Cow<'_, str> {
14 if !needs_escape(s.as_bytes()) {
16 return std::borrow::Cow::Borrowed(s);
17 }
18
19 let mut result = String::with_capacity(s.len() + s.len() / 8);
20 escape_to(s, &mut result);
21 std::borrow::Cow::Owned(result)
22}
23
24#[inline]
26fn needs_escape(bytes: &[u8]) -> bool {
27 memchr3(b'<', b'>', b'&', bytes).is_some()
28 || memchr::memchr2(b'"', b'\'', bytes).is_some()
29}
30
31#[inline]
33pub fn escape_to(s: &str, out: &mut String) {
34 let bytes = s.as_bytes();
35 let mut start = 0;
36
37 for (i, &byte) in bytes.iter().enumerate() {
38 let escaped = match byte {
39 b'<' => "<",
40 b'>' => ">",
41 b'&' => "&",
42 b'"' => """,
43 b'\'' => "'",
44 _ => continue,
45 };
46
47 if start < i {
48 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
51 }
52 out.push_str(escaped);
53 start = i + 1;
54 }
55
56 if start < bytes.len() {
57 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
58 }
59}
60
61#[inline]
65pub fn escape_attr(s: &str) -> std::borrow::Cow<'_, str> {
66 escape(s)
67}
68
69#[inline]
73pub fn unescape(s: &str) -> Result<std::borrow::Cow<'_, str>, UnescapeError> {
74 if !s.contains('&') {
76 return Ok(std::borrow::Cow::Borrowed(s));
77 }
78
79 let mut result = String::with_capacity(s.len());
80 unescape_to(s, &mut result)?;
81 Ok(std::borrow::Cow::Owned(result))
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct UnescapeError {
87 pub entity: String,
89 pub position: usize,
91}
92
93impl std::fmt::Display for UnescapeError {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "invalid XML entity '{}' at position {}", self.entity, self.position)
96 }
97}
98
99impl std::error::Error for UnescapeError {}
100
101pub fn unescape_to(s: &str, out: &mut String) -> Result<(), UnescapeError> {
103 let bytes = s.as_bytes();
104 let mut i = 0;
105 let mut start = 0;
106
107 while i < bytes.len() {
108 if bytes[i] == b'&' {
109 if start < i {
111 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..i]) });
112 }
113
114 let entity_start = i;
115 i += 1;
116
117 let semicolon = bytes[i..].iter().position(|&b| b == b';');
119 match semicolon {
120 Some(len) if len > 0 => {
121 let entity = unsafe { std::str::from_utf8_unchecked(&bytes[i..i + len]) };
122 let decoded = decode_entity(entity);
123
124 match decoded {
125 Some(c) => out.push(c),
126 None => {
127 if let Some(c) = decode_numeric_entity(entity) {
129 out.push(c);
130 } else {
131 return Err(UnescapeError {
132 entity: format!("&{};", entity),
133 position: entity_start,
134 });
135 }
136 }
137 }
138
139 i += len + 1;
140 start = i;
141 }
142 _ => {
143 return Err(UnescapeError {
144 entity: String::from("&"),
145 position: entity_start,
146 });
147 }
148 }
149 } else {
150 i += 1;
151 }
152 }
153
154 if start < bytes.len() {
155 out.push_str(unsafe { std::str::from_utf8_unchecked(&bytes[start..]) });
156 }
157
158 Ok(())
159}
160
161#[inline]
163fn decode_entity(entity: &str) -> Option<char> {
164 match entity {
165 "lt" => Some('<'),
166 "gt" => Some('>'),
167 "amp" => Some('&'),
168 "quot" => Some('"'),
169 "apos" => Some('\''),
170 _ => None,
171 }
172}
173
174#[inline]
176fn decode_numeric_entity(entity: &str) -> Option<char> {
177 if entity.is_empty() {
178 return None;
179 }
180
181 let bytes = entity.as_bytes();
182 if bytes[0] != b'#' {
183 return None;
184 }
185
186 let (radix, digits) = if bytes.len() > 1 && (bytes[1] == b'x' || bytes[1] == b'X') {
187 (16, &entity[2..])
188 } else {
189 (10, &entity[1..])
190 };
191
192 if digits.is_empty() {
193 return None;
194 }
195
196 let code = u32::from_str_radix(digits, radix).ok()?;
197 char::from_u32(code)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_escape_no_special_chars() {
206 let s = "Hello, World!";
207 let escaped = escape(s);
208 assert!(matches!(escaped, std::borrow::Cow::Borrowed(_)));
209 assert_eq!(escaped, s);
210 }
211
212 #[test]
213 fn test_escape_lt() {
214 assert_eq!(escape("<"), "<");
215 }
216
217 #[test]
218 fn test_escape_gt() {
219 assert_eq!(escape(">"), ">");
220 }
221
222 #[test]
223 fn test_escape_amp() {
224 assert_eq!(escape("&"), "&");
225 }
226
227 #[test]
228 fn test_escape_quot() {
229 assert_eq!(escape("\""), """);
230 }
231
232 #[test]
233 fn test_escape_apos() {
234 assert_eq!(escape("'"), "'");
235 }
236
237 #[test]
238 fn test_escape_mixed() {
239 assert_eq!(
240 escape("<div class=\"foo\">Hello & goodbye</div>"),
241 "<div class="foo">Hello & goodbye</div>"
242 );
243 }
244
245 #[test]
246 fn test_unescape_no_entities() {
247 let s = "Hello, World!";
248 let unescaped = unescape(s).unwrap();
249 assert!(matches!(unescaped, std::borrow::Cow::Borrowed(_)));
250 assert_eq!(unescaped, s);
251 }
252
253 #[test]
254 fn test_unescape_lt() {
255 assert_eq!(unescape("<").unwrap(), "<");
256 }
257
258 #[test]
259 fn test_unescape_gt() {
260 assert_eq!(unescape(">").unwrap(), ">");
261 }
262
263 #[test]
264 fn test_unescape_amp() {
265 assert_eq!(unescape("&").unwrap(), "&");
266 }
267
268 #[test]
269 fn test_unescape_quot() {
270 assert_eq!(unescape(""").unwrap(), "\"");
271 }
272
273 #[test]
274 fn test_unescape_apos() {
275 assert_eq!(unescape("'").unwrap(), "'");
276 }
277
278 #[test]
279 fn test_unescape_mixed() {
280 assert_eq!(
281 unescape("<div class="foo">Hello & goodbye</div>").unwrap(),
282 "<div class=\"foo\">Hello & goodbye</div>"
283 );
284 }
285
286 #[test]
287 fn test_unescape_numeric_decimal() {
288 assert_eq!(unescape("A").unwrap(), "A");
289 assert_eq!(unescape("a").unwrap(), "a");
290 assert_eq!(unescape("€").unwrap(), "€");
291 }
292
293 #[test]
294 fn test_unescape_numeric_hex() {
295 assert_eq!(unescape("A").unwrap(), "A");
296 assert_eq!(unescape("a").unwrap(), "a");
297 assert_eq!(unescape("€").unwrap(), "€");
298 }
299
300 #[test]
301 fn test_unescape_invalid_entity() {
302 let result = unescape("&invalid;");
303 assert!(result.is_err());
304 let err = result.unwrap_err();
305 assert_eq!(err.entity, "&invalid;");
306 assert_eq!(err.position, 0);
307 }
308
309 #[test]
310 fn test_unescape_unterminated_entity() {
311 let result = unescape("<");
312 assert!(result.is_err());
313 }
314
315 #[test]
316 fn test_escape_to() {
317 let mut out = String::new();
318 escape_to("<test>", &mut out);
319 assert_eq!(out, "<test>");
320 }
321
322 #[test]
323 fn test_roundtrip() {
324 let original = "<div class=\"foo\">Hello & goodbye</div>";
325 let escaped = escape(original);
326 let unescaped = unescape(&escaped).unwrap();
327 assert_eq!(unescaped, original);
328 }
329}