zerotrie 0.2.4

A data structure that efficiently maps strings to integers
Documentation
// This file is part of ICU4X. For terms of use, please see the file
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use super::*;
use crate::error::ZeroTrieBuildError;
use alloc::vec;
use alloc::vec::Vec;

/// To speed up the search algorithm, we limit the number of times the level-2 parameter (q)
/// can hit its max value (initially `Q_FAST_MAX`) before we try the next level-1 parameter (p).
/// In practice, this has a small impact on the resulting perfect hash, resulting in about
/// 1 in 10000 hash maps that fall back to the slow path.
const MAX_L2_SEARCH_MISSES: usize = 24;

/// Directly compute the perfect hash function.
///
/// Returns `(p, [q_0, q_1, ..., q_(N-1)])`, or an error if the PHF could not be computed.
#[allow(unused_labels)] // for readability
#[allow(clippy::indexing_slicing)] // carefully reviewed to not panic
pub fn find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError> {
    let n_usize = bytes.len();

    let mut p = 0u8;
    let mut qq = vec![0u8; n_usize];

    let mut bqs = vec![0u8; n_usize];
    let mut seen = vec![false; n_usize];
    let max_allowable_p = P_FAST_MAX;
    let mut max_allowable_q = Q_FAST_MAX;

    #[allow(non_snake_case)]
    let N = if n_usize > 0 && n_usize < 256 {
        n_usize as u8
    } else {
        debug_assert!(n_usize == 0 || n_usize == 256);
        return Ok((p, qq));
    };

    'p_loop: loop {
        // Vec of tuples: (index, bucket count)
        let mut buckets: Vec<(usize, Vec<u8>)> = (0..n_usize).map(|i| (i, vec![])).collect();
        for byte in bytes {
            let l1 = f1(*byte, p, N) as usize;
            buckets[l1].1.push(*byte);
        }
        buckets.sort_by_key(|(_, v)| -(v.len() as isize));
        // println!("New P: p={p:?}, buckets={buckets:?}");
        let mut i = 0;
        let mut num_max_q = 0;
        bqs.fill(0);
        seen.fill(false);
        'q_loop: loop {
            // Loop condition: exit when i is beyond the buckets length
            if i == buckets.len() {
                for (local_j, real_j) in buckets.iter().map(|(j, _)| *j).enumerate() {
                    debug_assert!(local_j < n_usize); // comes from .enumerate()
                    debug_assert!(real_j < n_usize); // first item of bucket tuple is an index
                    qq[real_j] = bqs[local_j];
                }
                // println!("Success: p={p:?}, num_max_q={num_max_q:?}, bqs={bqs:?}, qq={qq:?}");
                // if num_max_q > 0 {
                //     println!("num_max_q={num_max_q:?}");
                // }
                return Ok((p, qq));
            }
            let mut bucket = buckets[i].1.as_slice();
            'byte_loop: for (j, byte) in bucket.iter().enumerate() {
                let l2 = f2(*byte, bqs[i], N) as usize;
                if seen[l2] {
                    // println!("Skipping Q: p={p:?}, i={i:?}, byte={byte:}, q={i:?}, l2={:?}", f2(*byte, bqs[i], N));
                    for k_byte in &bucket[0..j] {
                        let l2 = f2(*k_byte, bqs[i], N) as usize;
                        assert!(seen[l2]);
                        seen[l2] = false;
                    }
                    'reset_loop: loop {
                        if bqs[i] < max_allowable_q {
                            bqs[i] += 1;
                            continue 'q_loop;
                        }
                        num_max_q += 1;
                        bqs[i] = 0;
                        if i == 0 || num_max_q > MAX_L2_SEARCH_MISSES {
                            if p == max_allowable_p && max_allowable_q != Q_REAL_MAX {
                                // println!("Could not solve fast function: trying again: {bytes:?}");
                                max_allowable_q = Q_REAL_MAX;
                                p = 0;
                                continue 'p_loop;
                            } else if p == max_allowable_p {
                                // If a fallback algorithm for `p` is added, relax this assertion
                                // and re-run the loop with a higher `max_allowable_p`.
                                debug_assert_eq!(max_allowable_p, P_REAL_MAX);
                                // println!("Could not solve PHF function");
                                return Err(ZeroTrieBuildError::CouldNotSolvePerfectHash);
                            } else {
                                p += 1;
                                continue 'p_loop;
                            }
                        }
                        i -= 1;
                        bucket = buckets[i].1.as_slice();
                        for byte in bucket {
                            let l2 = f2(*byte, bqs[i], N) as usize;
                            assert!(seen[l2]);
                            seen[l2] = false;
                        }
                    }
                } else {
                    // println!("Marking as seen: i={i:?}, byte={byte:}, l2={:?}", f2(*byte, bqs[i], N));
                    let l2 = f2(*byte, bqs[i], N) as usize;
                    seen[l2] = true;
                }
            }
            // println!("Found Q: i={i:?}, q={:?}", bqs[i]);
            i += 1;
        }
    }
}

impl PerfectByteHashMap<Vec<u8>> {
    /// Computes a new [`PerfectByteHashMap`].
    ///
    /// (this is a doc-hidden API)
    #[allow(clippy::indexing_slicing)] // carefully reviewed to not panic
    pub fn try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError> {
        let n_usize = keys.len();
        let n = n_usize as u8;
        let (p, mut qq) = find(keys)?;
        let mut keys_permuted = vec![0; n_usize];
        for key in keys {
            let l1 = f1(*key, p, n) as usize;
            let q = qq[l1];
            let l2 = f2(*key, q, n) as usize;
            keys_permuted[l2] = *key;
        }
        let mut result = Vec::with_capacity(n_usize * 2 + 1);
        result.push(p);
        result.append(&mut qq);
        result.append(&mut keys_permuted);
        Ok(Self(result))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    extern crate std;
    use std::print;
    use std::println;

    fn print_byte_to_stdout(byte: u8) {
        let c = char::from(byte);
        if c.is_ascii_alphanumeric() {
            print!("'{c}'");
        } else {
            print!("0x{byte:X}");
        }
    }

    fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
        use rand::seq::SliceRandom;
        use rand::SeedableRng;
        let mut bytes: Vec<u8> =
            b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".into();
        let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
        bytes.partial_shuffle(&mut rng, len).0.into()
    }

    #[test]
    fn test_random_distributions() {
        let mut p_distr = vec![0; 256];
        let mut q_distr = vec![0; 256];
        for len in 0..50 {
            for seed in 0..50 {
                let bytes = random_alphanums(seed, len);
                let (p, qq) = find(bytes.as_slice()).unwrap();
                p_distr[p as usize] += 1;
                for q in qq {
                    q_distr[q as usize] += 1;
                }
            }
        }
        println!("p_distr: {p_distr:?}");
        println!("q_distr: {q_distr:?}");

        let fast_p = p_distr[0..=P_FAST_MAX as usize].iter().sum::<usize>();
        let slow_p = p_distr[(P_FAST_MAX + 1) as usize..].iter().sum::<usize>();
        let fast_q = q_distr[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
        let slow_q = q_distr[(Q_FAST_MAX + 1) as usize..].iter().sum::<usize>();

        assert_eq!(2500, fast_p);
        assert_eq!(0, slow_p);
        assert_eq!(61243, fast_q);
        assert_eq!(7, slow_q);

        let bytes = random_alphanums(0, 16);

        #[allow(non_snake_case)]
        let N = u8::try_from(bytes.len()).unwrap();

        let (p, qq) = find(bytes.as_slice()).unwrap();

        println!("Results:");
        for byte in bytes.iter() {
            print_byte_to_stdout(*byte);
            let l1 = f1(*byte, p, N) as usize;
            let q = qq[l1];
            let l2 = f2(*byte, q, N) as usize;
            println!(" => l1 {l1} => q {q} => l2 {l2}");
        }
    }
}