use bytes::{Buf, BufMut};
use crate::ProtocolError;
use crate::primitives::fixed::{get_i32, put_i32};
use crate::primitives::varint::{get_uvarint, put_uvarint, uvarint_len};
pub fn put_array_len<B: BufMut>(buf: &mut B, n: usize, flexible: bool) {
if flexible {
put_uvarint(buf, u32::try_from(n + 1).expect("array too large"));
} else {
put_i32(buf, i32::try_from(n).expect("array too large"));
}
}
pub fn put_nullable_array_len<B: BufMut>(buf: &mut B, len: Option<usize>, flexible: bool) {
match (flexible, len) {
(false, None) => put_i32(buf, -1),
(false, Some(n)) => put_i32(buf, i32::try_from(n).expect("array too large")),
(true, None) => put_uvarint(buf, 0),
(true, Some(n)) => put_uvarint(buf, u32::try_from(n + 1).expect("array too large")),
}
}
#[must_use]
pub fn array_len_prefix_len(n: usize, flexible: bool) -> usize {
if flexible {
uvarint_len(u32::try_from(n + 1).unwrap())
} else {
4
}
}
#[must_use]
pub fn nullable_array_len_prefix_len(len: Option<usize>, flexible: bool) -> usize {
match (flexible, len) {
(false, _) => 4,
(true, None) => uvarint_len(0),
(true, Some(n)) => uvarint_len(u32::try_from(n + 1).unwrap()),
}
}
pub fn get_array_len<B: Buf>(buf: &mut B, flexible: bool) -> Result<usize, ProtocolError> {
if flexible {
let raw = get_uvarint(buf)?;
if raw == 0 {
return Err(ProtocolError::InvalidValue(
"non-nullable array was null (compact encoding)",
));
}
Ok((raw - 1) as usize)
} else {
let n = get_i32(buf)?;
if n < 0 {
return Err(ProtocolError::InvalidValue(
"non-nullable array had negative length",
));
}
Ok(usize::try_from(n).expect("n is non-negative"))
}
}
pub fn get_nullable_array_len<B: Buf>(
buf: &mut B,
flexible: bool,
) -> Result<Option<usize>, ProtocolError> {
if flexible {
let raw = get_uvarint(buf)?;
if raw == 0 {
Ok(None)
} else {
Ok(Some((raw - 1) as usize))
}
} else {
let n = get_i32(buf)?;
if n < 0 {
Ok(None)
} else {
Ok(Some(usize::try_from(n).expect("n is non-negative")))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use bytes::BytesMut;
#[test]
fn non_flex_empty_array_roundtrip() {
let mut buf = BytesMut::new();
put_array_len(&mut buf, 0, false);
assert!(buf.len() == 4, "prefix must be 4 bytes");
let mut cur = &buf[..];
assert!(get_array_len(&mut cur, false).unwrap() == 0);
assert!(cur.is_empty());
}
#[test]
fn non_flex_three_element_array_roundtrip() {
let mut buf = BytesMut::new();
put_array_len(&mut buf, 3, false);
assert!(buf.len() == 4);
let mut cur = &buf[..];
assert!(get_array_len(&mut cur, false).unwrap() == 3);
assert!(cur.is_empty());
}
#[test]
fn flex_empty_array_roundtrip() {
let mut buf = BytesMut::new();
put_array_len(&mut buf, 0, true);
assert!(&buf[..] == &[0x01]);
let mut cur = &buf[..];
assert!(get_array_len(&mut cur, true).unwrap() == 0);
assert!(cur.is_empty());
}
#[test]
fn flex_three_element_array_roundtrip() {
let mut buf = BytesMut::new();
put_array_len(&mut buf, 3, true);
assert!(&buf[..] == &[0x04]);
let mut cur = &buf[..];
assert!(get_array_len(&mut cur, true).unwrap() == 3);
assert!(cur.is_empty());
}
#[test]
fn non_flex_nullable_null_roundtrip() {
let mut buf = BytesMut::new();
put_nullable_array_len(&mut buf, None, false);
assert!(&buf[..] == &[0xFF, 0xFF, 0xFF, 0xFF]); let mut cur = &buf[..];
assert!(get_nullable_array_len(&mut cur, false).unwrap() == None);
assert!(cur.is_empty());
}
#[test]
fn non_flex_nullable_some_roundtrip() {
let mut buf = BytesMut::new();
put_nullable_array_len(&mut buf, Some(3), false);
let mut cur = &buf[..];
assert!(get_nullable_array_len(&mut cur, false).unwrap() == Some(3));
assert!(cur.is_empty());
}
#[test]
fn flex_nullable_null_roundtrip() {
let mut buf = BytesMut::new();
put_nullable_array_len(&mut buf, None, true);
assert!(&buf[..] == &[0x00]); let mut cur = &buf[..];
assert!(get_nullable_array_len(&mut cur, true).unwrap() == None);
assert!(cur.is_empty());
}
#[test]
fn flex_nullable_some_roundtrip() {
let mut buf = BytesMut::new();
put_nullable_array_len(&mut buf, Some(3), true);
assert!(&buf[..] == &[0x04]);
let mut cur = &buf[..];
assert!(get_nullable_array_len(&mut cur, true).unwrap() == Some(3));
assert!(cur.is_empty());
}
#[test]
fn array_len_prefix_len_non_flex() {
assert!(array_len_prefix_len(0, false) == 4);
assert!(array_len_prefix_len(100, false) == 4);
}
#[test]
fn array_len_prefix_len_flex() {
assert!(array_len_prefix_len(0, true) == 1);
assert!(array_len_prefix_len(126, true) == 1);
assert!(array_len_prefix_len(127, true) == 2);
}
#[test]
fn nullable_prefix_len_non_flex_always_4() {
assert!(nullable_array_len_prefix_len(None, false) == 4);
assert!(nullable_array_len_prefix_len(Some(3), false) == 4);
}
#[test]
fn nullable_prefix_len_flex_null_is_1() {
assert!(nullable_array_len_prefix_len(None, true) == 1);
}
#[test]
fn non_nullable_rejects_null_non_flex() {
let bytes = (-1i32).to_be_bytes();
let mut cur = &bytes[..];
assert!(matches!(
get_array_len(&mut cur, false),
Err(ProtocolError::InvalidValue(_))
));
}
#[test]
fn non_nullable_rejects_null_flex() {
let bytes = [0x00u8];
let mut cur = &bytes[..];
assert!(matches!(
get_array_len(&mut cur, true),
Err(ProtocolError::InvalidValue(_))
));
}
}