Skip to main content

crabka_protocol/primitives/
array.rs

1//! Wire-level helpers for `[]<elem>` and nullable `[]<elem>` arrays.
2//!
3//! Non-flexible arrays use a 4-byte big-endian `INT32` length prefix (−1 for
4//! null).  Flexible (compact) arrays use an unsigned varint whose value is
5//! `len + 1` (0 means null).
6
7use bytes::{Buf, BufMut};
8
9use crate::ProtocolError;
10use crate::primitives::fixed::{get_i32, put_i32};
11use crate::primitives::varint::{get_uvarint, put_uvarint, uvarint_len};
12
13/// Write a non-nullable array-length prefix.
14pub fn put_array_len<B: BufMut>(buf: &mut B, n: usize, flexible: bool) {
15    if flexible {
16        put_uvarint(buf, u32::try_from(n + 1).expect("array too large"));
17    } else {
18        put_i32(buf, i32::try_from(n).expect("array too large"));
19    }
20}
21
22/// Write a nullable array-length prefix.  `None` encodes as −1 (non-flex) or
23/// 0 (flex).
24pub fn put_nullable_array_len<B: BufMut>(buf: &mut B, len: Option<usize>, flexible: bool) {
25    match (flexible, len) {
26        (false, None) => put_i32(buf, -1),
27        (false, Some(n)) => put_i32(buf, i32::try_from(n).expect("array too large")),
28        (true, None) => put_uvarint(buf, 0),
29        (true, Some(n)) => put_uvarint(buf, u32::try_from(n + 1).expect("array too large")),
30    }
31}
32
33/// Number of bytes consumed by a non-nullable array-length prefix.
34#[must_use]
35pub fn array_len_prefix_len(n: usize, flexible: bool) -> usize {
36    if flexible {
37        uvarint_len(u32::try_from(n + 1).unwrap())
38    } else {
39        4
40    }
41}
42
43/// Number of bytes consumed by a nullable array-length prefix.
44#[must_use]
45pub fn nullable_array_len_prefix_len(len: Option<usize>, flexible: bool) -> usize {
46    match (flexible, len) {
47        (false, _) => 4,
48        (true, None) => uvarint_len(0),
49        (true, Some(n)) => uvarint_len(u32::try_from(n + 1).unwrap()),
50    }
51}
52
53/// Read a non-nullable array length.  Returns an error if the encoded value is
54/// null (−1 / 0).
55pub fn get_array_len<B: Buf>(buf: &mut B, flexible: bool) -> Result<usize, ProtocolError> {
56    let n = if flexible {
57        let raw = get_uvarint(buf)?;
58        if raw == 0 {
59            return Err(ProtocolError::InvalidValue(
60                "non-nullable array was null (compact encoding)",
61            ));
62        }
63        (raw - 1) as usize
64    } else {
65        let n = get_i32(buf)?;
66        if n < 0 {
67            return Err(ProtocolError::InvalidValue(
68                "non-nullable array had negative length",
69            ));
70        }
71        usize::try_from(n).expect("n is non-negative")
72    };
73    // Every array element occupies at least one byte on the wire, so a
74    // legitimate array of `n` elements always has at least `n` bytes left in
75    // the buffer. Any larger `n` is impossible and indicates a malformed or
76    // hostile frame; reject before a caller can `Vec::with_capacity(n)`.
77    if n > buf.remaining() {
78        return Err(ProtocolError::InvalidValue(
79            "array length exceeds remaining buffer",
80        ));
81    }
82    Ok(n)
83}
84
85/// Read a nullable array length.  Returns `None` when the encoded value is
86/// null (−1 non-flex, 0 flex).
87pub fn get_nullable_array_len<B: Buf>(
88    buf: &mut B,
89    flexible: bool,
90) -> Result<Option<usize>, ProtocolError> {
91    let n = if flexible {
92        let raw = get_uvarint(buf)?;
93        if raw == 0 {
94            return Ok(None);
95        }
96        (raw - 1) as usize
97    } else {
98        let n = get_i32(buf)?;
99        if n < 0 {
100            return Ok(None);
101        }
102        usize::try_from(n).expect("n is non-negative")
103    };
104    // See `get_array_len`: a length larger than the remaining bytes cannot
105    // describe a real array and must be rejected before pre-allocation.
106    if n > buf.remaining() {
107        return Err(ProtocolError::InvalidValue(
108            "array length exceeds remaining buffer",
109        ));
110    }
111    Ok(Some(n))
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use assert2::assert;
118    use bytes::BytesMut;
119
120    // --- non-nullable, non-flexible -----------------------------------------
121
122    #[test]
123    fn non_flex_empty_array_roundtrip() {
124        let mut buf = BytesMut::new();
125        put_array_len(&mut buf, 0, false);
126        assert!(buf.len() == 4, "prefix must be 4 bytes");
127        let mut cur = &buf[..];
128        assert!(get_array_len(&mut cur, false).unwrap() == 0);
129        assert!(cur.is_empty());
130    }
131
132    #[test]
133    fn non_flex_three_element_array_roundtrip() {
134        let mut buf = BytesMut::new();
135        put_array_len(&mut buf, 3, false);
136        assert!(buf.len() == 4);
137        // Provide 3 bytes of element payload so the length is within bounds.
138        buf.extend_from_slice(&[0u8; 3]);
139        let mut cur = &buf[..];
140        assert!(get_array_len(&mut cur, false).unwrap() == 3);
141        assert!(cur.len() == 3);
142    }
143
144    // --- non-nullable, flexible (compact) -----------------------------------
145
146    #[test]
147    fn flex_empty_array_roundtrip() {
148        let mut buf = BytesMut::new();
149        put_array_len(&mut buf, 0, true);
150        // len=0 → encode 1 → single byte 0x01
151        assert!(&buf[..] == &[0x01]);
152        let mut cur = &buf[..];
153        assert!(get_array_len(&mut cur, true).unwrap() == 0);
154        assert!(cur.is_empty());
155    }
156
157    #[test]
158    fn flex_three_element_array_roundtrip() {
159        let mut buf = BytesMut::new();
160        put_array_len(&mut buf, 3, true);
161        // len=3 → encode 4 → single byte 0x04
162        assert!(&buf[..] == &[0x04]);
163        // Provide 3 bytes of element payload so the length is within bounds.
164        buf.extend_from_slice(&[0u8; 3]);
165        let mut cur = &buf[..];
166        assert!(get_array_len(&mut cur, true).unwrap() == 3);
167        assert!(cur.len() == 3);
168    }
169
170    // --- nullable, non-flexible ---------------------------------------------
171
172    #[test]
173    fn non_flex_nullable_null_roundtrip() {
174        let mut buf = BytesMut::new();
175        put_nullable_array_len(&mut buf, None, false);
176        assert!(&buf[..] == &[0xFF, 0xFF, 0xFF, 0xFF]); // -1 in big-endian
177        let mut cur = &buf[..];
178        assert!(get_nullable_array_len(&mut cur, false).unwrap() == None);
179        assert!(cur.is_empty());
180    }
181
182    #[test]
183    fn non_flex_nullable_some_roundtrip() {
184        let mut buf = BytesMut::new();
185        put_nullable_array_len(&mut buf, Some(3), false);
186        buf.extend_from_slice(&[0u8; 3]);
187        let mut cur = &buf[..];
188        assert!(get_nullable_array_len(&mut cur, false).unwrap() == Some(3));
189        assert!(cur.len() == 3);
190    }
191
192    // --- nullable, flexible -------------------------------------------------
193
194    #[test]
195    fn flex_nullable_null_roundtrip() {
196        let mut buf = BytesMut::new();
197        put_nullable_array_len(&mut buf, None, true);
198        assert!(&buf[..] == &[0x00]); // 0 = null in compact encoding
199        let mut cur = &buf[..];
200        assert!(get_nullable_array_len(&mut cur, true).unwrap() == None);
201        assert!(cur.is_empty());
202    }
203
204    #[test]
205    fn flex_nullable_some_roundtrip() {
206        let mut buf = BytesMut::new();
207        put_nullable_array_len(&mut buf, Some(3), true);
208        // Some(3) → encode 4 → 0x04
209        assert!(&buf[..] == &[0x04]);
210        buf.extend_from_slice(&[0u8; 3]);
211        let mut cur = &buf[..];
212        assert!(get_nullable_array_len(&mut cur, true).unwrap() == Some(3));
213        assert!(cur.len() == 3);
214    }
215
216    // --- prefix_len helpers -------------------------------------------------
217
218    #[test]
219    fn array_len_prefix_len_non_flex() {
220        assert!(array_len_prefix_len(0, false) == 4);
221        assert!(array_len_prefix_len(100, false) == 4);
222    }
223
224    #[test]
225    fn array_len_prefix_len_flex() {
226        // len=0 → varint(1) = 1 byte; len=126 → varint(127) = 1 byte;
227        // len=127 → varint(128) = 2 bytes.
228        assert!(array_len_prefix_len(0, true) == 1);
229        assert!(array_len_prefix_len(126, true) == 1);
230        assert!(array_len_prefix_len(127, true) == 2);
231    }
232
233    #[test]
234    fn nullable_prefix_len_non_flex_always_4() {
235        assert!(nullable_array_len_prefix_len(None, false) == 4);
236        assert!(nullable_array_len_prefix_len(Some(3), false) == 4);
237    }
238
239    #[test]
240    fn nullable_prefix_len_flex_null_is_1() {
241        // null → varint(0) = 1 byte
242        assert!(nullable_array_len_prefix_len(None, true) == 1);
243    }
244
245    // --- error cases --------------------------------------------------------
246
247    #[test]
248    fn non_nullable_rejects_null_non_flex() {
249        let bytes = (-1i32).to_be_bytes();
250        let mut cur = &bytes[..];
251        assert!(matches!(
252            get_array_len(&mut cur, false),
253            Err(ProtocolError::InvalidValue(_))
254        ));
255    }
256
257    #[test]
258    fn non_nullable_rejects_null_flex() {
259        // varint 0 = null in compact encoding
260        let bytes = [0x00u8];
261        let mut cur = &bytes[..];
262        assert!(matches!(
263            get_array_len(&mut cur, true),
264            Err(ProtocolError::InvalidValue(_))
265        ));
266    }
267
268    // --- pre-allocation DoS bound (length > remaining buffer) ----------------
269
270    #[test]
271    fn non_flex_rejects_length_exceeding_remaining() {
272        // Declare a ~2-billion-element array with no element bytes following.
273        let mut buf = BytesMut::new();
274        put_array_len(&mut buf, 2_000_000_000, false);
275        let mut cur = &buf[..];
276        assert!(matches!(
277            get_array_len(&mut cur, false),
278            Err(ProtocolError::InvalidValue(
279                "array length exceeds remaining buffer"
280            ))
281        ));
282    }
283
284    #[test]
285    fn flex_rejects_length_exceeding_remaining() {
286        let mut buf = BytesMut::new();
287        put_array_len(&mut buf, 2_000_000_000, true);
288        let mut cur = &buf[..];
289        assert!(matches!(
290            get_array_len(&mut cur, true),
291            Err(ProtocolError::InvalidValue(
292                "array length exceeds remaining buffer"
293            ))
294        ));
295    }
296
297    #[test]
298    fn nullable_non_flex_rejects_length_exceeding_remaining() {
299        let mut buf = BytesMut::new();
300        put_nullable_array_len(&mut buf, Some(2_000_000_000), false);
301        let mut cur = &buf[..];
302        assert!(matches!(
303            get_nullable_array_len(&mut cur, false),
304            Err(ProtocolError::InvalidValue(
305                "array length exceeds remaining buffer"
306            ))
307        ));
308    }
309
310    #[test]
311    fn nullable_flex_rejects_length_exceeding_remaining() {
312        let mut buf = BytesMut::new();
313        put_nullable_array_len(&mut buf, Some(2_000_000_000), true);
314        let mut cur = &buf[..];
315        assert!(matches!(
316            get_nullable_array_len(&mut cur, true),
317            Err(ProtocolError::InvalidValue(
318                "array length exceeds remaining buffer"
319            ))
320        ));
321    }
322
323    #[test]
324    fn length_equal_to_remaining_is_accepted() {
325        // Non-flex: declare 5, supply exactly 5 trailing bytes.
326        let mut buf = BytesMut::new();
327        put_array_len(&mut buf, 5, false);
328        buf.extend_from_slice(&[0u8; 5]);
329        let mut cur = &buf[..];
330        assert!(get_array_len(&mut cur, false).unwrap() == 5);
331
332        // Compact: same, declare 5 with 5 trailing bytes.
333        let mut buf = BytesMut::new();
334        put_array_len(&mut buf, 5, true);
335        buf.extend_from_slice(&[0u8; 5]);
336        let mut cur = &buf[..];
337        assert!(get_array_len(&mut cur, true).unwrap() == 5);
338    }
339}