hashed_permutation/
kensler.rs

1//! The module for the hashed permutation implementation and the struct that stores its state.
2//!
3//! This method was first conceived by Andrew Kensler of Pixar Research, and discussed in his 2013
4//! [paper](https://graphics.pixar.com/library/MultiJitteredSampling/paper.pdf)
5//! on correlated multi-jittered sampling.
6
7use crate::error::{PermutationError, PermutationResult};
8#[cfg(feature = "use-rand")]
9use rand::prelude::*;
10use std::num::NonZeroU32;
11
12/// The `HashedPermutation` struct stores the initial `seed` and `length` of the permutation
13/// vector. In other words, if you want to shuffle the numbers from `0..n`, then `length = n`.
14///
15/// Because the shuffle is performed using bit arithmetic, the fields have to be 32 bit integers.
16/// Unfortunately, larger types are not supported at this time.
17#[derive(Clone, Debug)]
18pub struct HashedPermutation {
19    /// The random seed that dictates which permutation you want to use. The shuffle is
20    /// deterministic, so using the same seed will yield the same permutation every time.
21    pub seed: u32,
22
23    /// The upper bound on the range of numbers to shuffle (from `0..length`). This value must be
24    /// greater zero, otherwise undefined behavior may occur.
25    pub length: NonZeroU32,
26}
27
28impl HashedPermutation {
29    /// Create a new instance of the hashed permutation with a random seed.
30    ///
31    /// This method creates a hashed permutation of some length and initializes the seed to some
32    /// random number created by Rust's `thread_rng`.
33    #[cfg(feature = "use-rand")]
34    pub fn new(length: NonZeroU32) -> Self {
35        // Uses thread-rng under the hood
36        let seed = rand::random();
37        HashedPermutation { length, seed }
38    }
39
40    /// Create a new instance of the hashed permutation given a length and seed
41    pub fn new_with_seed(length: NonZeroU32, seed: u32) -> Self {
42        HashedPermutation { length, seed }
43    }
44
45    /// Shuffle or permute a particular value.
46    ///
47    /// This method uses the technique described in Kensler's paper to perform an in-place shuffle
48    /// with no memory overhead.
49    // We disable the `unreadable_literal` because these literals are arbitrary and don't really
50    // need to be readable anyways.
51    #[allow(clippy::unreadable_literal)]
52    pub fn shuffle(&self, input: u32) -> PermutationResult<u32> {
53        if input >= self.length.get() {
54            return Err(PermutationError::ShuffleOutOfRange {
55                shuffle: input,
56                max_shuffle: self.length.get(),
57            });
58        }
59        let mut i = input;
60        let n = self.length.get();
61        let seed = self.seed;
62        let mut w = n - 1;
63        w |= w >> 1;
64        w |= w >> 2;
65        w |= w >> 4;
66        w |= w >> 8;
67        w |= w >> 16;
68
69        while i >= n {
70            i ^= seed;
71            i *= 0xe170893d;
72            i ^= seed >> 16;
73            i ^= (i & w) >> 4;
74            i ^= seed >> 8;
75            i *= 0x0929eb3f;
76            i ^= seed >> 23;
77            i ^= (i & w) >> 1;
78            i *= 1 | seed >> 27;
79            i *= 0x6935fa69;
80            i ^= (i & w) >> 11;
81            i *= 0x74dcb303;
82            i ^= (i & w) >> 2;
83            i *= 0x9e501cc3;
84            i ^= (i & w) >> 2;
85            i *= 0xc860a3df;
86            i &= w;
87            i ^= i >> 5;
88        }
89        Ok((i + seed) % n)
90    }
91}
92
93#[cfg(test)]
94mod test {
95    use super::*;
96    use std::collections::HashMap;
97
98    /// A convenient helper method that returns a pair of lengths and seeds (in that order).
99    ///
100    /// This method defines the lengths and the seeds for the test cases, since these are reused
101    /// in the tests, and it's best practice to consolidate them in one place so code is not
102    /// repeated.
103    fn lengths_and_seeds() -> (Vec<NonZeroU32>, Vec<u32>) {
104        let lengths: Vec<NonZeroU32> = vec![100, 5, 13, 128, 249]
105            .iter()
106            .map(|&x| NonZeroU32::new(x).unwrap())
107            .collect();
108        let seeds = vec![100, 5, 13, 128, 249];
109        assert_eq!(lengths.len(), seeds.len());
110        (lengths, seeds)
111    }
112
113    #[test]
114    // This method is a sanity check that tests to see if a shuffle has points that all stay within
115    // the domain that they are supposed to.
116    fn test_domain() {
117        let (lengths, seeds) = lengths_and_seeds();
118
119        for (&length, seed) in lengths.iter().zip(seeds) {
120            let perm = HashedPermutation { seed, length };
121
122            for i in 0..perm.length.get() {
123                let res = perm.shuffle(i);
124                assert!(res.is_ok());
125                assert!(res.unwrap() < perm.length.get());
126            }
127        }
128    }
129
130    #[test]
131    // This method checks to see that a permutation does not have any collisions and that every
132    // number maps to another unique number. In other words, we are testing to see whether we have
133    // a bijective function.
134    fn test_bijection() {
135        let (lengths, seeds) = lengths_and_seeds();
136
137        for (length, seed) in lengths.iter().zip(seeds) {
138            let perm = HashedPermutation {
139                seed,
140                length: *length,
141            };
142
143            // Check that each entry doesn't exist
144            // Check that every number is "hit" (as they'd have to be) for a perfect bijection
145            // Check that the number is within range
146            let mut map = HashMap::new();
147
148            for i in 0..perm.length.get() {
149                let res = perm.shuffle(i);
150                let res = res.unwrap();
151                let map_result = map.get(&res);
152                assert!(map_result.is_none());
153                map.insert(res, i);
154            }
155            // Need to dereference the types into regular integers
156            let mut keys_vec: Vec<u32> = map.keys().into_iter().map(|k| *k).collect();
157            keys_vec.sort();
158            let mut vals_vec: Vec<u32> = map.values().into_iter().map(|v| *v).collect();
159            vals_vec.sort();
160            let ground_truth: Vec<u32> = (0..length.get()).collect();
161            assert_eq!(ground_truth, keys_vec);
162            assert_eq!(ground_truth, vals_vec);
163        }
164    }
165
166    #[test]
167    fn test_out_of_range() {
168        let lengths: Vec<NonZeroU32> = vec![1, 50, 256, 18]
169            .iter()
170            .map(|&x| NonZeroU32::new(x).unwrap())
171            .collect();
172        let offsets = vec![0, 1, 5, 15, 100];
173
174        for length in lengths {
175            let perm = HashedPermutation { seed: 0, length };
176
177            for offset in &offsets {
178                let result = perm.shuffle(length.get() + offset);
179                assert!(result.is_err());
180            }
181        }
182    }
183}