Skip to main content

reddb_wire/redwire/
bulk_binary.rs

1//! RedWire binary bulk payload codec.
2//!
3//! Layout:
4//! `[collection_len:u16][collection][column_count:u16]`
5//! `([column_len:u16][column_name])*`
6//! `[row_count:u32]([legacy WireValue])*(column_count * row_count)`.
7
8use crate::legacy::{encode_value, try_decode_value, WireValue};
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum BulkBinaryFlavor {
13    Binary,
14    Prevalidated,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub struct BulkBinaryPayload {
19    pub collection: String,
20    pub columns: Vec<String>,
21    pub rows: Vec<Vec<WireValue>>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum BulkBinaryError {
26    PayloadTooShort(BulkBinaryFlavor),
27    MissingCollectionLength(BulkBinaryFlavor),
28    TruncatedCollectionName(BulkBinaryFlavor),
29    InvalidCollectionName(BulkBinaryFlavor),
30    MissingColumnCount(BulkBinaryFlavor),
31    MissingColumnNameLength(BulkBinaryFlavor),
32    TruncatedColumnName(BulkBinaryFlavor),
33    InvalidColumnName(BulkBinaryFlavor),
34    MissingRowCount(BulkBinaryFlavor),
35    Value(BulkBinaryFlavor, &'static str),
36    LengthOverflow(&'static str),
37    RowWidthMismatch { got: usize, expected: usize },
38    RowCountOverflow,
39}
40
41impl fmt::Display for BulkBinaryError {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::PayloadTooShort(flavor) => write!(f, "{}: payload too short", prefix(*flavor)),
45            Self::MissingCollectionLength(flavor) => {
46                write!(f, "{}: missing collection length", label(*flavor))
47            }
48            Self::TruncatedCollectionName(flavor) => {
49                write!(f, "{}: truncated collection name", label(*flavor))
50            }
51            Self::InvalidCollectionName(flavor) => {
52                write!(f, "{}: invalid collection name", label(*flavor))
53            }
54            Self::MissingColumnCount(flavor) => {
55                write!(f, "{}: missing column count", label(*flavor))
56            }
57            Self::MissingColumnNameLength(flavor) => {
58                write!(f, "{}: missing column name length", label(*flavor))
59            }
60            Self::TruncatedColumnName(flavor) => {
61                write!(f, "{}: truncated column name", label(*flavor))
62            }
63            Self::InvalidColumnName(flavor) => {
64                write!(f, "{}: invalid column name", label(*flavor))
65            }
66            Self::MissingRowCount(flavor) => write!(f, "{}: missing row count", label(*flavor)),
67            Self::Value(flavor, err) => write!(f, "{}: {err}", label(*flavor)),
68            Self::LengthOverflow(field) => {
69                write!(f, "{field} is too large for RedWire binary bulk")
70            }
71            Self::RowWidthMismatch { got, expected } => {
72                write!(f, "row had {got} values for {expected} columns")
73            }
74            Self::RowCountOverflow => write!(f, "row count is too large for RedWire binary bulk"),
75        }
76    }
77}
78
79impl std::error::Error for BulkBinaryError {}
80
81pub fn encode_bulk_binary_payload(
82    collection: &str,
83    columns: &[&str],
84    rows: &[Vec<WireValue>],
85) -> Result<Vec<u8>, BulkBinaryError> {
86    write_len_u16(collection.len(), "collection")?;
87    write_len_u16(columns.len(), "columns")?;
88    if rows.len() > u32::MAX as usize {
89        return Err(BulkBinaryError::RowCountOverflow);
90    }
91    let mut out = Vec::with_capacity(64 + rows.len() * columns.len() * 16);
92    write_string_u16(&mut out, collection, "collection")?;
93    out.extend_from_slice(&(columns.len() as u16).to_le_bytes());
94    for column in columns {
95        write_string_u16(&mut out, column, "column")?;
96    }
97    out.extend_from_slice(&(rows.len() as u32).to_le_bytes());
98    for row in rows {
99        if row.len() != columns.len() {
100            return Err(BulkBinaryError::RowWidthMismatch {
101                got: row.len(),
102                expected: columns.len(),
103            });
104        }
105        for value in row {
106            encode_value(&mut out, value);
107        }
108    }
109    Ok(out)
110}
111
112pub fn decode_bulk_binary_payload(
113    payload: &[u8],
114    flavor: BulkBinaryFlavor,
115) -> Result<BulkBinaryPayload, BulkBinaryError> {
116    let mut pos = 0;
117    if payload.len() < 6 {
118        return Err(BulkBinaryError::PayloadTooShort(flavor));
119    }
120    let coll_len = read_u16(
121        payload,
122        &mut pos,
123        BulkBinaryError::MissingCollectionLength(flavor),
124    )? as usize;
125    let collection = read_string(
126        payload,
127        &mut pos,
128        coll_len,
129        BulkBinaryError::TruncatedCollectionName(flavor),
130        BulkBinaryError::InvalidCollectionName(flavor),
131    )?;
132    let ncols = read_u16(
133        payload,
134        &mut pos,
135        BulkBinaryError::MissingColumnCount(flavor),
136    )? as usize;
137    let mut columns = Vec::with_capacity(ncols);
138    for _ in 0..ncols {
139        let name_len = read_u16(
140            payload,
141            &mut pos,
142            BulkBinaryError::MissingColumnNameLength(flavor),
143        )? as usize;
144        columns.push(read_string(
145            payload,
146            &mut pos,
147            name_len,
148            BulkBinaryError::TruncatedColumnName(flavor),
149            BulkBinaryError::InvalidColumnName(flavor),
150        )?);
151    }
152    let nrows = read_u32(payload, &mut pos, BulkBinaryError::MissingRowCount(flavor))? as usize;
153    let mut rows = Vec::with_capacity(nrows);
154    for _ in 0..nrows {
155        let mut values = Vec::with_capacity(ncols);
156        for _ in 0..ncols {
157            values.push(
158                try_decode_value(payload, &mut pos)
159                    .map_err(|err| BulkBinaryError::Value(flavor, err))?,
160            );
161        }
162        rows.push(values);
163    }
164    Ok(BulkBinaryPayload {
165        collection,
166        columns,
167        rows,
168    })
169}
170
171fn prefix(flavor: BulkBinaryFlavor) -> &'static str {
172    match flavor {
173        BulkBinaryFlavor::Binary => "binary bulk",
174        BulkBinaryFlavor::Prevalidated => "binary bulk prevalidated",
175    }
176}
177
178fn label(flavor: BulkBinaryFlavor) -> &'static str {
179    match flavor {
180        BulkBinaryFlavor::Binary => "binary bulk",
181        BulkBinaryFlavor::Prevalidated => "prevalidated",
182    }
183}
184
185fn write_string_u16(
186    out: &mut Vec<u8>,
187    value: &str,
188    field: &'static str,
189) -> Result<(), BulkBinaryError> {
190    write_len_u16(value.len(), field)?;
191    out.extend_from_slice(&(value.len() as u16).to_le_bytes());
192    out.extend_from_slice(value.as_bytes());
193    Ok(())
194}
195
196fn write_len_u16(len: usize, field: &'static str) -> Result<(), BulkBinaryError> {
197    if len > u16::MAX as usize {
198        return Err(BulkBinaryError::LengthOverflow(field));
199    }
200    Ok(())
201}
202
203fn read_u16(payload: &[u8], pos: &mut usize, err: BulkBinaryError) -> Result<u16, BulkBinaryError> {
204    let bytes = read_bytes(payload, pos, 2, err)?;
205    Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
206}
207
208fn read_u32(payload: &[u8], pos: &mut usize, err: BulkBinaryError) -> Result<u32, BulkBinaryError> {
209    let bytes = read_bytes(payload, pos, 4, err)?;
210    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
211}
212
213fn read_string(
214    payload: &[u8],
215    pos: &mut usize,
216    len: usize,
217    truncated_err: BulkBinaryError,
218    utf8_err: BulkBinaryError,
219) -> Result<String, BulkBinaryError> {
220    let bytes = read_bytes(payload, pos, len, truncated_err)?;
221    std::str::from_utf8(bytes)
222        .map(str::to_owned)
223        .map_err(|_| utf8_err)
224}
225
226fn read_bytes<'a>(
227    payload: &'a [u8],
228    pos: &mut usize,
229    len: usize,
230    err: BulkBinaryError,
231) -> Result<&'a [u8], BulkBinaryError> {
232    let end = pos.saturating_add(len);
233    if end > payload.len() {
234        return Err(err);
235    }
236    let bytes = &payload[*pos..end];
237    *pos = end;
238    Ok(bytes)
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn binary_bulk_payload_round_trips_values() {
247        let rows = vec![vec![
248            WireValue::I64(7),
249            WireValue::Text("Ada".into()),
250            WireValue::Bool(true),
251        ]];
252        let bytes = encode_bulk_binary_payload("users", &["id", "name", "active"], &rows).unwrap();
253        let decoded = decode_bulk_binary_payload(&bytes, BulkBinaryFlavor::Binary).unwrap();
254        assert_eq!(decoded.collection, "users");
255        assert_eq!(decoded.columns, vec!["id", "name", "active"]);
256        assert_eq!(decoded.rows, rows);
257    }
258
259    #[test]
260    fn binary_bulk_decode_preserves_error_prefixes() {
261        assert_eq!(
262            decode_bulk_binary_payload(&[0; 5], BulkBinaryFlavor::Binary)
263                .unwrap_err()
264                .to_string(),
265            "binary bulk: payload too short"
266        );
267        let payload = vec![1, 0, b't', 1, 0, 1, 0, b'x', 1, 0, 0, 0, 1];
268        assert_eq!(
269            decode_bulk_binary_payload(&payload, BulkBinaryFlavor::Prevalidated)
270                .unwrap_err()
271                .to_string(),
272            "prevalidated: truncated i64 value"
273        );
274    }
275}