1#![cfg_attr(
6 all(
7 any(target_arch = "x86", target_arch = "x86_64"),
8 not(miri),
9 target_feature = "avx2"
10 ),
11 allow(dead_code)
12)]
13
14#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)))]
15mod avx2;
16mod fallback;
17mod naive;
18#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)))]
19mod sse2;
20
21static ESCAPE_LUT: [u8; 256] = [
22 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
23 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 1, 2, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
24 9, 9, 9, 9, 3, 9, 4, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
25 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
26 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
27 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
28 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
29 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
30 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
31 9, 9, 9, 9,
32];
33
34const ESCAPED: [&str; 5] = [""", "&", "'", "<", ">"];
35const ESCAPED_LEN: usize = 5;
36
37use super::buffer::Buffer;
38
39#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), not(miri)))]
41#[cfg_attr(feature = "perf-inline", inline)]
42pub fn escape_to_buf(feed: &str, buf: &mut Buffer) {
43 #[cfg(not(target_feature = "avx2"))]
44 {
45 use std::sync::atomic::{AtomicPtr, Ordering};
46
47 type FnRaw = *mut ();
48 static FN: AtomicPtr<()> = AtomicPtr::new(detect as FnRaw);
49
50 fn detect(feed: &str, buf: &mut Buffer) {
51 debug_assert!(feed.len() >= 16);
52 let fun = if is_x86_feature_detected!("avx2") {
53 avx2::escape
54 } else if is_x86_feature_detected!("sse2") {
55 sse2::escape
56 } else {
57 fallback::escape
58 };
59
60 FN.store(fun as FnRaw, Ordering::Relaxed);
61 unsafe { fun(feed, buf) };
62 }
63
64 unsafe {
65 if feed.len() < 16 {
66 buf.reserve_small(feed.len() * 6);
67 let l = naive::escape_small(feed, buf.as_mut_ptr().add(buf.len()));
68 buf.advance(l);
69 } else {
70 let fun = FN.load(Ordering::Relaxed);
71 std::mem::transmute::<FnRaw, fn(&str, &mut Buffer)>(fun)(feed, buf);
72 }
73 }
74 }
75
76 #[cfg(target_feature = "avx2")]
77 unsafe {
78 if feed.len() < 16 {
79 buf.reserve_small(feed.len() * 6);
80 let l = naive::escape_small(feed, buf.as_mut_ptr().add(buf.len()));
81 buf.advance(l);
82 } else if cfg!(target_feature = "avx2") {
83 avx2::escape(feed, buf);
84 }
85 }
86}
87
88#[cfg(not(all(any(target_arch = "x86", target_arch = "x86_64"), not(miri))))]
90#[cfg_attr(feature = "perf-inline", inline)]
91pub fn escape_to_buf(feed: &str, buf: &mut Buffer) {
92 unsafe {
93 if cfg!(miri) {
94 let bp = feed.as_ptr();
95 naive::escape(buf, bp, bp, bp.add(feed.len()))
96 } else if feed.len() < 16 {
97 buf.reserve_small(feed.len() * 6);
98 let l = naive::escape_small(feed, buf.as_mut_ptr().add(buf.len()));
99 buf.advance(l);
100 } else {
101 fallback::escape(feed, buf)
102 }
103 }
104}
105
106#[inline]
118pub fn escape_to_string(feed: &str, s: &mut String) {
119 let mut s2 = String::new();
120 std::mem::swap(s, &mut s2);
121 let mut buf = Buffer::from(s2);
122 escape_to_buf(feed, &mut buf);
123 let mut s2 = buf.into_string();
124 std::mem::swap(s, &mut s2);
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 fn escape(feed: &str) -> String {
132 let mut s = String::new();
133 escape_to_string(feed, &mut s);
134 s
135 }
136
137 #[test]
138 fn noescape() {
139 assert_eq!(escape(""), "");
140 assert_eq!(escape("1234567890"), "1234567890");
141 assert_eq!(
142 escape("abcdefghijklmnopqrstrvwxyz"),
143 "abcdefghijklmnopqrstrvwxyz"
144 );
145 assert_eq!(escape("!#$%()*+,-.:;=?_^"), "!#$%()*+,-.:;=?_^");
146 assert_eq!(
147 escape("漢字はエスケープしないはずだよ"),
148 "漢字はエスケープしないはずだよ"
149 );
150 }
151
152 #[test]
153 fn escape_short() {
154 assert_eq!(escape("<"), "<");
155 assert_eq!(escape("\"&<>'"), ""&<>'");
156 assert_eq!(
157 escape("{\"title\": \"This is a JSON!\"}"),
158 "{"title": "This is a JSON!"}"
159 );
160 assert_eq!(
161 escape("<html><body><h1>Hello, world</h1></body></html>"),
162 "<html><body><h1>Hello, world</h1>\
163 </body></html>"
164 );
165 }
166
167 #[test]
168 #[rustfmt::skip]
169 fn escape_long() {
170 assert_eq!(
171 escape(r###"m{jml&,?6>\2~08g)\=3`,_`$1@0{i5j}.}2ki\^t}k"'@p4$~?;!;pn_l8v."ki`%/&^=\[y+qcerr`@3*|?du.\0vd#40'.>bcpf\u@m|c<2t7`hk)^?"0u{v%9}4y2hhv?%-f`<;rzwx`7}l(j2b:c\<|z&$x{+k;f`0+w3e0\m.wmdli>94e2hp\$}j0&m(*h$/lwlj#}99r';o.kj@1#}~v+;y~b[~m.eci}&l7fxt`\\{~#k*9z/d{}(.^j}[(,]:<\h]9k2+0*w60/|23~5;/!-h&ci*~e1h~+:1lhh\>y_*>:-\zzv+8uo],,a^k3_,uip]-/.-~\t51a*<{6!<(_|<#o6=\h1*`[2x_?#-/])x};};r@wqx|;/w&jrv~?\`t:^/dug3(g(ener?!t$}h4:57ptnm@71e=t>@o*"$]799r=+)t>co?rvgk%u0c@.9os;#t_*/gqv<za&~r^]"{t4by2t`<q4bfo^&!so5/~(nxk:7l\;#0w41u~w3i$g|>e/t;o<*`~?3.jyx+h)+^cn^j4td|>)~rs)vm#]:"&\fi;54%+z~fhe|w~\q|ui={54[b9tg*?@]g+q!mq]3jg2?eoo"chyat3k#7pq1u=.l]c14twa4tg#5k_""###),
172 r###"m{jml&,?6>\2~08g)\=3`,_`$1@0{i5j}.}2ki\^t}k"'@p4$~?;!;pn_l8v."ki`%/&^=\[y+qcerr`@3*|?du.\0vd#40'.>bcpf\u@m|c<2t7`hk)^?"0u{v%9}4y2hhv?%-f`<;rzwx`7}l(j2b:c\<|z&$x{+k;f`0+w3e0\m.wmdli>94e2hp\$}j0&m(*h$/lwlj#}99r';o.kj@1#}~v+;y~b[~m.eci}&l7fxt`\\{~#k*9z/d{}(.^j}[(,]:<\h]9k2+0*w60/|23~5;/!-h&ci*~e1h~+:1lhh\>y_*>:-\zzv+8uo],,a^k3_,uip]-/.-~\t51a*<{6!<(_|<#o6=\h1*`[2x_?#-/])x};};r@wqx|;/w&jrv~?\`t:^/dug3(g(ener?!t$}h4:57ptnm@71e=t>@o*"$]799r=+)t>co?rvgk%u0c@.9os;#t_*/gqv<za&~r^]"{t4by2t`<q4bfo^&!so5/~(nxk:7l\;#0w41u~w3i$g|>e/t;o<*`~?3.jyx+h)+^cn^j4td|>)~rs)vm#]:"&\fi;54%+z~fhe|w~\q|ui={54[b9tg*?@]g+q!mq]3jg2?eoo"chyat3k#7pq1u=.l]c14twa4tg#5k_""###
173 );
174 }
175
176 #[test]
177 #[cfg(not(miri))]
178 fn random() {
179 const ASCII_CHARS: &'static [u8] = br##"abcdefghijklmnopqrstuvwxyz0123456789-^\@[;:],./\!"#$%&'()~=~|`{+*}<>?_"##;
180 let mut state = 88172645463325252u64;
181 let mut data = Vec::with_capacity(100);
182
183 let mut buf_naive = Buffer::new();
184 let mut buf = Buffer::new();
185
186 for len in 16..100 {
187 for _ in 0..5 {
188 data.clear();
189 for _ in 0..len {
190 state ^= state << 13;
192 state ^= state >> 7;
193 state ^= state << 17;
194
195 let idx = state as usize % ASCII_CHARS.len();
196 data.push(ASCII_CHARS[idx]);
197 }
198
199 let s = unsafe { std::str::from_utf8_unchecked(&*data) };
200
201 unsafe {
202 naive::escape(
203 &mut buf_naive,
204 s.as_ptr(),
205 s.as_ptr(),
206 s.as_ptr().add(s.len()),
207 );
208
209 fallback::escape(s, &mut buf);
210 assert_eq!(buf.as_str(), buf_naive.as_str());
211 buf.clear();
212
213 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
214 {
215 if is_x86_feature_detected!("sse2") {
216 sse2::escape(s, &mut buf);
217 assert_eq!(buf.as_str(), buf_naive.as_str());
218 buf.clear();
219 }
220
221 if is_x86_feature_detected!("avx2") {
222 avx2::escape(s, &mut buf);
223 assert_eq!(buf.as_str(), buf_naive.as_str());
224 buf.clear();
225 }
226 }
227 }
228
229 buf_naive.clear();
230 }
231 }
232 }
233}