samsa 0.1.8

Rust-native Kafka/Redpanda protocol and client implementation.
Documentation
//! Deserialize data from the bytecode protocol.
use bytes::Bytes;
use nom::{
    bytes::complete::take,
    combinator::map,
    error::{ErrorKind, ParseError},
    multi::many_m_n,
    number::complete::{be_i16, be_i32, be_u16, be_u32},
    Err::*,
    IResult,
    Needed::Unknown,
};
use nombytes::NomBytes;
use num_traits::FromPrimitive;

use crate::error::KafkaCode;

pub fn parse_kafka_code(s: NomBytes) -> IResult<NomBytes, KafkaCode> {
    map(be_i16, |n| match FromPrimitive::from_i16(n) {
        Some(e) => e,
        None => {
            tracing::error!("Unhandled Kakfa code :: {}", n);
            KafkaCode::Unknown
        }
    })(s)
}

// NOTE:: Redpanda sends us values that are 2* the correct value... so make sure to
// divide your values by 2 when you go to use them
pub fn take_varint<E>(i: NomBytes) -> nom::IResult<NomBytes, usize, E>
where
    E: ParseError<NomBytes>,
{
    let mut res: usize = 0;
    let mut count: usize = 0;
    let mut remainder = i;
    loop {
        let byte = match take::<usize, NomBytes, ()>(1)(remainder) {
            Ok((rest, bytes)) => {
                remainder = rest;
                let bytes = bytes.to_bytes();
                bytes.first().cloned().unwrap()
            }
            Err(_) => return Err(Incomplete(Unknown)),
        };
        res += ((byte as usize) & 127)
            .checked_shl((count * 7).try_into().unwrap_or(u32::MAX))
            .ok_or_else(|| Error(E::from_error_kind(remainder.clone(), ErrorKind::MapOpt)))?;
        count += 1;
        if (byte >> 7) == 0 {
            return Ok((remainder, res));
        }
    }
}

pub fn parse_string(s: NomBytes) -> IResult<NomBytes, Bytes> {
    let (s, length) = be_u16(s)?;
    let (s, string) = take(length)(s)?;
    Ok((s, string.into_bytes()))
}

pub fn parse_bytes(s: NomBytes) -> IResult<NomBytes, Bytes> {
    let (s, length) = be_u32(s)?;
    let (s, string) = take(length)(s)?;
    Ok((s, string.into_bytes()))
}

pub fn parse_array<O, E, F>(f: F) -> impl FnMut(NomBytes) -> IResult<NomBytes, Vec<O>, E>
where
    F: nom::Parser<NomBytes, O, E> + Copy,
    E: nom::error::ParseError<NomBytes>,
{
    move |input: NomBytes| {
        let i = input.clone();
        let (i, length) = be_i32(i)?;
        if length == -1 {
            return Ok((i, vec![]));
        }
        many_m_n(length as usize, length as usize, f)(i)
    }
}

pub fn parse_varint_array<O, E, F>(f: F) -> impl FnMut(NomBytes) -> IResult<NomBytes, Vec<O>, E>
where
    F: nom::Parser<NomBytes, O, E> + Copy,
    E: nom::error::ParseError<NomBytes>,
{
    move |input: NomBytes| {
        let i = input.clone();
        let (i, length) = take_varint(i)?;
        if length == 0 {
            return Ok((i, vec![]));
        }
        many_m_n(length / 2, length / 2, f)(i)
    }
}

pub fn parse_boolean(s: NomBytes) -> IResult<NomBytes, bool> {
    let (s, b) = take(1_usize)(s)?;
    let b = b != NomBytes::from(b"\x00" as &[u8]);
    Ok((s, b))
}

pub fn parse_nullable_string(s: NomBytes) -> IResult<NomBytes, Option<Bytes>> {
    let (s, length) = be_i16(s)?;
    if length == -1 {
        return Ok((s, None));
    }

    let (s, string) = take(length as u16)(s)?;
    Ok((s, Some(string.into_bytes())))
}

pub fn parse_nullable_bytes(s: NomBytes) -> IResult<NomBytes, Option<Bytes>> {
    let (s, length) = be_i32(s)?;
    if length == -1 {
        return Ok((s, None));
    }

    let (s, bytes) = take(length as u32)(s)?;
    Ok((s, Some(bytes.into_bytes())))
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn parse_varint_simple() {
        assert_eq!(
            take_varint::<()>(NomBytes::from(b"\x0b\x01\x02\x03" as &[u8])),
            Ok((NomBytes::from(b"\x01\x02\x03" as &[u8]), 11))
        );
    }

    #[test]
    fn parse_varint_twobyte() {
        assert_eq!(
            take_varint::<()>(NomBytes::from(b"\x84\x02\x04\x05\x06" as &[u8])),
            Ok((NomBytes::from(b"\x04\x05\x06" as &[u8]), 260))
        );
    }

    #[cfg(target_pointer_width = "64")]
    #[test]
    fn parse_varlong() {
        assert_eq!(
            take_varint::<()>(NomBytes::from(
                b"\xff\xff\xff\xff\xff\xff\xff\xff\x7f\x04\x05\x06" as &[u8]
            )),
            Ok((
                NomBytes::from(b"\x04\x05\x06" as &[u8]),
                9223372036854775807
            ))
        );
    }

    #[test]
    fn test_parse_string() {
        let buf = NomBytes::from(b"\x00\x04\x72\x75\x73\x74" as &[u8]);

        assert_eq!(
            parse_string(buf).unwrap().1,
            NomBytes::from(b"\x72\x75\x73\x74" as &[u8]).to_bytes()
        );
    }

    #[test]
    fn test_parse_bool_false() {
        let buf = NomBytes::from(b"\x00" as &[u8]);

        assert!(!parse_boolean(buf).unwrap().1);
    }

    #[test]
    fn test_parse_bool_true() {
        let buf = NomBytes::from(b"\x01" as &[u8]);

        assert!(parse_boolean(buf).unwrap().1);
    }

    #[test]
    fn test_parse_bool_two() {
        let buf = NomBytes::from(b"\x02" as &[u8]);

        assert!(parse_boolean(buf).unwrap().1);
    }

    #[test]
    fn test_parse_array() {
        let buf = NomBytes::from(
            [
                0, 0, 0, 2, // array size
                0, 4, 114, 117, 115, 116, // string
                0, 4, 114, 117, 115, 116, // string
                0, 0, 0, // leftover input
            ]
            .as_slice(),
        );

        assert_eq!(
            parse_array(parse_string)(buf).unwrap().1,
            vec![String::from("rust"), String::from("rust")]
        );
    }
}