Skip to main content

onepass_seed/expr/
node.rs

1use core::iter::once;
2use std::io::{Result, Write};
3
4use crypto_bigint::{CheckedSub, NonZero, One, U256, Word};
5use secrecy::{ExposeSecret, ExposeSecretMut, SecretBox};
6
7use super::{
8    Eval, EvalContext, chars::Chars, context::Context, generator::Generator,
9    util::u256_saturating_pow,
10};
11
12/// AST representation for [`Expr`][super::Expr] nodes.
13#[derive(Clone, Debug, PartialEq)]
14pub enum Node {
15    /// String literal.
16    Literal(Box<str>),
17
18    /// Character class (see [`Chars`].)
19    Chars(Chars),
20
21    /// Sequence of nodes.
22    List(Box<[Node]>),
23
24    /// Variable count from min to max.
25    Count(Box<Node>, u32, u32),
26
27    /// [`Generator`] call.
28    Generator(Generator),
29}
30
31impl EvalContext for Node {
32    type Context<'a> = Context<'a>;
33
34    fn size(&self, context: &Context) -> NonZero<U256> {
35        match *self {
36            Node::Literal(_) => NonZero::ONE,
37            Node::Chars(ref chars) => chars.size(),
38            Node::List(ref nodes) => {
39                NonZero::new(nodes.into_iter().fold(U256::ONE, |acc, node| {
40                    acc.saturating_mul(&node.size(context))
41                }))
42                .unwrap()
43            }
44
45            Node::Count(ref node, min, max) => {
46                let n = node.size(context);
47                if n.is_one().into() {
48                    return NonZero::new((max - min + 1).into()).unwrap();
49                }
50                // Closed form of n^k + … + n^l
51                //              = n^k (1 + … + n^(l-k))
52                //              = n^k (n^(l-k+1) - 1) / (n - 1)
53                //              = (n^(l+1) - n^k) / (n - 1)
54                let k = min;
55                let l = max;
56                let mut x = U256::ZERO;
57                u256_saturating_pow(&n, (l + 1).into(), &mut x);
58                let mut y = U256::ZERO;
59                u256_saturating_pow(&n, Word::from(k), &mut y);
60                if x == U256::MAX && y == U256::MAX {
61                    // Assume we got an overflow.
62                    return NonZero::MAX;
63                }
64                x = x.checked_sub(&y).unwrap();
65                let (x, rem) = x.div_rem(&NonZero::new(n.saturating_sub(&U256::ONE)).unwrap());
66                assert!(bool::from(rem.is_zero()));
67                NonZero::new(x).unwrap()
68            }
69
70            Node::Generator(ref generator) => generator.size(context),
71        }
72    }
73
74    fn write_to(
75        &self,
76        context: &Context,
77        w: &mut dyn Write,
78        index: &mut dyn ExposeSecretMut<U256>,
79    ) -> Result<()> {
80        match *self {
81            Node::Literal(ref s) => w.write_all(s.as_bytes()),
82            Node::Chars(ref chars) => chars.write_to(w, index),
83
84            Node::List(ref nodes) => nodes
85                .into_iter()
86                .try_fold(index, |index, node| {
87                    let mut node_index = SecretBox::init_with_mut(|node_index| {
88                        let index = index.expose_secret_mut();
89                        (*index, *node_index) = index.div_rem(&node.size(context));
90                    });
91                    node.write_to(context, w, &mut node_index)?;
92                    Ok(index)
93                })
94                .map(|index| {
95                    assert!(bool::from(index.expose_secret_mut().is_zero()));
96                }),
97
98            Node::Count(ref node, min, max) => {
99                let node = node.as_ref();
100                let base = SecretBox::init_with(|| node.size(context));
101                let mut count = min;
102                let mut n: SecretBox<U256> = SecretBox::init_with_mut(|n| {
103                    u256_saturating_pow(base.expose_secret(), Word::from(min), n)
104                });
105                let n = n.expose_secret_mut();
106                while n <= index.expose_secret_mut() {
107                    count += 1;
108                    *index.expose_secret_mut() -= *n;
109                    *n = n.saturating_mul(base.expose_secret());
110                }
111                assert!(count <= max);
112                for _ in 0..count {
113                    let mut node_index = SecretBox::init_with_mut(|node_index| {
114                        let index = index.expose_secret_mut();
115                        (*index, *node_index) = index.div_rem(base.expose_secret());
116                    });
117                    node.write_to(context, w, &mut node_index)?;
118                }
119                assert!(bool::from(index.expose_secret_mut().is_zero()));
120                Ok(())
121            }
122
123            Node::Generator(ref generator) => generator.write_to(context, w, index),
124        }
125    }
126}
127
128impl From<Chars> for Node {
129    fn from(chars: Chars) -> Self {
130        Node::Chars(chars)
131    }
132}
133
134impl From<Generator> for Node {
135    fn from(generator: Generator) -> Self {
136        Node::Generator(generator)
137    }
138}
139
140impl FromIterator<Node> for Node {
141    fn from_iter<T: IntoIterator<Item = Node>>(iter: T) -> Self {
142        let mut iter = iter.into_iter().peekable();
143        let Some(node) = iter.next() else {
144            return Node::List(Box::default());
145        };
146        if iter.peek().is_none() {
147            return node;
148        }
149        Node::List(once(node).chain(iter).collect())
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::{super::util::*, *};
156
157    use num_traits::PrimInt;
158
159    #[test]
160    fn test_counts() {
161        let context = Context::empty();
162
163        let tests = [
164            ("a", 1, 1, 0, Some(1)),
165            ("aa", 2, 2, 0, Some(1)),
166            ("a", 1, 5, 0, Some(5)),
167            ("aa", 1, 5, 1, None),
168            ("aaaaa", 1, 5, 4, None),
169            ("", 0, 1, 0, Some(2)),
170            ("a", 0, 1, 1, None),
171        ];
172        for (want, min, max, index, want_size) in tests {
173            let prim = Node::Literal("a".into());
174            let count = Node::Count(prim.into(), min, max);
175            assert_eq!(
176                want,
177                &format_at_ctx(&count, &context, U256::from_u32(index))
178            );
179            if let Some(size) = want_size {
180                assert_eq!(U256::from_u32(size), *count.size(&context));
181            }
182        }
183
184        let tests = [
185            ("a", 0),
186            ("b", 1),
187            ("aa", 26),
188            ("ba", 27),
189            ("zzzzz", 12356629),
190        ];
191        let prim = Chars::from_ranges([('a', 'z')]).into();
192        let count = Node::Count(Box::new(prim), 1, 5);
193        assert_eq!(U256::from_u32(12356630), *count.size(&context));
194        for (want, index) in tests {
195            assert_eq!(
196                want,
197                &format_at_ctx(&count, &context, U256::from_u32(index))
198            );
199        }
200
201        let tests = [
202            ("aa", 0),
203            ("ba", 1),
204            ("za", 25),
205            ("ab", 26),
206            ("zz", 675),
207            ("aaa", 676),
208            ("zzzzz", 12356603),
209        ];
210        let prim = Chars::from_ranges([('a', 'z')]).into();
211        let count = Node::Count(Box::new(prim), 2, 5);
212        assert_eq!(U256::from_u32(12356604), *count.size(&context));
213        for (want, index) in tests {
214            assert_eq!(
215                want,
216                &format_at_ctx(&count, &context, U256::from_u32(index))
217            );
218        }
219    }
220
221    #[test]
222    fn test_count_single() {
223        let context = Context::empty();
224        let literal = Node::Literal("a".into());
225        for (want, min, max, index) in [
226            ("", 0, 5, 0),
227            ("a", 0, 5, 1),
228            ("aaaaa", 0, 5, 5),
229            ("a", 1, 5, 0),
230            ("aaaa", 1, 5, 3),
231            ("aaaaa", 1, 5, 4),
232            ("aaaa", 4, 10, 0),
233            ("aaaaa", 4, 10, 1),
234        ] {
235            let count = Node::Count(Box::new(literal.clone()), min, max);
236            let index = U256::from_u32(index);
237            assert_eq!(want, &format_at_ctx(&count, &context, index));
238        }
239    }
240
241    #[test]
242    fn test_lists() {
243        let context = Context::empty();
244        let prim = || Chars::from_ranges([('a', 'z')]).into();
245        let tests = [
246            ("a", 1, 0),
247            ("b", 1, 1),
248            ("z", 1, 25),
249            ("aa", 2, 0),
250            ("ba", 2, 1),
251            ("za", 2, 25),
252            ("ab", 2, 26),
253            ("zz", 2, 675),
254            ("aaaaa", 5, 0),
255        ];
256        for (want, rep, index) in tests {
257            let node: Node = (0..rep).map(|_| prim()).collect();
258            let size = 26.pow(rep as u32);
259            assert_eq!(U256::from_u32(size), *node.size(&context));
260            assert_eq!(want, &format_at_ctx(&node, &context, U256::from_u32(index)));
261        }
262    }
263
264    #[test]
265    fn test_generators() {
266        let context = Context::default();
267        let node = Node::from(Generator::new("word"));
268        assert_eq!(U256::from_u32(7776), *node.size(&context));
269        assert_eq!("abacus", &format_at_ctx(&node, &context, U256::ZERO));
270        assert_eq!(
271            "zoom",
272            &format_at_ctx(&node, &context, U256::from_u32(7775))
273        );
274    }
275}