1use crate::legacy::{encode_value, try_decode_value, WireValue};
9use std::fmt;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BulkStreamStartPayload {
13 pub collection: String,
14 pub columns: Vec<String>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub struct BulkStreamRowsPayload {
19 pub rows: Vec<Vec<WireValue>>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum BulkStreamError {
24 MissingCollectionLength,
25 TruncatedCollectionName,
26 InvalidCollectionName,
27 MissingColumnCount,
28 MissingColumnNameLength,
29 TruncatedColumnName,
30 InvalidColumnName,
31 MissingRowCount,
32 Value(&'static str),
33 LengthOverflow(&'static str),
34 RowWidthMismatch { got: usize, expected: usize },
35 RowCountOverflow,
36}
37
38impl fmt::Display for BulkStreamError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::MissingCollectionLength => write!(f, "stream start: missing collection length"),
42 Self::TruncatedCollectionName => write!(f, "stream start: truncated collection name"),
43 Self::InvalidCollectionName => write!(f, "stream start: invalid collection name"),
44 Self::MissingColumnCount => write!(f, "stream start: missing column count"),
45 Self::MissingColumnNameLength => {
46 write!(f, "stream start: missing column name length")
47 }
48 Self::TruncatedColumnName => write!(f, "stream start: truncated column name"),
49 Self::InvalidColumnName => write!(f, "stream start: invalid column name"),
50 Self::MissingRowCount => write!(f, "stream rows: missing row count"),
51 Self::Value(err) => write!(f, "stream rows: {err}"),
52 Self::LengthOverflow(field) => {
53 write!(f, "{field} is too large for RedWire bulk stream")
54 }
55 Self::RowWidthMismatch { got, expected } => {
56 write!(f, "row had {got} values for {expected} columns")
57 }
58 Self::RowCountOverflow => write!(f, "row count is too large for RedWire bulk stream"),
59 }
60 }
61}
62
63impl std::error::Error for BulkStreamError {}
64
65pub fn encode_bulk_stream_start_payload(
66 collection: &str,
67 columns: &[&str],
68) -> Result<Vec<u8>, BulkStreamError> {
69 write_len_u16(collection.len(), "collection")?;
70 write_len_u16(columns.len(), "columns")?;
71 let mut out = Vec::with_capacity(4 + collection.len() + columns.len() * 16);
72 write_string_u16(&mut out, collection, "collection")?;
73 out.extend_from_slice(&(columns.len() as u16).to_le_bytes());
74 for column in columns {
75 write_string_u16(&mut out, column, "column")?;
76 }
77 Ok(out)
78}
79
80pub fn decode_bulk_stream_start_payload(
81 payload: &[u8],
82) -> Result<BulkStreamStartPayload, BulkStreamError> {
83 let mut pos = 0;
84 let coll_len = read_u16(payload, &mut pos, BulkStreamError::MissingCollectionLength)? as usize;
85 let collection = read_string(
86 payload,
87 &mut pos,
88 coll_len,
89 BulkStreamError::TruncatedCollectionName,
90 BulkStreamError::InvalidCollectionName,
91 )?;
92 let ncols = read_u16(payload, &mut pos, BulkStreamError::MissingColumnCount)? as usize;
93 let mut columns = Vec::with_capacity(ncols);
94 for _ in 0..ncols {
95 let len = read_u16(payload, &mut pos, BulkStreamError::MissingColumnNameLength)? as usize;
96 columns.push(read_string(
97 payload,
98 &mut pos,
99 len,
100 BulkStreamError::TruncatedColumnName,
101 BulkStreamError::InvalidColumnName,
102 )?);
103 }
104 Ok(BulkStreamStartPayload {
105 collection,
106 columns,
107 })
108}
109
110pub fn encode_bulk_stream_rows_payload(
111 rows: &[Vec<WireValue>],
112 column_count: usize,
113) -> Result<Vec<u8>, BulkStreamError> {
114 if rows.len() > u32::MAX as usize {
115 return Err(BulkStreamError::RowCountOverflow);
116 }
117 let mut out = Vec::with_capacity(4 + rows.len() * column_count * 16);
118 out.extend_from_slice(&(rows.len() as u32).to_le_bytes());
119 for row in rows {
120 if row.len() != column_count {
121 return Err(BulkStreamError::RowWidthMismatch {
122 got: row.len(),
123 expected: column_count,
124 });
125 }
126 for value in row {
127 encode_value(&mut out, value);
128 }
129 }
130 Ok(out)
131}
132
133pub fn decode_bulk_stream_rows_payload(
134 payload: &[u8],
135 column_count: usize,
136) -> Result<BulkStreamRowsPayload, BulkStreamError> {
137 let mut pos = 0;
138 let nrows = read_u32(payload, &mut pos, BulkStreamError::MissingRowCount)? as usize;
139 let mut rows = Vec::with_capacity(nrows);
140 for _ in 0..nrows {
141 let mut values = Vec::with_capacity(column_count);
142 for _ in 0..column_count {
143 values.push(try_decode_value(payload, &mut pos).map_err(BulkStreamError::Value)?);
144 }
145 rows.push(values);
146 }
147 Ok(BulkStreamRowsPayload { rows })
148}
149
150fn write_string_u16(
151 out: &mut Vec<u8>,
152 value: &str,
153 field: &'static str,
154) -> Result<(), BulkStreamError> {
155 write_len_u16(value.len(), field)?;
156 out.extend_from_slice(&(value.len() as u16).to_le_bytes());
157 out.extend_from_slice(value.as_bytes());
158 Ok(())
159}
160
161fn write_len_u16(len: usize, field: &'static str) -> Result<(), BulkStreamError> {
162 if len > u16::MAX as usize {
163 return Err(BulkStreamError::LengthOverflow(field));
164 }
165 Ok(())
166}
167
168fn read_u16(payload: &[u8], pos: &mut usize, err: BulkStreamError) -> Result<u16, BulkStreamError> {
169 let bytes = read_bytes(payload, pos, 2, err)?;
170 Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
171}
172
173fn read_u32(payload: &[u8], pos: &mut usize, err: BulkStreamError) -> Result<u32, BulkStreamError> {
174 let bytes = read_bytes(payload, pos, 4, err)?;
175 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
176}
177
178fn read_string(
179 payload: &[u8],
180 pos: &mut usize,
181 len: usize,
182 truncated_err: BulkStreamError,
183 utf8_err: BulkStreamError,
184) -> Result<String, BulkStreamError> {
185 let bytes = read_bytes(payload, pos, len, truncated_err)?;
186 std::str::from_utf8(bytes)
187 .map(str::to_owned)
188 .map_err(|_| utf8_err)
189}
190
191fn read_bytes<'a>(
192 payload: &'a [u8],
193 pos: &mut usize,
194 len: usize,
195 err: BulkStreamError,
196) -> Result<&'a [u8], BulkStreamError> {
197 let end = pos.saturating_add(len);
198 if end > payload.len() {
199 return Err(err);
200 }
201 let bytes = &payload[*pos..end];
202 *pos = end;
203 Ok(bytes)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn stream_start_payload_round_trips() {
212 let bytes = encode_bulk_stream_start_payload("events", &["id", "name"]).unwrap();
213 let decoded = decode_bulk_stream_start_payload(&bytes).unwrap();
214 assert_eq!(decoded.collection, "events");
215 assert_eq!(decoded.columns, vec!["id", "name"]);
216 }
217
218 #[test]
219 fn stream_rows_payload_round_trips_values() {
220 let rows = vec![vec![WireValue::I64(7), WireValue::Text("Ada".into())]];
221 let bytes = encode_bulk_stream_rows_payload(&rows, 2).unwrap();
222 assert_eq!(
223 decode_bulk_stream_rows_payload(&bytes, 2).unwrap().rows,
224 rows
225 );
226 }
227
228 #[test]
229 fn stream_rows_payload_preserves_error_prefix() {
230 let payload = vec![1, 0, 0, 0, 1];
231 assert_eq!(
232 decode_bulk_stream_rows_payload(&payload, 1)
233 .unwrap_err()
234 .to_string(),
235 "stream rows: truncated i64 value"
236 );
237 }
238}