#![allow(unsafe_op_in_unsafe_fn)]
use bytes::Bytes;
use crate::resp2::Frame;
#[inline(always)]
unsafe fn find_cr(buf: &[u8], from: usize) -> usize {
let ptr = buf.as_ptr();
let mut i = from;
while *ptr.add(i) != b'\r' {
i += 1;
}
i
}
#[inline(always)]
unsafe fn parse_i64_unchecked(buf: &[u8]) -> i64 {
let mut i = 0;
let neg = *buf.get_unchecked(0) == b'-';
if neg {
i = 1;
}
let mut v: i64 = 0;
while i < buf.len() {
v = v * 10 + (*buf.get_unchecked(i) - b'0') as i64;
i += 1;
}
if neg { -v } else { v }
}
#[inline(always)]
unsafe fn parse_usize_unchecked(buf: &[u8]) -> usize {
let mut v: usize = 0;
for i in 0..buf.len() {
v = v * 10 + (*buf.get_unchecked(i) - b'0') as usize;
}
v
}
unsafe fn parse_inner(input: &Bytes, pos: usize) -> (Frame, usize) {
let buf = input.as_ref();
let tag = *buf.get_unchecked(pos);
match tag {
b'+' => {
let cr = find_cr(buf, pos + 1);
(Frame::SimpleString(input.slice(pos + 1..cr)), cr + 2)
}
b'-' => {
let cr = find_cr(buf, pos + 1);
(Frame::Error(input.slice(pos + 1..cr)), cr + 2)
}
b':' => {
let cr = find_cr(buf, pos + 1);
let v = parse_i64_unchecked(buf.get_unchecked(pos + 1..cr));
(Frame::Integer(v), cr + 2)
}
b'$' => {
let cr = find_cr(buf, pos + 1);
let len_slice = buf.get_unchecked(pos + 1..cr);
if *len_slice.get_unchecked(0) == b'-' {
return (Frame::BulkString(None), cr + 2);
}
let len = parse_usize_unchecked(len_slice);
if len == 0 {
return (Frame::BulkString(Some(Bytes::new())), cr + 4);
}
let data_start = cr + 2;
let data_end = data_start + len;
(
Frame::BulkString(Some(input.slice(data_start..data_end))),
data_end + 2,
)
}
b'*' => {
let cr = find_cr(buf, pos + 1);
let len_slice = buf.get_unchecked(pos + 1..cr);
if *len_slice.get_unchecked(0) == b'-' {
return (Frame::Array(None), cr + 2);
}
let count = parse_usize_unchecked(len_slice);
if count == 0 {
return (Frame::Array(Some(Vec::new())), cr + 2);
}
let mut cursor = cr + 2;
let mut items = Vec::with_capacity(count);
for _ in 0..count {
let (frame, next) = parse_inner(input, cursor);
items.push(frame);
cursor = next;
}
(Frame::Array(Some(items)), cursor)
}
_ => core::hint::unreachable_unchecked(),
}
}
pub unsafe fn parse_frame_unchecked(input: Bytes) -> (Frame, Bytes) {
let (frame, consumed) = parse_inner(&input, 0);
(frame, input.slice(consumed..))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::resp2;
#[test]
fn unchecked_matches_safe() {
let cases = vec![
"+OK\r\n",
"-ERR fail\r\n",
":42\r\n",
":-123\r\n",
":0\r\n",
"$5\r\nhello\r\n",
"$0\r\n\r\n",
"$-1\r\n",
"*0\r\n",
"*-1\r\n",
"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
"*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
"*2\r\n*1\r\n:1\r\n+OK\r\n",
];
for wire in cases {
let input = Bytes::from(wire);
let (safe_frame, safe_rest) = resp2::parse_frame(input.clone()).unwrap();
let (unsafe_frame, unsafe_rest) = unsafe { parse_frame_unchecked(input) };
assert_eq!(safe_frame, unsafe_frame, "mismatch for: {wire}");
assert_eq!(safe_rest, unsafe_rest, "rest mismatch for: {wire}");
}
}
}