go-brrr 0.1.0

Token-efficient code analysis for LLMs - Rust implementation
Documentation
//! SIMD-accelerated byte searching utilities.
//!
//! This module provides thin wrappers around the `memchr` crate, which uses
//! SIMD instructions when available for fast byte searching. These functions
//! are portable and automatically select the best implementation for the
//! target architecture (SSE2, AVX2, NEON, or fallback).
//!
//! # Performance
//!
//! The `memchr` crate typically achieves 10-20x speedup over naive byte-by-byte
//! iteration on x86_64 with AVX2, and 4-8x on aarch64 with NEON.
//!
//! # Example
//!
//! ```
//! use go_brrr::util::simd::{find_byte, find_bytes, count_byte};
//!
//! let haystack = b"hello\nworld\n";
//!
//! // Find first newline
//! assert_eq!(find_byte(haystack, b'\n'), Some(5));
//!
//! // Find first occurrence of 'e' or 'o'
//! assert_eq!(find_bytes(haystack, &[b'e', b'o']), Some(1));
//!
//! // Count newlines
//! assert_eq!(count_byte(haystack, b'\n'), 2);
//! ```

use memchr::{memchr, memchr2, memchr3, Memchr};

/// Find the first occurrence of a byte in a slice.
///
/// Uses SIMD-accelerated search when available (SSE2/AVX2 on x86_64, NEON on aarch64).
///
/// # Arguments
///
/// * `haystack` - The byte slice to search in
/// * `needle` - The byte value to search for
///
/// # Returns
///
/// The index of the first occurrence, or `None` if not found.
///
/// # Performance
///
/// O(n) time complexity, but with SIMD acceleration typically 10-20x faster
/// than naive iteration on modern CPUs.
#[inline]
pub fn find_byte(haystack: &[u8], needle: u8) -> Option<usize> {
    memchr(needle, haystack)
}

/// Find the first occurrence of any byte from a set of needles.
///
/// Optimized for 1-3 needles using specialized `memchr2`/`memchr3` functions.
/// For more than 3 needles, falls back to iterating with `memchr`.
///
/// # Arguments
///
/// * `haystack` - The byte slice to search in
/// * `needles` - The byte values to search for (1-255 bytes supported)
///
/// # Returns
///
/// The index of the first occurrence of any needle, or `None` if none found.
///
/// # Panics
///
/// Panics if `needles` is empty.
///
/// # Performance
///
/// - 1 needle: Uses `memchr` (fastest)
/// - 2 needles: Uses `memchr2` (nearly as fast)
/// - 3 needles: Uses `memchr3`
/// - 4+ needles: Iterates through haystack with multi-byte matcher
#[inline]
pub fn find_bytes(haystack: &[u8], needles: &[u8]) -> Option<usize> {
    match needles.len() {
        0 => panic!("find_bytes requires at least one needle"),
        1 => memchr(needles[0], haystack),
        2 => memchr2(needles[0], needles[1], haystack),
        3 => memchr3(needles[0], needles[1], needles[2], haystack),
        _ => find_bytes_many(haystack, needles),
    }
}

/// Internal implementation for 4+ needles.
///
/// Uses a lookup table for O(1) byte membership checking combined with
/// SIMD-accelerated iteration.
fn find_bytes_many(haystack: &[u8], needles: &[u8]) -> Option<usize> {
    // Build a 256-bit lookup table for O(1) membership testing
    let mut lookup = [false; 256];
    for &needle in needles {
        lookup[needle as usize] = true;
    }

    // Iterate through haystack checking membership
    // Note: For very long haystacks with many needles, consider using
    // memchr::memmem or aho-corasick for better performance
    haystack
        .iter()
        .position(|&byte| lookup[byte as usize])
}

/// Count the number of occurrences of a byte in a slice.
///
/// Uses SIMD-accelerated counting when available.
///
/// # Arguments
///
/// * `haystack` - The byte slice to search in
/// * `needle` - The byte value to count
///
/// # Returns
///
/// The number of occurrences of `needle` in `haystack`.
///
/// # Performance
///
/// O(n) time complexity with SIMD acceleration. On x86_64 with AVX2,
/// can process 32 bytes per cycle.
#[inline]
pub fn count_byte(haystack: &[u8], needle: u8) -> usize {
    Memchr::new(needle, haystack).count()
}

/// Iterator over all positions of a byte in a slice.
///
/// Useful when you need all positions, not just the first or count.
///
/// # Example
///
/// ```
/// use go_brrr::util::simd::find_all_bytes;
///
/// let haystack = b"a.b.c.d";
/// let positions: Vec<_> = find_all_bytes(haystack, b'.').collect();
/// assert_eq!(positions, vec![1, 3, 5]);
/// ```
#[inline]
pub fn find_all_bytes(haystack: &[u8], needle: u8) -> impl Iterator<Item = usize> + '_ {
    Memchr::new(needle, haystack)
}

/// Find the last occurrence of a byte in a slice.
///
/// Uses SIMD-accelerated reverse search when available.
///
/// # Arguments
///
/// * `haystack` - The byte slice to search in
/// * `needle` - The byte value to search for
///
/// # Returns
///
/// The index of the last occurrence, or `None` if not found.
#[inline]
pub fn rfind_byte(haystack: &[u8], needle: u8) -> Option<usize> {
    memchr::memrchr(needle, haystack)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_byte_found() {
        assert_eq!(find_byte(b"hello", b'e'), Some(1));
        assert_eq!(find_byte(b"hello", b'h'), Some(0));
        assert_eq!(find_byte(b"hello", b'o'), Some(4));
    }

    #[test]
    fn test_find_byte_not_found() {
        assert_eq!(find_byte(b"hello", b'x'), None);
        assert_eq!(find_byte(b"", b'a'), None);
    }

    #[test]
    fn test_find_bytes_single() {
        assert_eq!(find_bytes(b"hello", &[b'e']), Some(1));
    }

    #[test]
    fn test_find_bytes_double() {
        assert_eq!(find_bytes(b"hello", &[b'x', b'e']), Some(1));
        assert_eq!(find_bytes(b"hello", &[b'e', b'l']), Some(1));
    }

    #[test]
    fn test_find_bytes_triple() {
        assert_eq!(find_bytes(b"hello", &[b'x', b'y', b'o']), Some(4));
    }

    #[test]
    fn test_find_bytes_many() {
        assert_eq!(find_bytes(b"hello", &[b'w', b'x', b'y', b'o']), Some(4));
        assert_eq!(find_bytes(b"hello", &[b'w', b'x', b'y', b'z']), None);
    }

    #[test]
    #[should_panic(expected = "requires at least one needle")]
    fn test_find_bytes_empty_panics() {
        find_bytes(b"hello", &[]);
    }

    #[test]
    fn test_count_byte() {
        assert_eq!(count_byte(b"hello", b'l'), 2);
        assert_eq!(count_byte(b"hello", b'x'), 0);
        assert_eq!(count_byte(b"aaaaaa", b'a'), 6);
        assert_eq!(count_byte(b"", b'a'), 0);
    }

    #[test]
    fn test_find_all_bytes() {
        let positions: Vec<_> = find_all_bytes(b"a.b.c", b'.').collect();
        assert_eq!(positions, vec![1, 3]);
    }

    #[test]
    fn test_rfind_byte() {
        assert_eq!(rfind_byte(b"hello", b'l'), Some(3));
        assert_eq!(rfind_byte(b"hello", b'h'), Some(0));
        assert_eq!(rfind_byte(b"hello", b'x'), None);
    }

    #[test]
    fn test_newline_operations() {
        let text = b"line1\nline2\nline3\n";
        assert_eq!(find_byte(text, b'\n'), Some(5));
        assert_eq!(count_byte(text, b'\n'), 3);
        assert_eq!(rfind_byte(text, b'\n'), Some(17));
    }

    #[test]
    fn test_null_byte() {
        let data = b"hello\0world";
        assert_eq!(find_byte(data, b'\0'), Some(5));
    }
}