regenerator/
lib.rs

1use std::ops::{Range, RangeInclusive};
2
3use anyhow::{anyhow, Error, Result};
4use derive_builder::Builder;
5use fuzzerang::{StandardBuffered, StandardSeedableRng, TryDistribution, TryRanged};
6use rand::{thread_rng, Rng, SeedableRng};
7use regex_syntax::hir::{Class, Hir, HirKind};
8
9pub trait Visitor {
10    type Output;
11    type Err;
12
13    fn start(&mut self);
14
15    fn finish(&mut self) -> Result<Self::Output, Self::Err>;
16
17    fn visit(&mut self, _hir: &Hir, _meta: &Meta) -> Result<(), Self::Err> {
18        Ok(())
19    }
20}
21
22#[derive(Builder)]
23pub struct RngVisitor {
24    #[builder(default = "false")]
25    /// Whether to fail, returning an error, when the underlying RNG is exhausted
26    /// of data.
27    fail_on_exhaust: bool,
28    #[builder(default = "true")]
29    /// Whether to fall back to using a real random number generator when the
30    /// current one is exhausted. This can help when generating the output of
31    /// grammars that progress infinitely under the default (first-path) strategy.
32    fallback_to_thread_rng: bool,
33    /// The RNG to use. This is a `StandardSeedableRng` only, for now.
34    rng: StandardSeedableRng,
35    #[builder(default = "StandardBuffered::new()")]
36    /// The buffered distribution to use. This is a `StandardBuffered` only, for now.
37    distribution: StandardBuffered,
38    #[builder(default)]
39    /// The progressively generated output.
40    output: Vec<u8>,
41}
42
43impl RngVisitor {
44    pub fn try_sample<T>(&mut self) -> Result<T>
45    where
46        StandardBuffered: TryDistribution<T>,
47    {
48        self.distribution.try_sample(&mut self.rng)
49    }
50
51    pub fn try_sample_range<T>(&mut self, range: Range<T>) -> Result<T>
52    where
53        StandardBuffered: TryRanged<T>,
54    {
55        self.distribution.try_sample_range(&mut self.rng, range)
56    }
57
58    pub fn try_sample_range_inclusive<T>(&mut self, range: RangeInclusive<T>) -> Result<T>
59    where
60        StandardBuffered: TryRanged<T>,
61    {
62        self.distribution
63            .try_sample_range_inclusive(&mut self.rng, range)
64    }
65}
66
67impl SeedableRng for RngVisitor {
68    type Seed = Vec<u8>;
69
70    fn from_seed(seed: Self::Seed) -> Self {
71        Self {
72            fail_on_exhaust: false,
73            fallback_to_thread_rng: true,
74            rng: StandardSeedableRng::from_seed(seed),
75            distribution: StandardBuffered::new(),
76            output: Vec::new(),
77        }
78    }
79}
80
81impl Visitor for RngVisitor {
82    type Output = Vec<u8>;
83
84    type Err = Error;
85
86    fn start(&mut self) {
87        self.output = Vec::new();
88    }
89
90    fn visit(&mut self, hir: &Hir, _meta: &Meta) -> Result<()> {
91        match hir.kind() {
92            HirKind::Empty => {}
93            HirKind::Literal(lit) => {
94                self.output.extend(&*lit.0);
95            }
96            HirKind::Class(cls) => match cls {
97                Class::Unicode(ucls) => {
98                    if let Some(lit) = ucls.literal() {
99                        self.output.extend(&lit);
100                    } else {
101                        let intervals = ucls.ranges();
102                        match self.try_sample_range(0..intervals.len()) {
103                            Ok(idx) => {
104                                let interval = intervals[idx];
105                                match self
106                                    .try_sample_range_inclusive(interval.start()..=interval.end())
107                                {
108                                    Ok(c) => {
109                                        let mut buf = [0; 4];
110                                        c.encode_utf8(&mut buf);
111                                        self.output.extend(&buf);
112                                    }
113                                    Err(e) => {
114                                        if self.fail_on_exhaust {
115                                            return Err(e);
116                                        } else {
117                                            let mut buf = [0; 4];
118                                            interval.start().encode_utf8(&mut buf);
119                                            self.output.extend(&buf);
120                                        }
121                                    }
122                                }
123                            }
124                            Err(e) => {
125                                if self.fallback_to_thread_rng {
126                                    let interval =
127                                        intervals[thread_rng().gen_range(0..intervals.len())];
128                                    let mut buf = [0; 4];
129                                    thread_rng()
130                                        .gen_range(interval.start()..=interval.end())
131                                        .encode_utf8(&mut buf);
132                                    self.output.extend(&buf);
133                                } else if self.fail_on_exhaust {
134                                    return Err(e);
135                                } else {
136                                    let mut buf = [0; 4];
137                                    intervals[0].start().encode_utf8(&mut buf);
138                                    self.output.extend(&buf);
139                                }
140                            }
141                        }
142                    }
143                }
144                Class::Bytes(bcls) => {
145                    if let Some(lit) = bcls.literal() {
146                        self.output.extend(&lit);
147                    } else {
148                        let intervals = bcls.ranges();
149                        match self.try_sample_range(0..intervals.len()) {
150                            Ok(index) => {
151                                let interval = intervals[index];
152                                match self
153                                    .try_sample_range_inclusive(interval.start()..=interval.end())
154                                {
155                                    Ok(c) => self.output.push(c),
156                                    Err(e) => {
157                                        if self.fail_on_exhaust {
158                                            return Err(e);
159                                        } else {
160                                            self.output.push(interval.start());
161                                        }
162                                    }
163                                }
164                            }
165                            Err(e) => {
166                                if self.fallback_to_thread_rng {
167                                    self.output.push({
168                                        let interval =
169                                            intervals[thread_rng().gen_range(0..intervals.len())];
170                                        thread_rng().gen_range(interval.start()..=interval.end())
171                                    });
172                                } else if self.fail_on_exhaust {
173                                    return Err(e);
174                                } else {
175                                    self.output.push(intervals[0].start());
176                                }
177                            }
178                        }
179                    }
180                }
181            },
182            HirKind::Look(_) => todo!(),
183            _ => {}
184        }
185        Ok(())
186    }
187
188    fn finish(&mut self) -> Result<Self::Output> {
189        Ok(self.output.clone())
190    }
191}
192
193#[derive(Default, Debug)]
194pub struct Meta {
195    repetitions: Option<u32>,
196    concat_idx: Option<usize>,
197    alternation_visited: bool,
198}
199
200pub fn visit(
201    hir: &Hir,
202    visitor: &mut RngVisitor,
203) -> Result<<RngVisitor as Visitor>::Output, <RngVisitor as Visitor>::Err> {
204    let mut stack = vec![(hir, Meta::default())];
205
206    visitor.start();
207
208    while let Some((top, meta)) = stack.last() {
209        visitor
210            .visit(top, meta)
211            .map_err(|e| anyhow!("Error while visiting: {}", e))?;
212        match top.kind() {
213            HirKind::Repetition(rep) => {
214                if let Some(repetitions) = meta.repetitions.as_ref() {
215                    let repetitions = *repetitions;
216                    // This isn't the first visit
217                    if repetitions < rep.min
218                        || (repetitions < rep.max.unwrap_or(repetitions + 1)
219                            // Even if we fall back to thread RNG, we don't want to take repetitions
220                            // once we exhaust our own RNG
221                            && visitor.try_sample().unwrap_or(false))
222                    {
223                        // We may visit the subexpression again
224                        let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
225                        stack.push((
226                            top,
227                            Meta {
228                                repetitions: Some(repetitions + 1),
229                                ..meta
230                            },
231                        ));
232                        stack.push((&rep.sub, Meta::default()));
233                    } else {
234                        // We must not visit the subexpression again
235                        stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
236                    }
237                } else {
238                    // This is the first visit
239                    let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
240                    stack.push((
241                        top,
242                        Meta {
243                            repetitions: Some(1),
244                            ..meta
245                        },
246                    ));
247                    stack.push((&rep.sub, Meta::default()))
248                }
249            }
250            HirKind::Concat(concat) => {
251                if let Some(concat_idx) = meta.concat_idx.as_ref() {
252                    // Not the first visit
253                    if concat_idx >= &concat.len() {
254                        // We have visited all subexpressions
255                        stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
256                    } else {
257                        let idx = *concat_idx;
258                        // We must visit the next subexpression
259                        let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
260
261                        stack.push((
262                            top,
263                            Meta {
264                                concat_idx: Some(idx + 1),
265                                ..meta
266                            },
267                        ));
268
269                        stack.push((&concat[idx], Meta::default()));
270                    }
271                } else {
272                    // First visit
273                    let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
274                    stack.push((
275                        top,
276                        Meta {
277                            concat_idx: Some(0),
278                            ..meta
279                        },
280                    ));
281                    stack.push((&concat[0], Meta::default()))
282                }
283            }
284            HirKind::Alternation(alt) => {
285                if meta.alternation_visited {
286                    // We have visited all subexpressions
287                    stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
288                } else {
289                    // We must visit the next subexpression
290                    let (top, meta) = stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
291                    stack.push((
292                        top,
293                        Meta {
294                            alternation_visited: true,
295                            ..meta
296                        },
297                    ));
298                    let alternation_index = visitor.try_sample_range(0..alt.len()).unwrap_or(
299                        if visitor.fallback_to_thread_rng {
300                            thread_rng().gen_range(0..alt.len())
301                        } else {
302                            0
303                        },
304                    );
305                    stack.push((&alt[alternation_index], Meta::default()));
306                }
307            }
308            _ => {
309                stack.pop().ok_or_else(|| anyhow!("Stack underflow"))?;
310            }
311        }
312    }
313
314    visitor.finish()
315}
316
317#[cfg(test)]
318mod tests {
319    use std::collections::HashSet;
320
321    use anyhow::Result;
322    use fuzzerang::StandardSeedableRng;
323    use rand::{thread_rng, Rng, SeedableRng};
324    use regex_syntax::{ast::parse::Parser, hir::translate::TranslatorBuilder};
325
326    use crate::{visit, RngVisitorBuilder, Visitor};
327
328    fn test(regex: &str, seed: Vec<u8>, bytes: bool) -> Result<Vec<u8>> {
329        // println!("Testing regex: {}", regex);
330        let mut p = Parser::new();
331        let ast = p.parse(regex)?;
332        let hir = if bytes {
333            let mut t = TranslatorBuilder::new().unicode(false).utf8(false).build();
334            t.translate(regex, &ast)?
335        } else {
336            let mut t = TranslatorBuilder::new().unicode(true).utf8(false).build();
337            t.translate(regex, &ast)?
338        };
339        // println!("HIR: {:?}", hir);
340        let mut visitor = RngVisitorBuilder::default()
341            .rng(StandardSeedableRng::from_seed(seed))
342            .build()?;
343        visit(&hir, &mut visitor)?;
344        let r = visitor.finish()?;
345        // println!("Result: {:?}", r);
346        // match String::from_utf8(r.clone()) {
347        //     Ok(s) => println!("Result (string): {}", s),
348        //     Err(e) => eprintln!("Result (invalid UTF-8): {:?}", e),
349        // }
350        Ok(r)
351    }
352
353    #[test]
354    fn test_literal() -> Result<()> {
355        let regex = "A";
356        assert_eq!(test(regex, vec![], true)?, b"A");
357        Ok(())
358    }
359
360    #[test]
361    fn test_char_repetition_exact() -> Result<()> {
362        let regex = "A{2}";
363        assert_eq!(test(regex, vec![], true)?, b"AA");
364        Ok(())
365    }
366
367    #[test]
368    fn test_char_repetition_range_min() -> Result<()> {
369        let regex = "A{3,6}";
370        assert_eq!(test(regex, vec![0b00000000], true)?, b"AAA");
371        Ok(())
372    }
373
374    #[test]
375    fn test_char_repetition_range_max() -> Result<()> {
376        let regex = "A{3,6}";
377        assert_eq!(test(regex, vec![0b11111111], true)?, b"AAAAAA");
378        Ok(())
379    }
380
381    #[test]
382    fn test_class_one() -> Result<()> {
383        let regex = "[02468]";
384        assert_eq!(test(regex, vec![0b00000000], true)?, b"0");
385        assert_eq!(test(regex, vec![0b00000001], true)?, b"2");
386        assert_eq!(test(regex, vec![0b00000010], true)?, b"4");
387        assert_eq!(test(regex, vec![0b00000011], true)?, b"6");
388        assert_eq!(test(regex, vec![0b00000100], true)?, b"8");
389        Ok(())
390    }
391
392    #[test]
393    fn test_class_range_one() -> Result<()> {
394        let regex = "[0-9]";
395        let mut seen = HashSet::new();
396        for i in 0..32u8 {
397            // Print i as binary
398            let r = test(regex, vec![i, i + 1], true)?;
399            if let Ok(s) = String::from_utf8(r) {
400                seen.insert(s);
401            }
402        }
403        assert!(
404            seen.len() == 10,
405            "Expected 10 unique digits, got {:?}",
406            seen
407        );
408        Ok(())
409    }
410
411    #[test]
412    fn test_class_negate_one() -> Result<()> {
413        let regex = "[^0]";
414        for i in 0..(255 - 9) {
415            // Print i as binary
416            let r = test(regex, (i..(i + 8)).collect(), true)?;
417            if let Ok(s) = String::from_utf8(r) {
418                assert_ne!(s, "0");
419            }
420        }
421        Ok(())
422    }
423
424    #[test]
425    fn test_class_negate_range_one() -> Result<()> {
426        let regex = "[^0-9]";
427        for i in 0..(255 - 9) {
428            // Print i as binary
429            let r = test(regex, (i..(i + 8)).collect(), true)?;
430            if let Ok(s) = String::from_utf8(r) {
431                assert!(s.chars().all(|c| !c.is_ascii_digit()));
432            }
433        }
434        Ok(())
435    }
436
437    #[test]
438    fn test_concat() -> Result<()> {
439        let regex = "A+B+";
440        assert_eq!(
441            test(regex, vec![0xde, 0xad, 0xbe, 0xef], true)?,
442            b"AAAAAABBBB"
443        );
444        Ok(())
445    }
446
447    const JSON_REGEXES: &[&str] = &[
448        r#"\\"#,
449        r#"(\"|\\|\/|b|f|n|r|t|u)"#,
450        r#"[\da-fA-F]"#,
451        r#"[0-1]+"#,
452        r#"[0-7]+"#,
453        r#"[1-9]"#,
454        r#".*"#,
455        r#"[^*]*\*+([^/*][^*]*\*+)*"#,
456    ];
457
458    #[test]
459    fn test_json() -> Result<()> {
460        let mut rng = thread_rng();
461        for regex in JSON_REGEXES {
462            // println!("Testing regex: {}", regex);
463            test(regex, (0..64).map(|_| rng.gen()).collect(), true)?;
464        }
465        Ok(())
466    }
467}