pub trait ByteSearch {
fn find_byte(&self, needle: u8) -> Option<usize>;
fn find_crlf(&self, start: usize) -> Option<usize>;
fn find_header_end(&self, start: usize) -> Option<usize>;
}
macro_rules! impl_bitmask_byte_search {
() => {
#[inline]
pub(super) fn find_byte(buf: &[u8], needle: u8) -> Option<usize> {
let len = buf.len();
unsafe {
let nv = simd_splat!(needle);
let mut i = 0;
while i + 16 <= len {
let chunk = simd_load!(buf.as_ptr().add(i));
let mask = simd_mask!(chunk, nv);
if mask != 0 {
return Some(i + mask.trailing_zeros() as usize);
}
i += 16;
}
while i < len {
if *buf.get_unchecked(i) == needle {
return Some(i);
}
i += 1;
}
None
}
}
#[inline]
pub(super) fn find_crlf(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let len = buf.len();
unsafe {
let cr = simd_splat!(HttpChar::CarriageReturn.as_u8());
let mut i = start;
while i + 16 <= len {
let chunk = simd_load!(buf.as_ptr().add(i));
let mut mask = simd_mask!(chunk, cr);
while mask != 0 {
let bit = mask.trailing_zeros() as usize;
let pos = i + bit;
if pos < end && *buf.get_unchecked(pos + 1) == HttpChar::LineFeed {
return Some(pos);
}
mask &= mask - 1;
}
i += 16;
}
while i < end {
if *buf.get_unchecked(i) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 1) == HttpChar::LineFeed
{
return Some(i);
}
i += 1;
}
None
}
}
#[inline]
pub(super) fn find_header_end(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let len = buf.len();
unsafe {
let cr = simd_splat!(HttpChar::CarriageReturn.as_u8());
let mut i = start;
while i + 16 <= len {
let chunk = simd_load!(buf.as_ptr().add(i));
let mut mask = simd_mask!(chunk, cr);
while mask != 0 {
let bit = mask.trailing_zeros() as usize;
let pos = i + bit;
if pos < end
&& *buf.get_unchecked(pos + 1) == HttpChar::LineFeed
&& *buf.get_unchecked(pos + 2) == HttpChar::CarriageReturn
&& *buf.get_unchecked(pos + 3) == HttpChar::LineFeed
{
return Some(pos + 4);
}
mask &= mask - 1;
}
i += 16;
}
while i < end {
if *buf.get_unchecked(i) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 1) == HttpChar::LineFeed
&& *buf.get_unchecked(i + 2) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 3) == HttpChar::LineFeed
{
return Some(i + 4);
}
i += 1;
}
None
}
}
};
}
#[cfg(target_arch = "x86_64")]
#[allow(
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_ptr_alignment
)]
mod sse2 {
crate::simd::define_simd_primitives!();
use crate::ascii::HttpChar;
impl_bitmask_byte_search!();
}
#[cfg(target_arch = "aarch64")]
mod neon {
use std::arch::aarch64::{
vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vreinterpret_u64_u8, vreinterpretq_u16_u8,
vshrn_n_u16,
};
use crate::ascii::HttpChar;
#[inline]
unsafe fn neon_movemask(cmp: std::arch::aarch64::uint8x16_t) -> u64 {
let narrowed = vshrn_n_u16(vreinterpretq_u16_u8(cmp), 4);
vget_lane_u64(vreinterpret_u64_u8(narrowed), 0)
}
#[inline]
pub(super) fn find_byte(buf: &[u8], needle: u8) -> Option<usize> {
let len = buf.len();
unsafe {
let nv = vdupq_n_u8(needle);
let mut i = 0;
while i + 16 <= len {
let chunk = vld1q_u8(buf.as_ptr().add(i));
let cmp = vceqq_u8(chunk, nv);
let bits = neon_movemask(cmp);
if bits != 0 {
return Some(i + (bits.trailing_zeros() as usize) / 4);
}
i += 16;
}
while i < len {
if *buf.get_unchecked(i) == needle {
return Some(i);
}
i += 1;
}
None
}
}
#[inline]
pub(super) fn find_crlf(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let len = buf.len();
unsafe {
let cr = vdupq_n_u8(HttpChar::CarriageReturn.as_u8());
let mut i = start;
while i + 16 <= len {
let chunk = vld1q_u8(buf.as_ptr().add(i));
let cmp = vceqq_u8(chunk, cr);
let mut bits = neon_movemask(cmp);
while bits != 0 {
let bit_pos = bits.trailing_zeros() as usize;
let pos = i + bit_pos / 4;
if pos < end && *buf.get_unchecked(pos + 1) == HttpChar::LineFeed {
return Some(pos);
}
bits &= !(0xFu64 << (bit_pos & !3));
}
i += 16;
}
while i < end {
if *buf.get_unchecked(i) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 1) == HttpChar::LineFeed
{
return Some(i);
}
i += 1;
}
None
}
}
#[inline]
pub(super) fn find_header_end(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let len = buf.len();
unsafe {
let cr = vdupq_n_u8(HttpChar::CarriageReturn.as_u8());
let mut i = start;
while i + 16 <= len {
let chunk = vld1q_u8(buf.as_ptr().add(i));
let cmp = vceqq_u8(chunk, cr);
let mut bits = neon_movemask(cmp);
while bits != 0 {
let bit_pos = bits.trailing_zeros() as usize;
let pos = i + bit_pos / 4;
if pos < end
&& *buf.get_unchecked(pos + 1) == HttpChar::LineFeed
&& *buf.get_unchecked(pos + 2) == HttpChar::CarriageReturn
&& *buf.get_unchecked(pos + 3) == HttpChar::LineFeed
{
return Some(pos + 4);
}
bits &= !(0xFu64 << (bit_pos & !3));
}
i += 16;
}
while i < end {
if *buf.get_unchecked(i) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 1) == HttpChar::LineFeed
&& *buf.get_unchecked(i + 2) == HttpChar::CarriageReturn
&& *buf.get_unchecked(i + 3) == HttpChar::LineFeed
{
return Some(i + 4);
}
i += 1;
}
None
}
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
#[allow(clippy::cast_sign_loss)]
mod wasm_simd {
crate::simd::define_simd_primitives!();
use crate::ascii::HttpChar;
impl_bitmask_byte_search!();
}
#[cfg(not(any(
target_arch = "x86_64",
target_arch = "aarch64",
all(target_arch = "wasm32", target_feature = "simd128")
)))]
mod scalar {
use crate::ascii::HttpChar;
#[inline]
pub(super) fn find_byte(buf: &[u8], needle: u8) -> Option<usize> {
buf.iter().position(|&b| b == needle)
}
#[inline]
pub(super) fn find_crlf(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let mut i = start;
while i < end {
if buf[i] == HttpChar::CarriageReturn && buf[i + 1] == HttpChar::LineFeed {
return Some(i);
}
i += 1;
}
None
}
#[inline]
pub(super) fn find_header_end(buf: &[u8], start: usize, end: usize) -> Option<usize> {
let mut i = start;
while i < end {
if buf[i] == HttpChar::CarriageReturn
&& buf[i + 1] == HttpChar::LineFeed
&& buf[i + 2] == HttpChar::CarriageReturn
&& buf[i + 3] == HttpChar::LineFeed
{
return Some(i + 4);
}
i += 1;
}
None
}
}
macro_rules! dispatch {
($fn_name:ident $(, $arg:expr)* $(,)?) => {{
#[cfg(target_arch = "x86_64")]
{ sse2::$fn_name($($arg),*) }
#[cfg(target_arch = "aarch64")]
{ neon::$fn_name($($arg),*) }
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{ wasm_simd::$fn_name($($arg),*) }
#[cfg(not(any(
target_arch = "x86_64",
target_arch = "aarch64",
all(target_arch = "wasm32", target_feature = "simd128")
)))]
{ scalar::$fn_name($($arg),*) }
}};
}
impl ByteSearch for [u8] {
#[inline]
fn find_byte(&self, needle: u8) -> Option<usize> {
dispatch!(find_byte, self, needle)
}
#[inline]
fn find_crlf(&self, start: usize) -> Option<usize> {
let len = self.len();
if len < 2 {
return None;
}
dispatch!(find_crlf, self, start, len - 1)
}
#[inline]
fn find_header_end(&self, start: usize) -> Option<usize> {
let len = self.len();
if len < 4 {
return None;
}
dispatch!(find_header_end, self, start, len - 3)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_byte_empty() {
assert_eq!(b"".find_byte(b'x'), None);
}
#[test]
fn find_byte_single_match() {
assert_eq!(b"x".find_byte(b'x'), Some(0));
}
#[test]
fn find_byte_single_no_match() {
assert_eq!(b"a".find_byte(b'x'), None);
}
#[test]
fn find_byte_at_simd_boundary() {
for pos in 0..16 {
let mut buf = [b'.'; 16];
buf[pos] = b'x';
assert_eq!(buf.find_byte(b'x'), Some(pos), "pos={pos}");
}
}
#[test]
fn find_byte_in_scalar_tail() {
let mut buf = [b'.'; 17];
buf[16] = b'x';
assert_eq!(buf.find_byte(b'x'), Some(16));
}
#[test]
fn find_byte_across_two_simd_chunks() {
let mut buf = [b'.'; 33];
buf[20] = b'x';
assert_eq!(buf.find_byte(b'x'), Some(20));
}
#[test]
fn find_byte_no_match_large() {
let buf = [b'.'; 64];
assert_eq!(buf.find_byte(b'x'), None);
}
#[test]
fn find_crlf_empty() {
assert_eq!(b"".find_crlf(0), None);
}
#[test]
fn find_crlf_single_byte() {
assert_eq!(b"\r".find_crlf(0), None);
}
#[test]
fn find_crlf_at_start() {
assert_eq!(b"\r\nrest".find_crlf(0), Some(0));
}
#[test]
fn find_crlf_at_simd_boundary() {
let mut buf = [b'.'; 16];
buf[14] = b'\r';
buf[15] = b'\n';
assert_eq!(buf.find_crlf(0), Some(14));
}
#[test]
fn find_crlf_spanning_simd_boundary() {
let mut buf = [b'.'; 17];
buf[15] = b'\r';
buf[16] = b'\n';
assert_eq!(buf.find_crlf(0), Some(15));
}
#[test]
fn find_crlf_in_scalar_tail() {
let mut buf = [b'.'; 19];
buf[17] = b'\r';
buf[18] = b'\n';
assert_eq!(buf.find_crlf(0), Some(17));
}
#[test]
fn find_crlf_bare_cr_no_lf() {
let mut buf = [b'.'; 5];
buf[2] = b'\r';
assert_eq!(buf.find_crlf(0), None);
}
#[test]
fn find_crlf_with_start_offset() {
let buf = b"first\r\nsecond\r\n";
assert_eq!(buf.find_crlf(0), Some(5));
assert_eq!(buf.find_crlf(6), Some(13));
}
#[test]
fn find_crlf_dense_cr_no_lf() {
let buf = [b'\r'; 32];
assert_eq!(buf.find_crlf(0), None);
}
#[test]
fn find_crlf_dense_cr_with_one_lf() {
let mut buf = [b'\r'; 33];
buf[32] = b'\n';
assert_eq!(buf.find_crlf(0), Some(31));
}
#[test]
fn find_header_end_empty() {
assert_eq!(b"".find_header_end(0), None);
}
#[test]
fn find_header_end_too_short() {
assert_eq!(b"\r\n\r".find_header_end(0), None);
}
#[test]
fn find_header_end_exact() {
assert_eq!(b"\r\n\r\n".find_header_end(0), Some(4));
}
#[test]
fn find_header_end_after_headers() {
let buf = b"Host: localhost\r\n\r\nbody";
assert_eq!(buf.find_header_end(0), Some(19));
}
#[test]
fn find_header_end_at_simd_boundary() {
let mut buf = [b'.'; 20];
buf[12] = b'\r';
buf[13] = b'\n';
buf[14] = b'\r';
buf[15] = b'\n';
assert_eq!(buf.find_header_end(0), Some(16));
}
#[test]
fn find_header_end_spanning_simd_boundary() {
let mut buf = [b'.'; 20];
buf[14] = b'\r';
buf[15] = b'\n';
buf[16] = b'\r';
buf[17] = b'\n';
assert_eq!(buf.find_header_end(0), Some(18));
}
#[test]
fn find_header_end_in_scalar_tail() {
let mut buf = [b'.'; 21];
buf[17] = b'\r';
buf[18] = b'\n';
buf[19] = b'\r';
buf[20] = b'\n';
assert_eq!(buf.find_header_end(0), Some(21));
}
#[test]
fn find_header_end_with_start_offset() {
let buf = b"GET / HTTP/1.1\r\nHost: h\r\n\r\n";
assert_eq!(buf.find_header_end(0), Some(27));
assert_eq!(buf.find_header_end(16), Some(27));
}
#[test]
fn find_header_end_no_match() {
let buf = b"GET / HTTP/1.1\r\nHost: h\r\n";
assert_eq!(buf.find_header_end(0), None);
}
#[test]
fn find_header_end_many_crs_no_terminator() {
let buf = [b'\r'; 64];
assert_eq!(buf.find_header_end(0), None);
}
#[test]
fn find_header_end_dense_crs_with_terminator() {
let mut buf = [b'\r'; 64];
buf[31] = b'\n';
buf[33] = b'\n';
assert_eq!(&buf[30..34], b"\r\n\r\n");
assert_eq!(buf.find_header_end(0), Some(34));
}
}