1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//! The module for the hashed permutation implementation and the struct that stores its state.
//!
//! This method was first conceived by Andrew Kensler of Pixar Research, and discussed in his 2013
//! [paper](https://graphics.pixar.com/library/MultiJitteredSampling/paper.pdf)
//! on correlated multi-jittered sampling.

use crate::error::{PermutationError, PermutationResult};

/// The `HashedPermutation` struct stores the initial `seed` and `length` of the permutation
/// vector. In other words, if you want to shuffle the numbers from `0..n`, then `length = n`.
///
/// Because the shuffle is performed using bit arithmetic, the fields have to be 32 bit integers.
/// Unfortunately, larger types are not supported at this time.
#[derive(Clone, Debug)]
pub struct HashedPermutation {
    /// The random seed that dictates which permutation you want to use. The shuffle is
    /// deterministic, so using the same seed will yield the same permutation every time.
    pub seed: u32,

    /// The upper bound on the range of numbers to shuffle (from `0..length`). This value must be
    /// greater zero, otherwise undefined behavior may occur.
    pub length: u32,
}

impl HashedPermutation {
    /// Shuffle or permute a particular value.
    ///
    /// This method uses the technique described in Kensler's paper to perform an in-place shuffle
    /// with no memory overhead.
    pub fn shuffle(&self, input: u32) -> PermutationResult<u32> {
        if input >= self.length {
            return Err(PermutationError::ShuffleOutOfRange {
                shuffle: input,
                max_shuffle: self.length,
            });
        }
        let mut i = input;
        let n = self.length;
        let seed = self.seed;
        let mut w = n - 1;
        w |= w >> 1;
        w |= w >> 2;
        w |= w >> 4;
        w |= w >> 8;
        w |= w >> 16;

        while i >= n {
            i ^= seed;
            i *= 0xe170893d;
            i ^= seed >> 16;
            i ^= (i & w) >> 4;
            i ^= seed >> 8;
            i *= 0x0929eb3f;
            i ^= seed >> 23;
            i ^= (i & w) >> 1;
            i *= 1 | seed >> 27;
            i *= 0x6935fa69;
            i ^= (i & w) >> 11;
            i *= 0x74dcb303;
            i ^= (i & w) >> 2;
            i *= 0x9e501cc3;
            i ^= (i & w) >> 2;
            i *= 0xc860a3df;
            i &= w;
            i ^= i >> 5;
        }
        Ok((i + seed) % n)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use std::collections::HashMap;

    /// A convenient helper method that returns a pair of lengths and seeds (in that order).
    ///
    /// This method defines the lengths and the seeds for the test cases, since these are reused
    /// in the tests, and it's best practice to consolidate them in one place so code is not
    /// repeated.
    fn lengths_and_seeds() -> (Vec<u32>, Vec<u32>) {
        let lengths = vec![100, 5, 13, 128, 249];
        let seeds = vec![100, 5, 13, 128, 249];
        assert_eq!(lengths.len(), seeds.len());
        (lengths, seeds)
    }

    #[test]
    // This method is a sanity check that tests to see if a shuffle has points that all stay within
    // the domain that they are supposed to.
    fn test_domain() {
        let (lengths, seeds) = lengths_and_seeds();

        for (length, seed) in lengths.iter().zip(seeds) {
            let perm = HashedPermutation {
                seed,
                length: *length,
            };

            for i in 0..perm.length {
                let res = perm.shuffle(i);
                assert!(res.is_ok());
                assert!(res.unwrap() < perm.length);
            }
        }
    }

    #[test]
    // This method checks to see that a permutation does not have any collisions and that every
    // number maps to another unique number. In other words, we are testing to see whether we have
    // a bijective function.
    fn test_bijection() {
        let (lengths, seeds) = lengths_and_seeds();

        for (length, seed) in lengths.iter().zip(seeds) {
            let perm = HashedPermutation {
                seed,
                length: *length,
            };

            // Check that each entry doesn't exist
            // Check that every number is "hit" (as they'd have to be) for a perfect bijection
            // Check that the number is within range
            let mut map = HashMap::new();

            for i in 0..perm.length {
                let res = perm.shuffle(i);
                let res = res.unwrap();
                let map_result = map.get(&res);
                assert!(map_result.is_none());
                map.insert(res, i);
            }
            // Need to dereference the types into regular integers
            let mut keys_vec: Vec<u32> = map.keys().into_iter().map(|k| *k).collect();
            keys_vec.sort();
            let mut vals_vec: Vec<u32> = map.values().into_iter().map(|v| *v).collect();
            vals_vec.sort();
            let ground_truth: Vec<u32> = (0..*length).collect();
            assert_eq!(ground_truth, keys_vec);
            assert_eq!(ground_truth, vals_vec);
        }
    }

    #[test]
    fn test_out_of_range() {
        let lengths = vec![1, 50, 256, 18];
        let offsets = vec![0, 1, 5, 15, 100];

        for length in lengths {
            let perm = HashedPermutation { seed: 0, length };

            for offset in &offsets {
                let result = perm.shuffle(length + offset);
                assert!(result.is_err());
            }
        }
    }
}