pgm-extra 1.3.0

PGM-Extra: high-performance learned index structures for Rust
Documentation
#[inline]
pub fn pgm_sub_eps(pos: usize, eps: usize) -> usize {
    pos.saturating_sub(eps)
}

#[inline]
pub fn pgm_add_eps(pos: usize, eps: usize, size: usize) -> usize {
    let upper = pos.saturating_add(eps).saturating_add(2);
    if upper >= size { size } else { upper }
}

#[inline]
pub fn binary_search_branchless<K: Ord>(keys: &[K], key: &K, lo: usize, hi: usize) -> usize {
    if lo >= hi || lo >= keys.len() {
        return lo;
    }

    let hi = hi.min(keys.len());
    let slice = &keys[lo..hi];

    let mut size = slice.len();
    if size == 0 {
        return lo;
    }

    let mut base = 0usize;
    while size > 1 {
        let half = size / 2;
        let mid = base + half;
        base = if slice[mid] < *key { mid } else { base };
        size -= half;
    }

    let result = base + (slice[base] < *key) as usize;
    lo + result
}

#[inline]
pub fn linear_search<K: Ord>(keys: &[K], key: &K, lo: usize, hi: usize) -> usize {
    let hi = hi.min(keys.len());
    if hi <= lo {
        return lo;
    }

    let slice = &keys[lo..hi];
    let len = slice.len();
    let mut i = 0;

    while i + 4 <= len {
        if slice[i] >= *key {
            return lo + i;
        }
        if slice[i + 1] >= *key {
            return lo + i + 1;
        }
        if slice[i + 2] >= *key {
            return lo + i + 2;
        }
        if slice[i + 3] >= *key {
            return lo + i + 3;
        }
        i += 4;
    }

    while i < len {
        if slice[i] >= *key {
            return lo + i;
        }
        i += 1;
    }

    hi
}

const LINEAR_SEARCH_THRESHOLD: usize = 64;

#[inline]
pub fn linear_search_simd<K: Ord + Copy>(keys: &[K], key: &K, lo: usize, hi: usize) -> usize {
    let hi = hi.min(keys.len());
    if hi <= lo {
        return lo;
    }

    if let Some(pos) = &keys[lo..hi].iter().position(|k| k >= key) {
        return lo + pos;
    }

    hi
}

#[inline]
pub fn adaptive_search<K: Ord>(keys: &[K], key: &K, lo: usize, hi: usize) -> usize {
    let hi = hi.min(keys.len());
    if hi <= lo {
        return lo;
    }

    if (hi - lo) <= LINEAR_SEARCH_THRESHOLD {
        return linear_search(keys, key, lo, hi);
    }

    binary_search_branchless(keys, key, lo, hi)
}

#[cfg(test)]
mod tests {
    use super::*;
    use alloc::vec;

    #[test]
    fn test_pgm_sub_eps() {
        assert_eq!(pgm_sub_eps(10, 3), 7);
        assert_eq!(pgm_sub_eps(3, 10), 0);
        assert_eq!(pgm_sub_eps(0, 5), 0);
    }

    #[test]
    fn test_pgm_add_eps() {
        assert_eq!(pgm_add_eps(10, 3, 100), 15);
        assert_eq!(pgm_add_eps(95, 10, 100), 100);
        assert_eq!(pgm_add_eps(0, 5, 10), 7);
    }

    #[test]
    fn test_binary_search_branchless() {
        let keys = vec![1, 3, 5, 7, 9, 11, 13, 15];
        assert_eq!(binary_search_branchless(&keys, &5, 0, 8), 2);
        assert_eq!(binary_search_branchless(&keys, &6, 0, 8), 3);
        assert_eq!(binary_search_branchless(&keys, &1, 0, 8), 0);
        assert_eq!(binary_search_branchless(&keys, &0, 0, 8), 0);
    }

    #[test]
    fn test_linear_search() {
        let keys = vec![1, 3, 5, 7, 9];
        assert_eq!(linear_search(&keys, &5, 0, 5), 2);
        assert_eq!(linear_search(&keys, &6, 0, 5), 3);
        assert_eq!(linear_search(&keys, &0, 0, 5), 0);
        assert_eq!(linear_search(&keys, &100, 0, 5), 5);
    }

    #[cfg(feature = "simd")]
    #[test]
    fn test_simd_linear_search_u64() {
        let keys: Vec<u64> = (0..100).collect();

        for &key in &[0u64, 50, 99, 100] {
            let expected = linear_search(&keys, &key, 0, 100);
            let simd_result = linear_search_simd(&keys, &key, 0, 100);
            assert_eq!(simd_result, expected, "Mismatch for key {}", key);
        }

        assert_eq!(linear_search_simd(&keys, &25, 10, 50), 25);
        assert_eq!(linear_search_simd(&keys, &5, 10, 50), 10);
    }

    #[cfg(feature = "simd")]
    #[test]
    fn test_simd_linear_search_u32() {
        let keys: Vec<u32> = (0..100).collect();

        for &key in &[0u32, 50, 99, 100] {
            let expected = linear_search(&keys, &key, 0, 100);
            let simd_result = linear_search_simd(&keys, &key, 0, 100);
            assert_eq!(simd_result, expected, "Mismatch for key {}", key);
        }
    }
}