branchless_core 0.1.0

Algorithms and data structures designed to maximize performance on superscalar processors.
Documentation
#![cfg(target_arch = "x86_64")]
#![cfg(target_feature = "sse2")]

use core::arch::x86_64::{
    __m128i as m128, _mm_adds_epu16, _mm_adds_epu8, _mm_and_si128, _mm_bsrli_si128, _mm_cmpeq_epi8,
    _mm_cvtsi128_si32, _mm_loadu_si128, _mm_maddubs_epi16, _mm_movemask_epi8, _mm_packus_epi16,
    _mm_set1_epi8, _mm_set_epi8, _mm_shuffle_epi32, _mm_shuffle_epi8, _mm_subs_epi8,
    _mm_test_all_ones, _mm_xor_si128,
};

use crate::ip::Ipv4ParseError;

#[allow(non_snake_case)]
pub const fn _MM_SHUFFLE(z: u32, y: u32, x: u32, w: u32) -> i32 {
    ((z << 6) | (y << 4) | (x << 2) | w) as i32
}

#[cfg(test)]
extern crate std;

#[cfg(not(test))]
macro_rules! dbg {
    () => {};
    ($val:expr $(,)?) => {};
    ($($val:expr),+ $(,)?) => {};
}

static PATTERNS_ID: [u8; 256] = [
    38, 65, 255, 56, 73, 255, 255, 255, 255, 255, 255, 3, 255, 255, 6, 255, 255, 9, 255, 27, 255,
    12, 30, 255, 255, 255, 255, 15, 255, 33, 255, 255, 255, 255, 18, 36, 255, 255, 255, 54, 21,
    255, 39, 255, 255, 57, 255, 255, 255, 255, 255, 255, 255, 255, 24, 42, 255, 255, 255, 60, 255,
    255, 255, 255, 255, 255, 255, 255, 45, 255, 255, 63, 255, 255, 255, 255, 255, 255, 255, 255,
    255, 48, 53, 255, 255, 66, 71, 255, 255, 16, 255, 34, 255, 255, 255, 255, 255, 255, 255, 52,
    255, 255, 22, 70, 40, 255, 255, 58, 51, 255, 255, 69, 255, 255, 255, 255, 255, 255, 255, 255,
    255, 5, 255, 255, 255, 255, 255, 255, 11, 29, 46, 255, 255, 64, 255, 255, 72, 0, 77, 255, 255,
    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 76, 255, 255, 255, 255,
    255, 255, 255, 75, 255, 80, 255, 255, 255, 26, 255, 44, 255, 7, 62, 255, 255, 25, 255, 43, 13,
    31, 61, 255, 255, 255, 255, 255, 255, 255, 255, 255, 2, 19, 37, 255, 255, 50, 55, 79, 68, 255,
    255, 255, 255, 49, 255, 255, 67, 255, 255, 255, 255, 17, 255, 35, 78, 255, 4, 255, 255, 255,
    255, 255, 255, 10, 23, 28, 41, 255, 255, 59, 255, 255, 255, 8, 255, 255, 255, 255, 255, 1, 14,
    32, 255, 255, 255, 255, 255, 255, 255, 255, 74, 255, 47, 20,
];

static PATTERNS: [[u8; 16]; 81] = [
    [
        0, 128, 2, 128, 4, 128, 6, 128, 128, 128, 128, 128, 128, 128, 128, 128,
    ],
    [
        0, 128, 2, 128, 4, 128, 7, 6, 128, 128, 128, 128, 128, 128, 128, 6,
    ],
    [
        0, 128, 2, 128, 4, 128, 8, 7, 128, 128, 128, 128, 128, 128, 6, 6,
    ],
    [
        0, 128, 2, 128, 5, 4, 7, 128, 128, 128, 128, 128, 128, 4, 128, 128,
    ],
    [
        0, 128, 2, 128, 5, 4, 8, 7, 128, 128, 128, 128, 128, 4, 128, 7,
    ],
    [0, 128, 2, 128, 5, 4, 9, 8, 128, 128, 128, 128, 128, 4, 7, 7],
    [
        0, 128, 2, 128, 6, 5, 8, 128, 128, 128, 128, 128, 4, 4, 128, 128,
    ],
    [0, 128, 2, 128, 6, 5, 9, 8, 128, 128, 128, 128, 4, 4, 128, 8],
    [0, 128, 2, 128, 6, 5, 10, 9, 128, 128, 128, 128, 4, 4, 8, 8],
    [
        0, 128, 3, 2, 5, 128, 7, 128, 128, 128, 128, 2, 128, 128, 128, 128,
    ],
    [
        0, 128, 3, 2, 5, 128, 8, 7, 128, 128, 128, 2, 128, 128, 128, 7,
    ],
    [0, 128, 3, 2, 5, 128, 9, 8, 128, 128, 128, 2, 128, 128, 7, 7],
    [
        0, 128, 3, 2, 6, 5, 8, 128, 128, 128, 128, 2, 128, 5, 128, 128,
    ],
    [0, 128, 3, 2, 6, 5, 9, 8, 128, 128, 128, 2, 128, 5, 128, 8],
    [0, 128, 3, 2, 6, 5, 10, 9, 128, 128, 128, 2, 128, 5, 8, 8],
    [0, 128, 3, 2, 7, 6, 9, 128, 128, 128, 128, 2, 5, 5, 128, 128],
    [0, 128, 3, 2, 7, 6, 10, 9, 128, 128, 128, 2, 5, 5, 128, 9],
    [0, 128, 3, 2, 7, 6, 11, 10, 128, 128, 128, 2, 5, 5, 9, 9],
    [
        0, 128, 4, 3, 6, 128, 8, 128, 128, 128, 2, 2, 128, 128, 128, 128,
    ],
    [0, 128, 4, 3, 6, 128, 9, 8, 128, 128, 2, 2, 128, 128, 128, 8],
    [0, 128, 4, 3, 6, 128, 10, 9, 128, 128, 2, 2, 128, 128, 8, 8],
    [0, 128, 4, 3, 7, 6, 9, 128, 128, 128, 2, 2, 128, 6, 128, 128],
    [0, 128, 4, 3, 7, 6, 10, 9, 128, 128, 2, 2, 128, 6, 128, 9],
    [0, 128, 4, 3, 7, 6, 11, 10, 128, 128, 2, 2, 128, 6, 9, 9],
    [0, 128, 4, 3, 8, 7, 10, 128, 128, 128, 2, 2, 6, 6, 128, 128],
    [0, 128, 4, 3, 8, 7, 11, 10, 128, 128, 2, 2, 6, 6, 128, 10],
    [0, 128, 4, 3, 8, 7, 12, 11, 128, 128, 2, 2, 6, 6, 10, 10],
    [
        1, 0, 3, 128, 5, 128, 7, 128, 128, 0, 128, 128, 128, 128, 128, 128,
    ],
    [
        1, 0, 3, 128, 5, 128, 8, 7, 128, 0, 128, 128, 128, 128, 128, 7,
    ],
    [1, 0, 3, 128, 5, 128, 9, 8, 128, 0, 128, 128, 128, 128, 7, 7],
    [
        1, 0, 3, 128, 6, 5, 8, 128, 128, 0, 128, 128, 128, 5, 128, 128,
    ],
    [1, 0, 3, 128, 6, 5, 9, 8, 128, 0, 128, 128, 128, 5, 128, 8],
    [1, 0, 3, 128, 6, 5, 10, 9, 128, 0, 128, 128, 128, 5, 8, 8],
    [1, 0, 3, 128, 7, 6, 9, 128, 128, 0, 128, 128, 5, 5, 128, 128],
    [1, 0, 3, 128, 7, 6, 10, 9, 128, 0, 128, 128, 5, 5, 128, 9],
    [1, 0, 3, 128, 7, 6, 11, 10, 128, 0, 128, 128, 5, 5, 9, 9],
    [
        1, 0, 4, 3, 6, 128, 8, 128, 128, 0, 128, 3, 128, 128, 128, 128,
    ],
    [1, 0, 4, 3, 6, 128, 9, 8, 128, 0, 128, 3, 128, 128, 128, 8],
    [1, 0, 4, 3, 6, 128, 10, 9, 128, 0, 128, 3, 128, 128, 8, 8],
    [1, 0, 4, 3, 7, 6, 9, 128, 128, 0, 128, 3, 128, 6, 128, 128],
    [1, 0, 4, 3, 7, 6, 10, 9, 128, 0, 128, 3, 128, 6, 128, 9],
    [1, 0, 4, 3, 7, 6, 11, 10, 128, 0, 128, 3, 128, 6, 9, 9],
    [1, 0, 4, 3, 8, 7, 10, 128, 128, 0, 128, 3, 6, 6, 128, 128],
    [1, 0, 4, 3, 8, 7, 11, 10, 128, 0, 128, 3, 6, 6, 128, 10],
    [1, 0, 4, 3, 8, 7, 12, 11, 128, 0, 128, 3, 6, 6, 10, 10],
    [1, 0, 5, 4, 7, 128, 9, 128, 128, 0, 3, 3, 128, 128, 128, 128],
    [1, 0, 5, 4, 7, 128, 10, 9, 128, 0, 3, 3, 128, 128, 128, 9],
    [1, 0, 5, 4, 7, 128, 11, 10, 128, 0, 3, 3, 128, 128, 9, 9],
    [1, 0, 5, 4, 8, 7, 10, 128, 128, 0, 3, 3, 128, 7, 128, 128],
    [1, 0, 5, 4, 8, 7, 11, 10, 128, 0, 3, 3, 128, 7, 128, 10],
    [1, 0, 5, 4, 8, 7, 12, 11, 128, 0, 3, 3, 128, 7, 10, 10],
    [1, 0, 5, 4, 9, 8, 11, 128, 128, 0, 3, 3, 7, 7, 128, 128],
    [1, 0, 5, 4, 9, 8, 12, 11, 128, 0, 3, 3, 7, 7, 128, 11],
    [1, 0, 5, 4, 9, 8, 13, 12, 128, 0, 3, 3, 7, 7, 11, 11],
    [
        2, 1, 4, 128, 6, 128, 8, 128, 0, 0, 128, 128, 128, 128, 128, 128,
    ],
    [2, 1, 4, 128, 6, 128, 9, 8, 0, 0, 128, 128, 128, 128, 128, 8],
    [2, 1, 4, 128, 6, 128, 10, 9, 0, 0, 128, 128, 128, 128, 8, 8],
    [2, 1, 4, 128, 7, 6, 9, 128, 0, 0, 128, 128, 128, 6, 128, 128],
    [2, 1, 4, 128, 7, 6, 10, 9, 0, 0, 128, 128, 128, 6, 128, 9],
    [2, 1, 4, 128, 7, 6, 11, 10, 0, 0, 128, 128, 128, 6, 9, 9],
    [2, 1, 4, 128, 8, 7, 10, 128, 0, 0, 128, 128, 6, 6, 128, 128],
    [2, 1, 4, 128, 8, 7, 11, 10, 0, 0, 128, 128, 6, 6, 128, 10],
    [2, 1, 4, 128, 8, 7, 12, 11, 0, 0, 128, 128, 6, 6, 10, 10],
    [2, 1, 5, 4, 7, 128, 9, 128, 0, 0, 128, 4, 128, 128, 128, 128],
    [2, 1, 5, 4, 7, 128, 10, 9, 0, 0, 128, 4, 128, 128, 128, 9],
    [2, 1, 5, 4, 7, 128, 11, 10, 0, 0, 128, 4, 128, 128, 9, 9],
    [2, 1, 5, 4, 8, 7, 10, 128, 0, 0, 128, 4, 128, 7, 128, 128],
    [2, 1, 5, 4, 8, 7, 11, 10, 0, 0, 128, 4, 128, 7, 128, 10],
    [2, 1, 5, 4, 8, 7, 12, 11, 0, 0, 128, 4, 128, 7, 10, 10],
    [2, 1, 5, 4, 9, 8, 11, 128, 0, 0, 128, 4, 7, 7, 128, 128],
    [2, 1, 5, 4, 9, 8, 12, 11, 0, 0, 128, 4, 7, 7, 128, 11],
    [2, 1, 5, 4, 9, 8, 13, 12, 0, 0, 128, 4, 7, 7, 11, 11],
    [2, 1, 6, 5, 8, 128, 10, 128, 0, 0, 4, 4, 128, 128, 128, 128],
    [2, 1, 6, 5, 8, 128, 11, 10, 0, 0, 4, 4, 128, 128, 128, 10],
    [2, 1, 6, 5, 8, 128, 12, 11, 0, 0, 4, 4, 128, 128, 10, 10],
    [2, 1, 6, 5, 9, 8, 11, 128, 0, 0, 4, 4, 128, 8, 128, 128],
    [2, 1, 6, 5, 9, 8, 12, 11, 0, 0, 4, 4, 128, 8, 128, 11],
    [2, 1, 6, 5, 9, 8, 13, 12, 0, 0, 4, 4, 128, 8, 11, 11],
    [2, 1, 6, 5, 10, 9, 12, 128, 0, 0, 4, 4, 8, 8, 128, 128],
    [2, 1, 6, 5, 10, 9, 13, 12, 0, 0, 4, 4, 8, 8, 128, 12],
    [2, 1, 6, 5, 10, 9, 14, 13, 0, 0, 4, 4, 8, 8, 12, 12],
];

/// Parse ipv4 address using Mula's technique, refined by Lemire.
/// http://0x80.pl/notesen/2023-04-09-faster-parse-ipv4.html
/// https://lemire.me/blog/2023/06/08/parsing-ip-addresses-crazily-fast/
pub fn parse_ipv4(s: &str) -> Result<u32, Ipv4ParseError> {
    let mut v: m128 = masked_load_or_die(s)?;
    unsafe {
        let all_dots: m128 = _mm_set1_epi8(0x2E);
        let dot_locations: m128 = _mm_cmpeq_epi8(v, all_dots);
        let dot_mask: i32 = _mm_movemask_epi8(dot_locations);

        let saturation_distance = _mm_set1_epi8(0x76);
        v = _mm_xor_si128(v, _mm_set1_epi8(0x30));
        v = _mm_adds_epu8(v, saturation_distance);
        let non_digit_mask = _mm_movemask_epi8(v);
        v = _mm_subs_epi8(v, saturation_distance);

        let bad_mask = dot_mask ^ non_digit_mask;
        let clip_mask: i32 = bad_mask ^ (bad_mask - 1);
        let partition_mask = non_digit_mask & clip_mask;

        let hash_key = (((partition_mask as u64) * 0x00CF7800) >> 24) as u8;

        let hash_id = PATTERNS_ID[hash_key as usize];
        if hash_id >= 81 {
            return Err(Ipv4ParseError::Invalid);
        }

        let pattern_ptr = PATTERNS[hash_id as usize].as_ptr() as *const m128;
        let shuf = _mm_loadu_si128(pattern_ptr);
        v = _mm_shuffle_epi8(v, shuf);

        let mul_weights = _mm_set_epi8(0, 100, 0, 100, 0, 100, 0, 100, 10, 1, 10, 1, 10, 1, 10, 1);
        let mut acc = _mm_maddubs_epi16(mul_weights, v);
        let swapped = _mm_shuffle_epi32(acc, _MM_SHUFFLE(1, 0, 3, 2));
        acc = _mm_adds_epu16(acc, swapped);

        let address = _mm_cvtsi128_si32(_mm_packus_epi16(acc, acc));

        Ok(address as u32)
    }
}

fn masked_load_or_die(s: &str) -> Result<m128, Ipv4ParseError> {
    let v: m128 = unsafe { _mm_loadu_si128(s.as_ptr() as *const m128) };
    let mask = unsafe { _mm_set1_epi8(-1 as i8) };

    macro_rules! devolve {
        ($n:literal) => {
            unsafe {
                let shifted = _mm_bsrli_si128::<{ 16 - $n }>(mask);
                Ok(_mm_and_si128(shifted, v))
            }
        };
    }

    match s.len() {
        1 => devolve!(1),
        2 => devolve!(2),
        3 => devolve!(3),
        4 => devolve!(4),
        5 => devolve!(5),
        6 => devolve!(6),
        7 => devolve!(7),
        8 => devolve!(8),
        9 => devolve!(9),
        10 => devolve!(10),
        11 => devolve!(11),
        12 => devolve!(12),
        13 => devolve!(13),
        14 => devolve!(14),
        15 => devolve!(15),
        _ => Err(Ipv4ParseError::WrongLength),
    }
}

fn are_equal(a: m128, b: m128) -> bool {
    unsafe {
        let compared = _mm_cmpeq_epi8(a, b);
        _mm_test_all_ones(compared) == 1
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    pub fn test_masked_load_masks() {
        let a = "hello world";
        let b = "hello attacker";

        let a_masked = masked_load_or_die(&a[0..=4]).unwrap();
        let b_masked = masked_load_or_die(&b[0..=4]).unwrap();

        assert!(are_equal(a_masked, b_masked));
    }

    #[test]
    fn parse_ips() {
        let localhost_known = std::net::Ipv4Addr::new(127, 0, 0, 1);
        let localhost = parse_ipv4("127.0.0.1").unwrap();
        let ne = u32::to_ne_bytes(localhost);

        assert_eq!(ne, localhost_known.octets());
    }
}