Skip to main content

reddb_wire/redwire/
bulk_stream.rs

1//! RedWire legacy bulk-stream payload codec.
2//!
3//! Start layout:
4//! `[collection_len:u16][collection][column_count:u16]([column_len:u16][column_name])*`.
5//! Rows layout:
6//! `[row_count:u32]([legacy WireValue])*(column_count * row_count)`.
7
8use crate::legacy::{encode_value, try_decode_value, WireValue};
9use std::fmt;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BulkStreamStartPayload {
13    pub collection: String,
14    pub columns: Vec<String>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub struct BulkStreamRowsPayload {
19    pub rows: Vec<Vec<WireValue>>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum BulkStreamError {
24    MissingCollectionLength,
25    TruncatedCollectionName,
26    InvalidCollectionName,
27    MissingColumnCount,
28    MissingColumnNameLength,
29    TruncatedColumnName,
30    InvalidColumnName,
31    MissingRowCount,
32    Value(&'static str),
33    LengthOverflow(&'static str),
34    RowWidthMismatch { got: usize, expected: usize },
35    RowCountOverflow,
36}
37
38impl fmt::Display for BulkStreamError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::MissingCollectionLength => write!(f, "stream start: missing collection length"),
42            Self::TruncatedCollectionName => write!(f, "stream start: truncated collection name"),
43            Self::InvalidCollectionName => write!(f, "stream start: invalid collection name"),
44            Self::MissingColumnCount => write!(f, "stream start: missing column count"),
45            Self::MissingColumnNameLength => {
46                write!(f, "stream start: missing column name length")
47            }
48            Self::TruncatedColumnName => write!(f, "stream start: truncated column name"),
49            Self::InvalidColumnName => write!(f, "stream start: invalid column name"),
50            Self::MissingRowCount => write!(f, "stream rows: missing row count"),
51            Self::Value(err) => write!(f, "stream rows: {err}"),
52            Self::LengthOverflow(field) => {
53                write!(f, "{field} is too large for RedWire bulk stream")
54            }
55            Self::RowWidthMismatch { got, expected } => {
56                write!(f, "row had {got} values for {expected} columns")
57            }
58            Self::RowCountOverflow => write!(f, "row count is too large for RedWire bulk stream"),
59        }
60    }
61}
62
63impl std::error::Error for BulkStreamError {}
64
65pub fn encode_bulk_stream_start_payload(
66    collection: &str,
67    columns: &[&str],
68) -> Result<Vec<u8>, BulkStreamError> {
69    write_len_u16(collection.len(), "collection")?;
70    write_len_u16(columns.len(), "columns")?;
71    let mut out = Vec::with_capacity(4 + collection.len() + columns.len() * 16);
72    write_string_u16(&mut out, collection, "collection")?;
73    out.extend_from_slice(&(columns.len() as u16).to_le_bytes());
74    for column in columns {
75        write_string_u16(&mut out, column, "column")?;
76    }
77    Ok(out)
78}
79
80pub fn decode_bulk_stream_start_payload(
81    payload: &[u8],
82) -> Result<BulkStreamStartPayload, BulkStreamError> {
83    let mut pos = 0;
84    let coll_len = read_u16(payload, &mut pos, BulkStreamError::MissingCollectionLength)? as usize;
85    let collection = read_string(
86        payload,
87        &mut pos,
88        coll_len,
89        BulkStreamError::TruncatedCollectionName,
90        BulkStreamError::InvalidCollectionName,
91    )?;
92    let ncols = read_u16(payload, &mut pos, BulkStreamError::MissingColumnCount)? as usize;
93    let mut columns = Vec::with_capacity(ncols);
94    for _ in 0..ncols {
95        let len = read_u16(payload, &mut pos, BulkStreamError::MissingColumnNameLength)? as usize;
96        columns.push(read_string(
97            payload,
98            &mut pos,
99            len,
100            BulkStreamError::TruncatedColumnName,
101            BulkStreamError::InvalidColumnName,
102        )?);
103    }
104    Ok(BulkStreamStartPayload {
105        collection,
106        columns,
107    })
108}
109
110pub fn encode_bulk_stream_rows_payload(
111    rows: &[Vec<WireValue>],
112    column_count: usize,
113) -> Result<Vec<u8>, BulkStreamError> {
114    if rows.len() > u32::MAX as usize {
115        return Err(BulkStreamError::RowCountOverflow);
116    }
117    let mut out = Vec::with_capacity(4 + rows.len() * column_count * 16);
118    out.extend_from_slice(&(rows.len() as u32).to_le_bytes());
119    for row in rows {
120        if row.len() != column_count {
121            return Err(BulkStreamError::RowWidthMismatch {
122                got: row.len(),
123                expected: column_count,
124            });
125        }
126        for value in row {
127            encode_value(&mut out, value);
128        }
129    }
130    Ok(out)
131}
132
133pub fn decode_bulk_stream_rows_payload(
134    payload: &[u8],
135    column_count: usize,
136) -> Result<BulkStreamRowsPayload, BulkStreamError> {
137    let mut pos = 0;
138    let nrows = read_u32(payload, &mut pos, BulkStreamError::MissingRowCount)? as usize;
139    let mut rows = Vec::with_capacity(nrows);
140    for _ in 0..nrows {
141        let mut values = Vec::with_capacity(column_count);
142        for _ in 0..column_count {
143            values.push(try_decode_value(payload, &mut pos).map_err(BulkStreamError::Value)?);
144        }
145        rows.push(values);
146    }
147    Ok(BulkStreamRowsPayload { rows })
148}
149
150fn write_string_u16(
151    out: &mut Vec<u8>,
152    value: &str,
153    field: &'static str,
154) -> Result<(), BulkStreamError> {
155    write_len_u16(value.len(), field)?;
156    out.extend_from_slice(&(value.len() as u16).to_le_bytes());
157    out.extend_from_slice(value.as_bytes());
158    Ok(())
159}
160
161fn write_len_u16(len: usize, field: &'static str) -> Result<(), BulkStreamError> {
162    if len > u16::MAX as usize {
163        return Err(BulkStreamError::LengthOverflow(field));
164    }
165    Ok(())
166}
167
168fn read_u16(payload: &[u8], pos: &mut usize, err: BulkStreamError) -> Result<u16, BulkStreamError> {
169    let bytes = read_bytes(payload, pos, 2, err)?;
170    Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
171}
172
173fn read_u32(payload: &[u8], pos: &mut usize, err: BulkStreamError) -> Result<u32, BulkStreamError> {
174    let bytes = read_bytes(payload, pos, 4, err)?;
175    Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
176}
177
178fn read_string(
179    payload: &[u8],
180    pos: &mut usize,
181    len: usize,
182    truncated_err: BulkStreamError,
183    utf8_err: BulkStreamError,
184) -> Result<String, BulkStreamError> {
185    let bytes = read_bytes(payload, pos, len, truncated_err)?;
186    std::str::from_utf8(bytes)
187        .map(str::to_owned)
188        .map_err(|_| utf8_err)
189}
190
191fn read_bytes<'a>(
192    payload: &'a [u8],
193    pos: &mut usize,
194    len: usize,
195    err: BulkStreamError,
196) -> Result<&'a [u8], BulkStreamError> {
197    let end = pos.saturating_add(len);
198    if end > payload.len() {
199        return Err(err);
200    }
201    let bytes = &payload[*pos..end];
202    *pos = end;
203    Ok(bytes)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn stream_start_payload_round_trips() {
212        let bytes = encode_bulk_stream_start_payload("events", &["id", "name"]).unwrap();
213        let decoded = decode_bulk_stream_start_payload(&bytes).unwrap();
214        assert_eq!(decoded.collection, "events");
215        assert_eq!(decoded.columns, vec!["id", "name"]);
216    }
217
218    #[test]
219    fn stream_rows_payload_round_trips_values() {
220        let rows = vec![vec![WireValue::I64(7), WireValue::Text("Ada".into())]];
221        let bytes = encode_bulk_stream_rows_payload(&rows, 2).unwrap();
222        assert_eq!(
223            decode_bulk_stream_rows_payload(&bytes, 2).unwrap().rows,
224            rows
225        );
226    }
227
228    #[test]
229    fn stream_rows_payload_preserves_error_prefix() {
230        let payload = vec![1, 0, 0, 0, 1];
231        assert_eq!(
232            decode_bulk_stream_rows_payload(&payload, 1)
233                .unwrap_err()
234                .to_string(),
235            "stream rows: truncated i64 value"
236        );
237    }
238}