use bytes::Bytes;
use memchr::memmem::Finder;
pub const MAX_PARTS: usize = 32;
#[derive(Debug)]
pub struct PayloadParts {
payload: Bytes,
offsets: [(u32, u32); MAX_PARTS],
count: usize,
scan_cursor: usize,
finished: bool,
}
impl PayloadParts {
#[inline]
pub fn split(payload: Bytes, delimiter: &[u8]) -> Self {
let mut offsets = [(0u32, 0u32); MAX_PARTS];
let mut count = 0;
let mut start = 0usize;
if delimiter.is_empty() {
offsets[0] = (0, payload.len() as u32);
return Self {
payload,
offsets,
count: 1,
scan_cursor: 0,
finished: true,
};
}
let finder = Finder::new(delimiter);
let payload_len = payload.len();
let data = payload.as_ref();
while count < MAX_PARTS - 1 {
if let Some(pos) = finder.find(&data[start..]) {
offsets[count] = (start as u32, (start + pos) as u32);
count += 1;
start += pos + delimiter.len();
} else {
break;
}
}
if start <= payload_len && count < MAX_PARTS {
offsets[count] = (start as u32, payload_len as u32);
count += 1;
}
Self {
payload,
offsets,
count,
scan_cursor: payload_len,
finished: true,
}
}
#[inline]
pub fn split_with_finder(payload: Bytes, finder: &Finder<'_>, delim_len: usize) -> Self {
let mut offsets = [(0u32, 0u32); MAX_PARTS];
let mut count = 0;
let mut start = 0usize;
if delim_len == 0 {
offsets[0] = (0, payload.len() as u32);
return Self {
payload,
offsets,
count: 1,
scan_cursor: 0,
finished: true,
};
}
let payload_len = payload.len();
let data = payload.as_ref();
while count < MAX_PARTS - 1 {
if let Some(pos) = finder.find(&data[start..]) {
offsets[count] = (start as u32, (start + pos) as u32);
count += 1;
start += pos + delim_len;
} else {
break;
}
}
if start <= payload_len && count < MAX_PARTS {
offsets[count] = (start as u32, payload_len as u32);
count += 1;
}
Self {
payload,
offsets,
count,
scan_cursor: payload_len,
finished: true,
}
}
#[inline]
pub fn new_lazy(payload: Bytes) -> Self {
Self {
payload,
offsets: [(0u32, 0u32); MAX_PARTS],
count: 0,
scan_cursor: 0,
finished: false,
}
}
#[inline]
pub fn ensure(&mut self, index: usize, finder: &Finder<'_>, delim_len: usize) {
if index < self.count || self.finished {
return;
}
let data = self.payload.as_ref();
while self.count <= index && !self.finished {
if self.count >= MAX_PARTS - 1 {
if self.scan_cursor <= data.len() {
self.offsets[self.count] =
(self.scan_cursor as u32, data.len() as u32);
self.count += 1;
}
self.finished = true;
return;
}
if let Some(pos) = finder.find(&data[self.scan_cursor..]) {
self.offsets[self.count] =
(self.scan_cursor as u32, (self.scan_cursor + pos) as u32);
self.count += 1;
self.scan_cursor += pos + delim_len;
} else {
if self.scan_cursor <= data.len() && self.count < MAX_PARTS {
self.offsets[self.count] =
(self.scan_cursor as u32, data.len() as u32);
self.count += 1;
}
self.finished = true;
return;
}
}
}
#[inline]
pub fn len(&self) -> usize {
self.count
}
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[inline]
pub fn get(&self, index: usize) -> &[u8] {
if index < self.count {
let (start, end) = self.offsets[index];
&self.payload[start as usize..end as usize]
} else {
&[]
}
}
#[inline]
pub fn get_bytes(&self, index: usize) -> Bytes {
if index < self.count {
let (start, end) = self.offsets[index];
self.payload.slice(start as usize..end as usize)
} else {
Bytes::new()
}
}
#[inline]
pub fn payload(&self) -> &Bytes {
&self.payload
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
(0..self.count).map(move |i| self.get(i))
}
}
#[inline]
pub fn extract_header_value<'a>(headers: &'a [u8], header_name: &[u8]) -> Option<&'a [u8]> {
if headers.is_empty() || header_name.is_empty() {
return None;
}
let mut line_start = 0;
while line_start < headers.len() {
let line_end = match memchr::memchr2(b'\r', b'\n', &headers[line_start..]) {
Some(pos) => line_start + pos,
None => headers.len(),
};
let line = &headers[line_start..line_end];
if line.len() > header_name.len() {
let potential_name = &line[..header_name.len()];
if potential_name.eq_ignore_ascii_case(header_name) && line[header_name.len()] == b':' {
let mut val_start = header_name.len() + 1;
while val_start < line.len()
&& (line[val_start] == b' ' || line[val_start] == b'\t')
{
val_start += 1;
}
return Some(&line[val_start..]);
}
}
line_start = line_end;
if line_start < headers.len() && headers[line_start] == b'\r' {
line_start += 1;
}
if line_start < headers.len() && headers[line_start] == b'\n' {
line_start += 1;
}
if line_start == line_end {
break;
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_basic() {
let payload = Bytes::from("a;;;b;;;c");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.len(), 3);
assert_eq!(parts.get(0), b"a");
assert_eq!(parts.get(1), b"b");
assert_eq!(parts.get(2), b"c");
}
#[test]
fn test_split_no_delimiter() {
let payload = Bytes::from("hello world");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.len(), 1);
assert_eq!(parts.get(0), b"hello world");
}
#[test]
fn test_split_empty_parts() {
let payload = Bytes::from("a;;;;;;b");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.len(), 3);
assert_eq!(parts.get(0), b"a");
assert_eq!(parts.get(1), b"");
assert_eq!(parts.get(2), b"b");
}
#[test]
fn test_split_empty_payload() {
let payload = Bytes::from("");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.len(), 1);
assert_eq!(parts.get(0), b"");
}
#[test]
fn test_split_trailing_delimiter() {
let payload = Bytes::from("a;;;b;;;");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.len(), 3);
assert_eq!(parts.get(0), b"a");
assert_eq!(parts.get(1), b"b");
assert_eq!(parts.get(2), b"");
}
#[test]
fn test_split_single_char_delimiter() {
let payload = Bytes::from("a|b|c|d");
let parts = PayloadParts::split(payload, b"|");
assert_eq!(parts.len(), 4);
assert_eq!(parts.get(0), b"a");
assert_eq!(parts.get(1), b"b");
assert_eq!(parts.get(2), b"c");
assert_eq!(parts.get(3), b"d");
}
#[test]
fn test_get_out_of_bounds() {
let payload = Bytes::from("a;;;b");
let parts = PayloadParts::split(payload, b";;;");
assert_eq!(parts.get(0), b"a");
assert_eq!(parts.get(1), b"b");
assert_eq!(parts.get(2), b""); assert_eq!(parts.get(100), b"");
}
#[test]
fn test_get_bytes_zero_copy() {
let payload = Bytes::from("hello;;;world");
let parts = PayloadParts::split(payload.clone(), b";;;");
let part0 = parts.get_bytes(0);
let part1 = parts.get_bytes(1);
assert_eq!(&part0[..], b"hello");
assert_eq!(&part1[..], b"world");
assert_eq!(part0.as_ptr(), payload.as_ptr());
}
#[test]
fn test_extract_header_basic() {
let headers = b"Content-Type: application/json\r\nX-Custom: value\r\n";
assert_eq!(
extract_header_value(headers, b"Content-Type"),
Some(b"application/json".as_slice())
);
assert_eq!(
extract_header_value(headers, b"X-Custom"),
Some(b"value".as_slice())
);
}
#[test]
fn test_extract_header_case_insensitive() {
let headers = b"Content-Type: application/json\r\n";
assert_eq!(
extract_header_value(headers, b"content-type"),
Some(b"application/json".as_slice())
);
assert_eq!(
extract_header_value(headers, b"CONTENT-TYPE"),
Some(b"application/json".as_slice())
);
}
#[test]
fn test_extract_header_with_whitespace() {
let headers = b"X-Custom: value with spaces \r\n";
assert_eq!(
extract_header_value(headers, b"X-Custom"),
Some(b"value with spaces ".as_slice())
);
}
#[test]
fn test_extract_header_not_found() {
let headers = b"Content-Type: application/json\r\n";
assert_eq!(extract_header_value(headers, b"X-Missing"), None);
}
#[test]
fn test_extract_header_empty() {
assert_eq!(extract_header_value(b"", b"Content-Type"), None);
assert_eq!(extract_header_value(b"Content-Type: value", b""), None);
}
#[test]
fn test_extract_header_no_crlf() {
let headers = b"Content-Type: value\nX-Other: other\n";
assert_eq!(
extract_header_value(headers, b"Content-Type"),
Some(b"value".as_slice())
);
assert_eq!(
extract_header_value(headers, b"X-Other"),
Some(b"other".as_slice())
);
}
}