1use core::iter::once;
2use std::io::{Result, Write};
3
4use crypto_bigint::{CheckedSub, NonZero, One, U256, Word};
5use secrecy::{ExposeSecret, ExposeSecretMut, SecretBox};
6
7use super::{
8 Eval, EvalContext, chars::Chars, context::Context, generator::Generator,
9 util::u256_saturating_pow,
10};
11
12#[derive(Clone, Debug, PartialEq)]
14pub enum Node {
15 Literal(Box<str>),
17
18 Chars(Chars),
20
21 List(Box<[Node]>),
23
24 Count(Box<Node>, u32, u32),
26
27 Generator(Generator),
29}
30
31impl EvalContext for Node {
32 type Context<'a> = Context<'a>;
33
34 fn size(&self, context: &Context) -> NonZero<U256> {
35 match *self {
36 Node::Literal(_) => NonZero::ONE,
37 Node::Chars(ref chars) => chars.size(),
38 Node::List(ref nodes) => {
39 NonZero::new(nodes.into_iter().fold(U256::ONE, |acc, node| {
40 acc.saturating_mul(&node.size(context))
41 }))
42 .unwrap()
43 }
44
45 Node::Count(ref node, min, max) => {
46 let n = node.size(context);
47 if n.is_one().into() {
48 return NonZero::new((max - min + 1).into()).unwrap();
49 }
50 let k = min;
55 let l = max;
56 let mut x = U256::ZERO;
57 u256_saturating_pow(&n, (l + 1).into(), &mut x);
58 let mut y = U256::ZERO;
59 u256_saturating_pow(&n, Word::from(k), &mut y);
60 if x == U256::MAX && y == U256::MAX {
61 return NonZero::MAX;
63 }
64 x = x.checked_sub(&y).unwrap();
65 let (x, rem) = x.div_rem(&NonZero::new(n.saturating_sub(&U256::ONE)).unwrap());
66 assert!(bool::from(rem.is_zero()));
67 NonZero::new(x).unwrap()
68 }
69
70 Node::Generator(ref generator) => generator.size(context),
71 }
72 }
73
74 fn write_to(
75 &self,
76 context: &Context,
77 w: &mut dyn Write,
78 index: &mut dyn ExposeSecretMut<U256>,
79 ) -> Result<()> {
80 match *self {
81 Node::Literal(ref s) => w.write_all(s.as_bytes()),
82 Node::Chars(ref chars) => chars.write_to(w, index),
83
84 Node::List(ref nodes) => nodes
85 .into_iter()
86 .try_fold(index, |index, node| {
87 let mut node_index = SecretBox::init_with_mut(|node_index| {
88 let index = index.expose_secret_mut();
89 (*index, *node_index) = index.div_rem(&node.size(context));
90 });
91 node.write_to(context, w, &mut node_index)?;
92 Ok(index)
93 })
94 .map(|index| {
95 assert!(bool::from(index.expose_secret_mut().is_zero()));
96 }),
97
98 Node::Count(ref node, min, max) => {
99 let node = node.as_ref();
100 let base = SecretBox::init_with(|| node.size(context));
101 let mut count = min;
102 let mut n: SecretBox<U256> = SecretBox::init_with_mut(|n| {
103 u256_saturating_pow(base.expose_secret(), Word::from(min), n)
104 });
105 let n = n.expose_secret_mut();
106 while n <= index.expose_secret_mut() {
107 count += 1;
108 *index.expose_secret_mut() -= *n;
109 *n = n.saturating_mul(base.expose_secret());
110 }
111 assert!(count <= max);
112 for _ in 0..count {
113 let mut node_index = SecretBox::init_with_mut(|node_index| {
114 let index = index.expose_secret_mut();
115 (*index, *node_index) = index.div_rem(base.expose_secret());
116 });
117 node.write_to(context, w, &mut node_index)?;
118 }
119 assert!(bool::from(index.expose_secret_mut().is_zero()));
120 Ok(())
121 }
122
123 Node::Generator(ref generator) => generator.write_to(context, w, index),
124 }
125 }
126}
127
128impl From<Chars> for Node {
129 fn from(chars: Chars) -> Self {
130 Node::Chars(chars)
131 }
132}
133
134impl From<Generator> for Node {
135 fn from(generator: Generator) -> Self {
136 Node::Generator(generator)
137 }
138}
139
140impl FromIterator<Node> for Node {
141 fn from_iter<T: IntoIterator<Item = Node>>(iter: T) -> Self {
142 let mut iter = iter.into_iter().peekable();
143 let Some(node) = iter.next() else {
144 return Node::List(Box::default());
145 };
146 if iter.peek().is_none() {
147 return node;
148 }
149 Node::List(once(node).chain(iter).collect())
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::{super::util::*, *};
156
157 use num_traits::PrimInt;
158
159 #[test]
160 fn test_counts() {
161 let context = Context::empty();
162
163 let tests = [
164 ("a", 1, 1, 0, Some(1)),
165 ("aa", 2, 2, 0, Some(1)),
166 ("a", 1, 5, 0, Some(5)),
167 ("aa", 1, 5, 1, None),
168 ("aaaaa", 1, 5, 4, None),
169 ("", 0, 1, 0, Some(2)),
170 ("a", 0, 1, 1, None),
171 ];
172 for (want, min, max, index, want_size) in tests {
173 let prim = Node::Literal("a".into());
174 let count = Node::Count(prim.into(), min, max);
175 assert_eq!(
176 want,
177 &format_at_ctx(&count, &context, U256::from_u32(index))
178 );
179 if let Some(size) = want_size {
180 assert_eq!(U256::from_u32(size), *count.size(&context));
181 }
182 }
183
184 let tests = [
185 ("a", 0),
186 ("b", 1),
187 ("aa", 26),
188 ("ba", 27),
189 ("zzzzz", 12356629),
190 ];
191 let prim = Chars::from_ranges([('a', 'z')]).into();
192 let count = Node::Count(Box::new(prim), 1, 5);
193 assert_eq!(U256::from_u32(12356630), *count.size(&context));
194 for (want, index) in tests {
195 assert_eq!(
196 want,
197 &format_at_ctx(&count, &context, U256::from_u32(index))
198 );
199 }
200
201 let tests = [
202 ("aa", 0),
203 ("ba", 1),
204 ("za", 25),
205 ("ab", 26),
206 ("zz", 675),
207 ("aaa", 676),
208 ("zzzzz", 12356603),
209 ];
210 let prim = Chars::from_ranges([('a', 'z')]).into();
211 let count = Node::Count(Box::new(prim), 2, 5);
212 assert_eq!(U256::from_u32(12356604), *count.size(&context));
213 for (want, index) in tests {
214 assert_eq!(
215 want,
216 &format_at_ctx(&count, &context, U256::from_u32(index))
217 );
218 }
219 }
220
221 #[test]
222 fn test_count_single() {
223 let context = Context::empty();
224 let literal = Node::Literal("a".into());
225 for (want, min, max, index) in [
226 ("", 0, 5, 0),
227 ("a", 0, 5, 1),
228 ("aaaaa", 0, 5, 5),
229 ("a", 1, 5, 0),
230 ("aaaa", 1, 5, 3),
231 ("aaaaa", 1, 5, 4),
232 ("aaaa", 4, 10, 0),
233 ("aaaaa", 4, 10, 1),
234 ] {
235 let count = Node::Count(Box::new(literal.clone()), min, max);
236 let index = U256::from_u32(index);
237 assert_eq!(want, &format_at_ctx(&count, &context, index));
238 }
239 }
240
241 #[test]
242 fn test_lists() {
243 let context = Context::empty();
244 let prim = || Chars::from_ranges([('a', 'z')]).into();
245 let tests = [
246 ("a", 1, 0),
247 ("b", 1, 1),
248 ("z", 1, 25),
249 ("aa", 2, 0),
250 ("ba", 2, 1),
251 ("za", 2, 25),
252 ("ab", 2, 26),
253 ("zz", 2, 675),
254 ("aaaaa", 5, 0),
255 ];
256 for (want, rep, index) in tests {
257 let node: Node = (0..rep).map(|_| prim()).collect();
258 let size = 26.pow(rep as u32);
259 assert_eq!(U256::from_u32(size), *node.size(&context));
260 assert_eq!(want, &format_at_ctx(&node, &context, U256::from_u32(index)));
261 }
262 }
263
264 #[test]
265 fn test_generators() {
266 let context = Context::default();
267 let node = Node::from(Generator::new("word"));
268 assert_eq!(U256::from_u32(7776), *node.size(&context));
269 assert_eq!("abacus", &format_at_ctx(&node, &context, U256::ZERO));
270 assert_eq!(
271 "zoom",
272 &format_at_ctx(&node, &context, U256::from_u32(7775))
273 );
274 }
275}