smt_str/
sampling.rs

1use bit_set::BitSet;
2
3use rand::{rng, seq::IteratorRandom};
4
5use crate::{
6    automata::{TransitionType, NFA},
7    re::{deriv::DerivativeBuilder, ReBuilder, ReOp, Regex},
8    SmtString,
9};
10
11/// The result of sampling from a regex or automaton.
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub enum SampleResult {
14    /// Founda word in the language
15    Sampled(SmtString),
16    /// The language tried to sample from is empty
17    Empty,
18    /// Maximum depth was reached without finding a word.
19    MaxDepth,
20}
21
22impl SampleResult {
23    /// Unwraps the sampled string.
24    /// Panics if the sampling was not successfull.
25    pub fn unwrap(self) -> SmtString {
26        match self {
27            SampleResult::Sampled(s) => s,
28            _ => panic!("called `unwrap` on empty value"),
29        }
30    }
31
32    /// Return true if sampling was successfull and this result carries a value.
33    /// Othwerwise returns false.
34    pub fn success(&self) -> bool {
35        matches!(self, SampleResult::Sampled(_))
36    }
37}
38
39/// Tries to sample a word that is accepted by the regex.
40/// The function aborts if no word is found after `max_depth` steps.
41/// If `comp` is set to `true`, the function will return a word that is not accepted by the regex.
42/// In other words, the function will sample a word from the complement of the regex's language.
43pub fn sample_regex(
44    regex: &Regex,
45    builder: &mut ReBuilder,
46    max_depth: usize,
47    comp: bool,
48) -> SampleResult {
49    fn fast_sample(re: &Regex, d: usize, max: usize) -> SampleResult {
50        if d > max {
51            return SampleResult::MaxDepth;
52        }
53        match re.op() {
54            ReOp::Literal(w) => SampleResult::Sampled(w.clone()),
55            ReOp::Range(r) => {
56                if let Some(r) = r.choose().map(|c| c.into()) {
57                    SampleResult::Sampled(r)
58                } else {
59                    SampleResult::Empty
60                }
61            }
62            ReOp::None => SampleResult::Empty,
63            ReOp::Any | ReOp::All => SampleResult::Sampled(SmtString::from("a")),
64            ReOp::Concat(rs) => {
65                let mut res = SmtString::empty();
66                for r in rs {
67                    match fast_sample(r, d + 1, max) {
68                        SampleResult::Sampled(s) => res.append(&s),
69                        SampleResult::Empty => return SampleResult::Empty,
70                        SampleResult::MaxDepth => return SampleResult::MaxDepth,
71                    }
72                }
73                SampleResult::Sampled(res)
74            }
75            ReOp::Comp(comped) => match comped.op() {
76                ReOp::Literal(s) => {
77                    if s.is_empty() {
78                        SampleResult::Sampled("a".into())
79                    } else {
80                        SampleResult::Sampled(SmtString::empty())
81                    }
82                }
83                ReOp::Range(range) => {
84                    for c in range.complement() {
85                        if let Some(c) = c.choose() {
86                            return SampleResult::Sampled(c.into());
87                        }
88                    }
89                    SampleResult::Empty
90                }
91                ReOp::None => SampleResult::Sampled("a".into()),
92                ReOp::Any => SampleResult::Sampled("aa".into()),
93                ReOp::All => SampleResult::Empty,
94                ReOp::Comp(r) => fast_sample(r, d + 1, max), // Double complement
95                _ => SampleResult::MaxDepth,
96            },
97            ReOp::Union(rs) => {
98                let mut max_reached = false;
99                for r in rs {
100                    match fast_sample(r, d + 1, max) {
101                        SampleResult::Sampled(s) => return SampleResult::Sampled(s),
102                        SampleResult::Empty => (),
103                        SampleResult::MaxDepth => max_reached = true,
104                    }
105                }
106                if max_reached {
107                    SampleResult::MaxDepth
108                } else {
109                    SampleResult::Empty
110                }
111            }
112            ReOp::Star(_) | ReOp::Opt(_) => SampleResult::Sampled(SmtString::empty()),
113            ReOp::Plus(r) => fast_sample(r, d + 1, max),
114            ReOp::Pow(r, e) => match fast_sample(r, d + 1, max) {
115                SampleResult::Sampled(s) => SampleResult::Sampled(s.repeat(*e as usize)),
116                SampleResult::Empty => SampleResult::Empty,
117                SampleResult::MaxDepth => SampleResult::MaxDepth,
118            },
119            ReOp::Loop(r, l, u) if l <= u => match fast_sample(r, d + 1, max) {
120                SampleResult::Sampled(s) => SampleResult::Sampled(s.repeat(*l as usize)),
121                SampleResult::Empty => SampleResult::Empty,
122                SampleResult::MaxDepth => SampleResult::MaxDepth,
123            },
124            ReOp::Loop(_, _, _) => SampleResult::Empty,
125            _ => SampleResult::MaxDepth,
126        }
127    }
128
129    if !comp {
130        match fast_sample(regex, 0, max_depth) {
131            SampleResult::Sampled(s) => return SampleResult::Sampled(s),
132            SampleResult::Empty => return SampleResult::Empty,
133            SampleResult::MaxDepth => (),
134        }
135    }
136
137    let mut w = SmtString::empty();
138    let mut deriver = DerivativeBuilder::default();
139
140    let mut i = 0;
141    let mut re = regex.clone();
142
143    let done = |re: &Regex| {
144        if comp {
145            !re.nullable()
146        } else {
147            re.nullable()
148        }
149    };
150
151    if done(&re) {
152        return SampleResult::Sampled(w);
153    }
154
155    while !done(&re) && i < max_depth {
156        let next = if let Some(c) = re
157            .first()
158            .iter()
159            .choose(&mut rng())
160            .and_then(|c| c.choose())
161        {
162            c
163        } else {
164            return SampleResult::Empty;
165        };
166        w.push(next);
167        re = deriver.deriv(&re, next, builder);
168        i += 1;
169    }
170
171    if done(&re) {
172        SampleResult::Sampled(w)
173    } else {
174        SampleResult::MaxDepth
175    }
176}
177
178/// Tries to sample a word that is accepted or not accepted by the NFA.
179/// Randomly picks transitions to follow until a final state is reached.
180/// Once a final state is reached, the function returns the word that was sampled.
181/// The function aborts if no word is found after `max_depth` transitions.
182/// If `comp` is set to `true`, the function will return a word that is not accepted by the NFA.
183/// In other words, the function will sample a word from the complement of the NFA's language.
184///
185/// The NFA should be trim. Othwerwise the function returns `SampleResult::Empty` even though
186/// it is not. That happens if it runs into a state from which is cannot make progress anymore.
187/// Such states do not occur in trim automata.
188pub fn sample_nfa(nfa: &NFA, max: usize, comp: bool) -> SampleResult {
189    let mut w = SmtString::empty();
190    let mut states = BitSet::new();
191    if let Some(q0) = nfa.initial() {
192        states = BitSet::from_iter(nfa.epsilon_closure(q0).unwrap());
193    }
194    let mut i = 0;
195
196    let done = |s: &BitSet| {
197        if comp {
198            !s.iter().any(|q| nfa.is_final(q))
199        } else {
200            s.iter().any(|q| nfa.is_final(q))
201        }
202    };
203
204    while i <= max {
205        i += 1;
206        // Check if the current state set contains a final state
207        if done(&states) {
208            return SampleResult::Sampled(w);
209        }
210
211        // Collect all transitions from the current state set
212        let mut transitions = Vec::new();
213        for q in states.iter() {
214            transitions.extend(nfa.transitions_from(q).unwrap());
215        }
216        // Pick a random transition
217        let transition = match transitions.iter().choose(&mut rng()) {
218            Some(t) => t,
219            None => return SampleResult::Empty,
220        };
221        // Pick a random character from the transition
222        let c = match transition.get_type() {
223            TransitionType::Range(r) => r.choose(),
224            TransitionType::NotRange(nr) => {
225                let r = nr.complement();
226                r.into_iter()
227                    .filter(|r| !r.is_empty())
228                    .choose(&mut rng())
229                    .and_then(|r| r.choose())
230            }
231            TransitionType::Epsilon => None,
232        };
233        match c {
234            Some(c) => {
235                w.push(c);
236                // set the next state set to the epsilon closure of the destination state
237                states = BitSet::from_iter(
238                    states
239                        .iter()
240                        .flat_map(|s| nfa.consume(s, c))
241                        .flatten()
242                        .flat_map(|q| nfa.epsilon_closure(q))
243                        .flatten(),
244                );
245            }
246            None => continue,
247        }
248    }
249
250    SampleResult::MaxDepth
251}
252
253#[cfg(test)]
254mod tests {
255
256    use quickcheck_macros::quickcheck;
257    use smallvec::smallvec;
258
259    use crate::alphabet::CharRange;
260
261    use super::*;
262
263    #[test]
264    fn sample_const() {
265        let mut builder = ReBuilder::default();
266        let regex = builder.to_re("foo".into());
267
268        assert_eq!(
269            sample_regex(&regex, &mut builder, 3, false).unwrap(),
270            "foo".into()
271        );
272        assert_eq!(
273            sample_regex(&regex, &mut builder, 10, false).unwrap(),
274            "foo".into()
275        );
276    }
277
278    #[test]
279    fn sample_with_optional_characters() {
280        let mut builder = ReBuilder::default();
281
282        // fo(o|bar)
283        let o = builder.to_re("o".into());
284        let fo = builder.to_re("fo".into());
285        let bar = builder.to_re("bar".into());
286        let o_or_bar = builder.union(smallvec![o, bar]);
287        let regex = builder.concat(smallvec![fo, o_or_bar]);
288
289        // Test matching "foo"
290        assert!(sample_regex(&regex, &mut builder, 5, false).success());
291    }
292
293    #[quickcheck]
294    fn sample_with_character_range(range: CharRange) {
295        let mut builder = ReBuilder::default();
296        let regex = builder.range(range);
297
298        assert!(sample_regex(&regex, &mut builder, 1, false).success());
299        // Test matching word within the class
300        assert!(sample_regex(&regex, &mut builder, 3, false).success());
301    }
302
303    #[quickcheck]
304    fn sample_character_range(range: CharRange) {
305        let mut builder = ReBuilder::default();
306        let regex = builder.range(range);
307
308        assert!(sample_regex(&regex, &mut builder, 1, false).success());
309        // Test matching word within the class
310        assert!(sample_regex(&regex, &mut builder, 3, false).success());
311    }
312
313    #[quickcheck]
314    fn sample_character_range_pow(range: CharRange, n: u32) {
315        let n = n % 100;
316        let mut builder = ReBuilder::default();
317        let regex = builder.range(range);
318        let regex = builder.pow(regex, n);
319
320        assert!(sample_regex(&regex, &mut builder, n as usize, false).success());
321    }
322
323    #[quickcheck]
324    fn sample_alternatives(rs: Vec<CharRange>) {
325        let n = rs.len();
326        let mut builder = ReBuilder::default();
327        let rs = rs.into_iter().map(|r| builder.range(r)).collect();
328        let regex = builder.union(rs);
329
330        if n > 0 {
331            assert!(sample_regex(&regex, &mut builder, 1, false).success());
332        } else {
333            assert!(!sample_regex(&regex, &mut builder, 10, false).success());
334        }
335    }
336
337    #[test]
338    fn sampling_alternatives_bug() {
339        let rs = vec![
340            //CharRange::new(76887, 179877),
341            //CharRange::new(142686, 186533),
342            //CharRange::new(51684, 146039),
343            CharRange::new(2u32, 5u32),
344            CharRange::new(3u32, 6u32),
345            CharRange::new(1u32, 4u32),
346        ];
347
348        //  CharRange  CharRange { start: SmtChar(51684), end: SmtChar(146039) }])]
349        let n = rs.len();
350        let mut builder = ReBuilder::default();
351        let rs = rs.into_iter().map(|r| builder.range(r)).collect();
352        let regex = builder.union(rs);
353
354        if n > 0 {
355            assert!(sample_regex(&regex, &mut builder, 1, false).success());
356        } else {
357            assert!(!sample_regex(&regex, &mut builder, 10, false).success());
358        }
359    }
360
361    #[quickcheck]
362    fn sample_opt(r: CharRange) {
363        let mut builder = ReBuilder::default();
364        let r = builder.range(r);
365        let regex = builder.opt(r);
366
367        assert!(sample_regex(&regex, &mut builder, 0, false).success());
368        assert!(sample_regex(&regex, &mut builder, 1, false).success());
369    }
370
371    #[test]
372    fn sample_empty_string() {
373        let mut builder = ReBuilder::default();
374        let regex = builder.epsilon();
375
376        assert!(sample_regex(&regex, &mut builder, 0, false).success());
377    }
378
379    #[test]
380    fn sample_empty_regex() {
381        let mut builder = ReBuilder::default();
382        let regex = builder.none();
383
384        assert_eq!(
385            sample_regex(&regex, &mut builder, 0, false),
386            SampleResult::Empty
387        );
388        assert_eq!(
389            sample_regex(&regex, &mut builder, 20, false),
390            SampleResult::Empty
391        );
392    }
393
394    #[test]
395    fn sample_all() {
396        let mut builder = ReBuilder::default();
397        let regex = builder.all();
398
399        assert!(sample_regex(&regex, &mut builder, 0, false).success());
400        assert!(sample_regex(&regex, &mut builder, 20, false).success());
401    }
402
403    #[test]
404    fn sample_any() {
405        let mut builder = ReBuilder::default();
406        let regex = builder.allchar();
407        assert!(sample_regex(&regex, &mut builder, 20, false).success());
408    }
409
410    #[test]
411    fn test_sample_nfa_accepts_word() {
412        let mut nfa = NFA::new();
413        let q0 = nfa.new_state();
414        let q1 = nfa.new_state();
415
416        nfa.set_initial(q0).unwrap();
417        nfa.add_final(q1).unwrap();
418
419        nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'a')))
420            .unwrap();
421
422        let sample = sample_nfa(&nfa, 10, false);
423        assert_eq!(sample, SampleResult::Sampled(SmtString::from("a")));
424    }
425
426    #[test]
427    fn test_sample_nfa_rejects_unreachable_final_state() {
428        let mut nfa = NFA::new();
429        let q0 = nfa.new_state();
430        let q1 = nfa.new_state(); // Final state, but not reachable
431
432        nfa.set_initial(q0).unwrap();
433        nfa.add_final(q1).unwrap();
434
435        let sample = sample_nfa(&nfa, 10, false);
436        assert_eq!(sample, SampleResult::Empty);
437    }
438
439    #[test]
440    fn test_sample_nfa_handles_epsilon_transitions() {
441        let mut nfa = NFA::new();
442        let q0 = nfa.new_state();
443        let q1 = nfa.new_state();
444        let q2 = nfa.new_state();
445
446        nfa.set_initial(q0).unwrap();
447        nfa.add_final(q2).unwrap();
448
449        nfa.add_transition(q0, q1, TransitionType::Epsilon).unwrap();
450        nfa.add_transition(q1, q2, TransitionType::Range(CharRange::new('b', 'b')))
451            .unwrap();
452
453        let sample = sample_nfa(&nfa, 10, false);
454        assert_eq!(sample, SampleResult::Sampled(SmtString::from("b")));
455    }
456
457    #[test]
458    fn test_sample_nfa_stops_at_max_depth() {
459        let mut nfa = NFA::new();
460        let q0 = nfa.new_state();
461        let q1 = nfa.new_state();
462        let q2 = nfa.new_state();
463
464        nfa.set_initial(q0).unwrap();
465        nfa.add_final(q2).unwrap();
466
467        // Large range transition that makes random sampling harder
468        nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'z')))
469            .unwrap();
470        nfa.add_transition(q1, q2, TransitionType::Range(CharRange::new('a', 'z')))
471            .unwrap();
472
473        let sample = sample_nfa(&nfa, 1, false); // Very low max depth
474        assert_eq!(sample, SampleResult::MaxDepth); // Should not reach q2 in one step
475    }
476
477    #[test]
478    fn test_sample_nfa_handles_not_range_transitions() {
479        let mut nfa = NFA::new();
480        let q0 = nfa.new_state();
481        let q1 = nfa.new_state();
482
483        nfa.set_initial(q0).unwrap();
484        nfa.add_final(q1).unwrap();
485
486        nfa.add_transition(q0, q1, TransitionType::NotRange(CharRange::new('x', 'z')))
487            .unwrap();
488
489        let sample = sample_nfa(&nfa, 10, false);
490        assert!(sample.success()); // Should produce a valid word
491        if let SampleResult::Sampled(word) = sample {
492            assert!(
493                !word.contains_char('x') && !word.contains_char('y') && !word.contains_char('z')
494            );
495        }
496    }
497
498    #[test]
499    fn test_sample_nfa_multiple_paths() {
500        let mut nfa = NFA::new();
501        let q0 = nfa.new_state();
502        let q1 = nfa.new_state();
503        let q2 = nfa.new_state();
504        let q3 = nfa.new_state();
505
506        nfa.set_initial(q0).unwrap();
507        nfa.add_final(q3).unwrap();
508
509        nfa.add_transition(q0, q1, TransitionType::Range(CharRange::new('a', 'a')))
510            .unwrap();
511        nfa.add_transition(q1, q3, TransitionType::Range(CharRange::new('b', 'b')))
512            .unwrap();
513        nfa.add_transition(q0, q2, TransitionType::Range(CharRange::new('x', 'x')))
514            .unwrap();
515        nfa.add_transition(q2, q3, TransitionType::Range(CharRange::new('y', 'y')))
516            .unwrap();
517
518        let sample = sample_nfa(&nfa, 10, false);
519        assert!(
520            sample == SampleResult::Sampled(SmtString::from("ab"))
521                || sample == SampleResult::Sampled(SmtString::from("xy"))
522        );
523    }
524
525    #[test]
526    fn test_sample_nfa_leaves_loops() {
527        let mut nfa = NFA::new();
528        let q0 = nfa.new_state();
529        let q1 = nfa.new_state();
530
531        nfa.set_initial(q0).unwrap();
532        nfa.add_final(q1).unwrap();
533
534        nfa.add_transition(q0, q0, TransitionType::Range(CharRange::singleton('a')))
535            .unwrap();
536        nfa.add_transition(q0, q1, TransitionType::Range(CharRange::singleton('b')))
537            .unwrap();
538
539        match sample_nfa(&nfa, 100, false) {
540            SampleResult::Sampled(w) => {
541                let l = w.len();
542                let mut expected = SmtString::from("a").repeat(l - 1);
543                expected.push('b');
544                assert_eq!(w, expected);
545            }
546            _ => unreachable!("Sample should not return None"),
547        }
548    }
549}