1use integral_exponential_polynomial::Polynomial;
10use num_bigint::{BigInt, BigUint};
11use num_integer::Integer;
12use num_traits::{One, ToPrimitive};
13use std::collections::HashMap;
14use std::fmt;
15
16type CountingFunctionCacheInternal = HashMap<Vec<usize>, Polynomial>;
17
18#[derive(Clone, Default)]
20pub struct CountingFunctionCache {
21 internal: CountingFunctionCacheInternal,
22}
23
24impl fmt::Debug for CountingFunctionCache {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 for (k, v) in self.internal.iter() {
27 writeln!(f, "{:?}: {}", k, v)?;
28 }
29 Ok(())
30 }
31}
32
33fn get_normal_tuple_counting_function(sets: &[usize]) -> Polynomial {
34 let sum: usize = sets.iter().sum();
35 Polynomial::one_term(BigInt::one(), BigUint::from(sum))
36}
37
38fn generate_counting_function(
39 cache: &mut CountingFunctionCacheInternal,
40 sets: &[usize],
41 mask: usize,
42) -> Polynomial {
43 let sets = sets
44 .iter()
45 .enumerate()
46 .filter_map(|(index, &item)| {
47 if mask & 1 << index != 0 {
48 Some(item)
49 } else {
50 None
51 }
52 })
53 .collect::<Vec<usize>>();
54
55 if let Some(result) = cache.get(&sets) {
56 return result.clone();
57 }
58 let mask_all = (1 << sets.len()) - 1;
59 let mut result = get_normal_tuple_counting_function(&sets);
60 for mask in 1..mask_all {
65 result -= generate_counting_function(cache, &sets, mask);
66 }
67 cache.insert(sets, result.clone());
68 result
69}
70
71#[derive(Clone)]
73pub struct Counter {
74 polynomials: Vec<Polynomial>,
75}
76
77impl Counter {
78 pub fn new(cache: &mut CountingFunctionCache, sets: &[usize]) -> Counter {
80 let mask_num = 1 << sets.len();
81 let mask_all = mask_num - 1;
82 let result_base = get_normal_tuple_counting_function(&sets);
83 let polynomials = (0..mask_num)
84 .map(|mask| {
85 let mut result = result_base.clone();
86 for submask in 1..mask_all {
90 if submask | mask != mask_all {
91 result -= generate_counting_function(&mut cache.internal, &sets, submask);
92 }
93 }
94 result
95 })
96 .collect::<Vec<_>>();
97 Counter { polynomials }
98 }
99
100 fn count(&self, present_mask: usize, length: usize) -> BigUint {
101 let result = self.polynomials[present_mask].apply(length);
102 result.to_biguint().unwrap()
103 }
104
105 pub fn count_total(&self, length: usize) -> BigUint {
108 self.count(0, length)
109 }
110}
111
112#[derive(Clone)]
114pub struct Generator<'a, T: Clone> {
115 sets: &'a [Vec<T>],
116 counter: Counter,
117}
118
119impl<'a, T: 'a + Clone> Generator<'a, T> {
120 pub fn new(cache: &mut CountingFunctionCache, sets: &'a [Vec<T>]) -> Generator<'a, T> {
131 let num_sets = sets.iter().map(|s| s.len()).collect::<Vec<_>>();
132 let counter = Counter::new(cache, &num_sets);
133 Generator { sets, counter }
134 }
135
136 fn generate_internal<CountFn>(
137 &self,
138 length: usize,
139 value: BigUint,
140 count_fn: &mut CountFn,
141 ) -> Vec<T>
142 where
143 CountFn: FnMut(usize, usize) -> BigUint,
144 {
145 let mut result = Vec::with_capacity(length);
146 let mut present_mask = 0;
147 let mut value = value % count_fn(0, length);
148 for i in 0..length {
149 let left_length = length - i - 1;
150 for (j, set) in self.sets.iter().enumerate() {
151 let next_mask = present_mask | 1 << j;
152 let item_span = count_fn(next_mask, left_length);
153 let total_span = item_span.clone() * BigUint::from(set.len());
154 if value >= total_span {
155 value -= total_span;
156 continue;
157 }
158 let (quo, rem) = value.div_rem(&item_span);
159 value = rem;
160 present_mask = next_mask;
161 result.push(set[quo.to_usize().unwrap()].clone());
162 break;
163 }
164 }
165 result
166 }
167
168 pub fn generate(&self, length: usize, value: BigUint) -> Vec<T> {
177 self.generate_internal(length, value, &mut |mask, len| {
178 self.counter.count(mask, len)
179 })
180 }
181}
182
183#[derive(Clone)]
189pub struct GeneratorWithCache<'a, T: Clone> {
190 generator: &'a Generator<'a, T>,
191 cache: Vec<Vec<Option<BigUint>>>,
192}
193
194impl<'a, T: 'a + Clone> GeneratorWithCache<'a, T> {
195 pub fn from_generator(generator: &'a Generator<'a, T>) -> GeneratorWithCache<'a, T> {
196 let mask_num = 1 << generator.sets.len();
197 GeneratorWithCache {
198 generator,
199 cache: vec![vec![]; mask_num],
200 }
201 }
202
203 pub fn generate(&mut self, length: usize, value: BigUint) -> Vec<T> {
204 self.generator
205 .generate_internal(length, value, &mut |mask, len| {
206 let cache_line = &mut self.cache[mask];
207 if cache_line.len() <= len {
208 let fillup_len = len - cache_line.len() + 1;
209 cache_line.append(&mut vec![None; fillup_len]);
210 } else if let Some(result) = cache_line[length].as_ref() {
211 return result.clone();
212 }
213 let result = self.generator.counter.count(mask, len);
214 cache_line[length] = Some(result.clone());
215 result
216 })
217 }
218}
219
220#[derive(Clone, Copy, Debug)]
221struct ItemState {
222 set: usize,
223 index: usize,
224 present_mask: usize,
225}
226
227#[derive(Clone)]
230pub struct Enumerator<'a, T: Clone> {
231 length: usize,
232 sets: &'a [Vec<T>],
233 next_state: Vec<ItemState>,
234}
235
236impl<'a, T: 'a + Clone> Enumerator<'a, T> {
237 pub fn new(sets: &'a [Vec<T>], length: usize) -> Enumerator<'a, T> {
238 let mut initial_state = vec![
239 ItemState {
240 set: 0,
241 index: 0,
242 present_mask: 1
243 };
244 length
245 ];
246 let mut mask = 1;
247 for i in 1..sets.len() {
248 mask |= 1 << i;
249 initial_state[length - sets.len() + i] = ItemState {
250 set: i,
251 index: 0,
252 present_mask: mask,
253 };
254 }
255 if !sets.iter().all(|set| !set.is_empty()) {
256 initial_state[0].set = sets.len();
262 }
263 Enumerator {
264 length,
265 sets,
266 next_state: initial_state,
267 }
268 }
269
270 fn find_next(&mut self) {
271 let mask_all = (1 << self.sets.len()) - 1;
272 loop {
273 let mut cur = self.length - 1;
274 loop {
276 let mut item = &mut self.next_state[cur];
277 item.index += 1;
278 if item.index == self.sets[item.set].len() {
279 item.index = 0;
281 item.set += 1;
282 if item.set == self.sets.len() {
283 if cur > 0 {
284 item.set = 0;
286 cur -= 1;
287 continue;
288 } else {
289 return;
291 }
292 }
293 }
294 break;
295 }
296 if cur == 0 {
298 let mut item = &mut self.next_state[0];
299 item.present_mask = 1 << item.set;
300 cur += 1;
301 }
302 for i in cur..self.length {
303 let prev_mask = self.next_state[i - 1].present_mask;
304 let mut item = &mut self.next_state[i];
305 item.present_mask = prev_mask | 1 << item.set;
306 }
307 if self.next_state.last().unwrap().present_mask == mask_all {
309 break;
310 }
311 }
312 }
313}
314
315impl<'a, T: 'a + Clone> Iterator for Enumerator<'a, T> {
316 type Item = Vec<T>;
317
318 fn next(&mut self) -> Option<Self::Item> {
319 if self.next_state[0].set >= self.sets.len() {
320 None
321 } else {
322 let result = self
323 .next_state
324 .iter()
325 .map(|item| self.sets[item.set][item.index].clone())
326 .collect();
327 self.find_next();
328 Some(result)
329 }
330 }
331}
332
333#[cfg(test)]
334mod test {
335 use super::{Counter, CountingFunctionCache, Enumerator, Generator};
336 use num_bigint::BigUint;
337 use num_iter::range;
338 use num_traits::{One, Zero};
339
340 fn test_sets1() -> Vec<Vec<u8>> {
341 vec![
342 vec!['A' as u8, 'B' as u8, 'C' as u8, 'D' as u8],
343 vec!['a' as u8, 'b' as u8, 'c' as u8],
344 vec!['0' as u8, '1' as u8],
345 ]
346 }
347
348 struct ResultCycleTest<'a> {
349 generator: &'a Generator<'a, u8>,
350 }
351 impl<'a> ResultCycleTest<'a> {
352 fn run(&self, length: usize) {
353 let count = self.generator.counter.count(0, length);
354 for value in range(BigUint::zero(), count.clone()) {
355 let value1 = count.clone() + value.clone();
356 assert_eq!(
357 self.generator.generate(length, value),
358 self.generator.generate(length, value1)
359 );
360 }
361 }
362 }
363
364 #[test]
365 fn result_cycle_test() {
366 let sets = test_sets1();
367 let mut cache = CountingFunctionCache::default();
368 let test = ResultCycleTest {
369 generator: &Generator::new(&mut cache, &sets),
370 };
371 test.run(3);
372 test.run(4);
373 }
374
375 struct ResultMatchTest<'a> {
376 generator: &'a Generator<'a, u8>,
377 }
378 impl<'a> ResultMatchTest<'a> {
379 fn run(&self, length: usize) {
380 let mut value = BigUint::zero();
381 for item in Enumerator::new(self.generator.sets, length) {
382 let result = self.generator.generate(length, value.clone());
383 assert_eq!(result, item);
384 value = value + BigUint::one();
385 }
386 assert_eq!(value, self.generator.counter.count(0, length));
387 }
388 }
389
390 #[test]
391 fn result_match_test() {
392 let sets = test_sets1();
393 let mut cache = CountingFunctionCache::default();
394 let test = ResultMatchTest {
395 generator: &Generator::new(&mut cache, &sets),
396 };
397 test.run(3);
398 test.run(4);
399 }
400
401 #[test]
402 fn zero_count_test() {
403 let mut cache = CountingFunctionCache::default();
404 let set_lens = vec![];
405 assert!(Counter::new(&mut cache, &set_lens).count(0, 5).is_zero());
406 let set_lens = vec![5, 0, 8];
407 assert!(Counter::new(&mut cache, &set_lens).count(0, 5).is_zero());
408 let set_lens = vec![5, 2, 8];
409 assert!(Counter::new(&mut cache, &set_lens).count(0, 2).is_zero());
410 let set_lens = vec![5, 1, 8];
411 assert!(!Counter::new(&mut cache, &set_lens).count(0, 5).is_zero());
412 }
413
414 fn panic_generating_test(length: usize, sets: Vec<Vec<u8>>) {
415 let mut cache = CountingFunctionCache::default();
416 let generator = Generator::new(&mut cache, &sets);
417 generator.generate(length, BigUint::zero());
418 }
419
420 #[test]
421 #[should_panic]
422 fn empty_sets_generating_test() {
423 panic_generating_test(5, vec![]);
424 }
425
426 #[test]
427 #[should_panic]
428 fn empty_set_in_sets_generating_test() {
429 panic_generating_test(
430 5,
431 vec![
432 vec!['0' as u8, '1' as u8],
433 vec![],
434 vec!['A' as u8, 'B' as u8],
435 ],
436 );
437 }
438
439 #[test]
440 #[should_panic]
441 fn length_not_enough_generating_test() {
442 panic_generating_test(
443 2,
444 vec![
445 vec!['0' as u8, '1' as u8],
446 vec!['A' as u8, 'B' as u8],
447 vec!['a' as u8, 'b' as u8],
448 ],
449 );
450 }
451}