1use std::ops::{Range, RangeInclusive};
2
3use anyhow::{anyhow, Error, Result};
4use derive_builder::Builder;
5use fuzzerang::{StandardBuffered, StandardSeedableRng, TryDistribution, TryRanged};
6use rand::{thread_rng, Rng, SeedableRng};
7use regex_syntax::hir::{Class, Hir, HirKind};
8
9pub trait Visitor {
10 type Output;
11 type Err;
12
13 fn start(&mut self);
14
15 fn finish(&mut self) -> Result<Self::Output, Self::Err>;
16
17 fn visit(&mut self, _hir: &Hir, _meta: &Meta) -> Result<(), Self::Err> {
18 Ok(())
19 }
20}
21
22#[derive(Builder)]
23pub struct RngVisitor {
24 #[builder(default = "false")]
25 fail_on_exhaust: bool,
28 #[builder(default = "true")]
29 fallback_to_thread_rng: bool,
33 rng: StandardSeedableRng,
35 #[builder(default = "StandardBuffered::new()")]
36 distribution: StandardBuffered,
38 #[builder(default)]
39 output: Vec<u8>,
41}
42
43impl RngVisitor {
44 pub fn try_sample<T>(&mut self) -> Result<T>
45 where
46 StandardBuffered: TryDistribution<T>,
47 {
48 self.distribution.try_sample(&mut self.rng)
49 }
50
51 pub fn try_sample_range<T>(&mut self, range: Range<T>) -> Result<T>
52 where
53 StandardBuffered: TryRanged<T>,
54 {
55 self.distribution.try_sample_range(&mut self.rng, range)
56 }
57
58 pub fn try_sample_range_inclusive<T>(&mut self, range: RangeInclusive<T>) -> Result<T>
59 where
60 StandardBuffered: TryRanged<T>,
61 {
62 self.distribution
63 .try_sample_range_inclusive(&mut self.rng, range)
64 }
65}
66
67impl SeedableRng for RngVisitor {
68 type Seed = Vec<u8>;
69
70 fn from_seed(seed: Self::Seed) -> Self {
71 Self {
72 fail_on_exhaust: false,
73 fallback_to_thread_rng: true,
74 rng: StandardSeedableRng::from_seed(seed),
75 distribution: StandardBuffered::new(),
76 output: Vec::new(),
77 }
78 }
79}
80
81impl Visitor for RngVisitor {
82 type Output = Vec<u8>;
83
84 type Err = Error;
85
86 fn start(&mut self) {
87 self.output = Vec::new();
88 }
89
90 fn visit(&mut self, hir: &Hir, _meta: &Meta) -> Result<()> {
91 match hir.kind() {
92 HirKind::Empty => {}
93 HirKind::Literal(lit) => {
94 self.output.extend(&*lit.0);
95 }
96 HirKind::Class(cls) => match cls {
97 Class::Unicode(ucls) => {
98 if let Some(lit) = ucls.literal() {
99 self.output.extend(&lit);
100 } else {
101 let intervals = ucls.ranges();
102 match self.try_sample_range(0..intervals.len()) {
103 Ok(idx) => {
104 let interval = intervals[idx];
105 match self
106 .try_sample_range_inclusive(interval.start()..=interval.end())
107 {
108 Ok(c) => {
109 let mut buf = [0; 4];
110 c.encode_utf8(&mut buf);
111 self.output.extend(&buf);
112 }
113 Err(e) => {
114 if self.fail_on_exhaust {
115 return Err(e);
116 } else {
117 let mut buf = [0; 4];
118 interval.start().encode_utf8(&mut buf);
119 self.output.extend(&buf);
120 }
121 }
122 }
123 }
124 Err(e) => {
125 if self.fallback_to_thread_rng {
126 let interval =
127 intervals[thread_rng().gen_range(0..intervals.len())];
128 let mut buf = [0; 4];
129 thread_rng()
130 .gen_range(interval.start()..=interval.end())
131 .encode_utf8(&mut buf);
132 self.output.extend(&buf);
133 } else if self.fail_on_exhaust {
134 return Err(e);
135 } else {
136 let mut buf = [0; 4];
137 intervals[0].start().encode_utf8(&mut buf);
138 self.output.extend(&buf);
139 }
140 }
141 }
142 }
143 }
144 Class::Bytes(bcls) => {
145 if let Some(lit) = bcls.literal() {
146 self.output.extend(&lit);
147 } else {
148 let intervals = bcls.ranges();
149 match self.try_sample_range(0..intervals.len()) {
150 Ok(index) => {
151 let interval = intervals[index];
152 match self
153 .try_sample_range_inclusive(interval.start()..=interval.end())
154 {
155 Ok(c) => self.output.push(c),
156 Err(e) => {
157 if self.fail_on_exhaust {
158 return Err(e);
159 } else {
160 self.output.push(interval.start());
161 }
162 }
163 }
164 }
165 Err(e) => {
166 if self.fallback_to_thread_rng {
167 self.output.push({
168 let interval =
169 intervals[thread_rng().gen_range(0..intervals.len())];
170 thread_rng().gen_range(interval.start()..=interval.end())
171 });
172 } else if self.fail_on_exhaust {
173 return Err(e);
174 } else {
175 self.output.push(intervals[0].start());
176 }
177 }
178 }
179 }
180 }
181 },
182 HirKind::Look(_) => todo!(),
183 _ => {}
184 }
185 Ok(())
186 }
187
188 fn finish(&mut self) -> Result<Self::Output> {
189 Ok(self.output.clone())
190 }
191}
192
193#[derive(Default, Debug)]
194pub struct Meta {
195 repetitions: Option<u32>,
196 concat_idx: Option<usize>,
197 alternation_visited: bool,
198}
199
200pub fn visit(
201 hir: &Hir,
202 visitor: &mut RngVisitor,
203) -> Result<<RngVisitor as Visitor>::Output, <RngVisitor as Visitor>::Err> {
204 let mut stack = vec![(hir, Meta::default())];
205
206 visitor.start();
207
208 while let Some((top, meta)) = stack.last() {
209 visitor
210 .visit(top, meta)
211 .map_err(|e| anyhow!("Error while visiting: {}", e))?;
212 match top.kind() {
213 HirKind::Repetition(rep) => {
214 if let Some(repetitions) = meta.repetitions.as_ref() {
215 let repetitions = *repetitions;
216 if repetitions < rep.min
218 || (repetitions < rep.max.unwrap_or(repetitions + 1)
219 && visitor.try_sample().unwrap_or(false))
222 {
223 let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
225 stack.push((
226 top,
227 Meta {
228 repetitions: Some(repetitions + 1),
229 ..meta
230 },
231 ));
232 stack.push((&rep.sub, Meta::default()));
233 } else {
234 stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
236 }
237 } else {
238 let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
240 stack.push((
241 top,
242 Meta {
243 repetitions: Some(1),
244 ..meta
245 },
246 ));
247 stack.push((&rep.sub, Meta::default()))
248 }
249 }
250 HirKind::Concat(concat) => {
251 if let Some(concat_idx) = meta.concat_idx.as_ref() {
252 if concat_idx >= &concat.len() {
254 stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
256 } else {
257 let idx = *concat_idx;
258 let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
260
261 stack.push((
262 top,
263 Meta {
264 concat_idx: Some(idx + 1),
265 ..meta
266 },
267 ));
268
269 stack.push((&concat[idx], Meta::default()));
270 }
271 } else {
272 let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
274 stack.push((
275 top,
276 Meta {
277 concat_idx: Some(0),
278 ..meta
279 },
280 ));
281 stack.push((&concat[0], Meta::default()))
282 }
283 }
284 HirKind::Alternation(alt) => {
285 if meta.alternation_visited {
286 stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
288 } else {
289 let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
291 stack.push((
292 top,
293 Meta {
294 alternation_visited: true,
295 ..meta
296 },
297 ));
298 let alternation_index = visitor.try_sample_range(0..alt.len()).unwrap_or(
299 if visitor.fallback_to_thread_rng {
300 thread_rng().gen_range(0..alt.len())
301 } else {
302 0
303 },
304 );
305 stack.push((&alt[alternation_index], Meta::default()));
306 }
307 }
308 _ => {
309 stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
310 }
311 }
312 }
313
314 visitor.finish()
315}
316
317#[cfg(test)]
318mod tests {
319 use std::collections::HashSet;
320
321 use anyhow::Result;
322 use fuzzerang::StandardSeedableRng;
323 use rand::{thread_rng, Rng, SeedableRng};
324 use regex_syntax::{ast::parse::Parser, hir::translate::TranslatorBuilder};
325
326 use crate::{visit, RngVisitorBuilder, Visitor};
327
328 fn test(regex: &str, seed: Vec<u8>, bytes: bool) -> Result<Vec<u8>> {
329 let mut p = Parser::new();
331 let ast = p.parse(regex)?;
332 let hir = if bytes {
333 let mut t = TranslatorBuilder::new().unicode(false).utf8(false).build();
334 t.translate(regex, &ast)?
335 } else {
336 let mut t = TranslatorBuilder::new().unicode(true).utf8(false).build();
337 t.translate(regex, &ast)?
338 };
339 let mut visitor = RngVisitorBuilder::default()
341 .rng(StandardSeedableRng::from_seed(seed))
342 .build()?;
343 visit(&hir, &mut visitor)?;
344 let r = visitor.finish()?;
345 Ok(r)
351 }
352
353 #[test]
354 fn test_literal() -> Result<()> {
355 let regex = "A";
356 assert_eq!(test(regex, vec![], true)?, b"A");
357 Ok(())
358 }
359
360 #[test]
361 fn test_char_repetition_exact() -> Result<()> {
362 let regex = "A{2}";
363 assert_eq!(test(regex, vec![], true)?, b"AA");
364 Ok(())
365 }
366
367 #[test]
368 fn test_char_repetition_range_min() -> Result<()> {
369 let regex = "A{3,6}";
370 assert_eq!(test(regex, vec![0b00000000], true)?, b"AAA");
371 Ok(())
372 }
373
374 #[test]
375 fn test_char_repetition_range_max() -> Result<()> {
376 let regex = "A{3,6}";
377 assert_eq!(test(regex, vec![0b11111111], true)?, b"AAAAAA");
378 Ok(())
379 }
380
381 #[test]
382 fn test_class_one() -> Result<()> {
383 let regex = "[02468]";
384 assert_eq!(test(regex, vec![0b00000000], true)?, b"0");
385 assert_eq!(test(regex, vec![0b00000001], true)?, b"2");
386 assert_eq!(test(regex, vec![0b00000010], true)?, b"4");
387 assert_eq!(test(regex, vec![0b00000011], true)?, b"6");
388 assert_eq!(test(regex, vec![0b00000100], true)?, b"8");
389 Ok(())
390 }
391
392 #[test]
393 fn test_class_range_one() -> Result<()> {
394 let regex = "[0-9]";
395 let mut seen = HashSet::new();
396 for i in 0..32u8 {
397 let r = test(regex, vec![i, i + 1], true)?;
399 if let Ok(s) = String::from_utf8(r) {
400 seen.insert(s);
401 }
402 }
403 assert!(
404 seen.len() == 10,
405 "Expected 10 unique digits, got {:?}",
406 seen
407 );
408 Ok(())
409 }
410
411 #[test]
412 fn test_class_negate_one() -> Result<()> {
413 let regex = "[^0]";
414 for i in 0..(255 - 9) {
415 let r = test(regex, (i..(i + 8)).collect(), true)?;
417 if let Ok(s) = String::from_utf8(r) {
418 assert_ne!(s, "0");
419 }
420 }
421 Ok(())
422 }
423
424 #[test]
425 fn test_class_negate_range_one() -> Result<()> {
426 let regex = "[^0-9]";
427 for i in 0..(255 - 9) {
428 let r = test(regex, (i..(i + 8)).collect(), true)?;
430 if let Ok(s) = String::from_utf8(r) {
431 assert!(s.chars().all(|c| !c.is_ascii_digit()));
432 }
433 }
434 Ok(())
435 }
436
437 #[test]
438 fn test_concat() -> Result<()> {
439 let regex = "A+B+";
440 assert_eq!(
441 test(regex, vec![0xde, 0xad, 0xbe, 0xef], true)?,
442 b"AAAAAABBBB"
443 );
444 Ok(())
445 }
446
447 const JSON_REGEXES: &[&str] = &[
448 r#"\\"#,
449 r#"(\"|\\|\/|b|f|n|r|t|u)"#,
450 r#"[\da-fA-F]"#,
451 r#"[0-1]+"#,
452 r#"[0-7]+"#,
453 r#"[1-9]"#,
454 r#".*"#,
455 r#"[^*]*\*+([^/*][^*]*\*+)*"#,
456 ];
457
458 #[test]
459 fn test_json() -> Result<()> {
460 let mut rng = thread_rng();
461 for regex in JSON_REGEXES {
462 test(regex, (0..64).map(|_| rng.gen()).collect(), true)?;
464 }
465 Ok(())
466 }
467}