1use crate::legacy::{encode_value, try_decode_value, WireValue};
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum BulkBinaryFlavor {
13 Binary,
14 Prevalidated,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub struct BulkBinaryPayload {
19 pub collection: String,
20 pub columns: Vec<String>,
21 pub rows: Vec<Vec<WireValue>>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum BulkBinaryError {
26 PayloadTooShort(BulkBinaryFlavor),
27 MissingCollectionLength(BulkBinaryFlavor),
28 TruncatedCollectionName(BulkBinaryFlavor),
29 InvalidCollectionName(BulkBinaryFlavor),
30 MissingColumnCount(BulkBinaryFlavor),
31 MissingColumnNameLength(BulkBinaryFlavor),
32 TruncatedColumnName(BulkBinaryFlavor),
33 InvalidColumnName(BulkBinaryFlavor),
34 MissingRowCount(BulkBinaryFlavor),
35 Value(BulkBinaryFlavor, &'static str),
36 LengthOverflow(&'static str),
37 RowWidthMismatch { got: usize, expected: usize },
38 RowCountOverflow,
39}
40
41impl fmt::Display for BulkBinaryError {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 Self::PayloadTooShort(flavor) => write!(f, "{}: payload too short", prefix(*flavor)),
45 Self::MissingCollectionLength(flavor) => {
46 write!(f, "{}: missing collection length", label(*flavor))
47 }
48 Self::TruncatedCollectionName(flavor) => {
49 write!(f, "{}: truncated collection name", label(*flavor))
50 }
51 Self::InvalidCollectionName(flavor) => {
52 write!(f, "{}: invalid collection name", label(*flavor))
53 }
54 Self::MissingColumnCount(flavor) => {
55 write!(f, "{}: missing column count", label(*flavor))
56 }
57 Self::MissingColumnNameLength(flavor) => {
58 write!(f, "{}: missing column name length", label(*flavor))
59 }
60 Self::TruncatedColumnName(flavor) => {
61 write!(f, "{}: truncated column name", label(*flavor))
62 }
63 Self::InvalidColumnName(flavor) => {
64 write!(f, "{}: invalid column name", label(*flavor))
65 }
66 Self::MissingRowCount(flavor) => write!(f, "{}: missing row count", label(*flavor)),
67 Self::Value(flavor, err) => write!(f, "{}: {err}", label(*flavor)),
68 Self::LengthOverflow(field) => {
69 write!(f, "{field} is too large for RedWire binary bulk")
70 }
71 Self::RowWidthMismatch { got, expected } => {
72 write!(f, "row had {got} values for {expected} columns")
73 }
74 Self::RowCountOverflow => write!(f, "row count is too large for RedWire binary bulk"),
75 }
76 }
77}
78
79impl std::error::Error for BulkBinaryError {}
80
81pub fn encode_bulk_binary_payload(
82 collection: &str,
83 columns: &[&str],
84 rows: &[Vec<WireValue>],
85) -> Result<Vec<u8>, BulkBinaryError> {
86 write_len_u16(collection.len(), "collection")?;
87 write_len_u16(columns.len(), "columns")?;
88 if rows.len() > u32::MAX as usize {
89 return Err(BulkBinaryError::RowCountOverflow);
90 }
91 let mut out = Vec::with_capacity(64 + rows.len() * columns.len() * 16);
92 write_string_u16(&mut out, collection, "collection")?;
93 out.extend_from_slice(&(columns.len() as u16).to_le_bytes());
94 for column in columns {
95 write_string_u16(&mut out, column, "column")?;
96 }
97 out.extend_from_slice(&(rows.len() as u32).to_le_bytes());
98 for row in rows {
99 if row.len() != columns.len() {
100 return Err(BulkBinaryError::RowWidthMismatch {
101 got: row.len(),
102 expected: columns.len(),
103 });
104 }
105 for value in row {
106 encode_value(&mut out, value);
107 }
108 }
109 Ok(out)
110}
111
112pub fn decode_bulk_binary_payload(
113 payload: &[u8],
114 flavor: BulkBinaryFlavor,
115) -> Result<BulkBinaryPayload, BulkBinaryError> {
116 let mut pos = 0;
117 if payload.len() < 6 {
118 return Err(BulkBinaryError::PayloadTooShort(flavor));
119 }
120 let coll_len = read_u16(
121 payload,
122 &mut pos,
123 BulkBinaryError::MissingCollectionLength(flavor),
124 )? as usize;
125 let collection = read_string(
126 payload,
127 &mut pos,
128 coll_len,
129 BulkBinaryError::TruncatedCollectionName(flavor),
130 BulkBinaryError::InvalidCollectionName(flavor),
131 )?;
132 let ncols = read_u16(
133 payload,
134 &mut pos,
135 BulkBinaryError::MissingColumnCount(flavor),
136 )? as usize;
137 let mut columns = Vec::with_capacity(ncols);
138 for _ in 0..ncols {
139 let name_len = read_u16(
140 payload,
141 &mut pos,
142 BulkBinaryError::MissingColumnNameLength(flavor),
143 )? as usize;
144 columns.push(read_string(
145 payload,
146 &mut pos,
147 name_len,
148 BulkBinaryError::TruncatedColumnName(flavor),
149 BulkBinaryError::InvalidColumnName(flavor),
150 )?);
151 }
152 let nrows = read_u32(payload, &mut pos, BulkBinaryError::MissingRowCount(flavor))? as usize;
153 let mut rows = Vec::with_capacity(nrows);
154 for _ in 0..nrows {
155 let mut values = Vec::with_capacity(ncols);
156 for _ in 0..ncols {
157 values.push(
158 try_decode_value(payload, &mut pos)
159 .map_err(|err| BulkBinaryError::Value(flavor, err))?,
160 );
161 }
162 rows.push(values);
163 }
164 Ok(BulkBinaryPayload {
165 collection,
166 columns,
167 rows,
168 })
169}
170
171fn prefix(flavor: BulkBinaryFlavor) -> &'static str {
172 match flavor {
173 BulkBinaryFlavor::Binary => "binary bulk",
174 BulkBinaryFlavor::Prevalidated => "binary bulk prevalidated",
175 }
176}
177
178fn label(flavor: BulkBinaryFlavor) -> &'static str {
179 match flavor {
180 BulkBinaryFlavor::Binary => "binary bulk",
181 BulkBinaryFlavor::Prevalidated => "prevalidated",
182 }
183}
184
185fn write_string_u16(
186 out: &mut Vec<u8>,
187 value: &str,
188 field: &'static str,
189) -> Result<(), BulkBinaryError> {
190 write_len_u16(value.len(), field)?;
191 out.extend_from_slice(&(value.len() as u16).to_le_bytes());
192 out.extend_from_slice(value.as_bytes());
193 Ok(())
194}
195
196fn write_len_u16(len: usize, field: &'static str) -> Result<(), BulkBinaryError> {
197 if len > u16::MAX as usize {
198 return Err(BulkBinaryError::LengthOverflow(field));
199 }
200 Ok(())
201}
202
203fn read_u16(payload: &[u8], pos: &mut usize, err: BulkBinaryError) -> Result<u16, BulkBinaryError> {
204 let bytes = read_bytes(payload, pos, 2, err)?;
205 Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
206}
207
208fn read_u32(payload: &[u8], pos: &mut usize, err: BulkBinaryError) -> Result<u32, BulkBinaryError> {
209 let bytes = read_bytes(payload, pos, 4, err)?;
210 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
211}
212
213fn read_string(
214 payload: &[u8],
215 pos: &mut usize,
216 len: usize,
217 truncated_err: BulkBinaryError,
218 utf8_err: BulkBinaryError,
219) -> Result<String, BulkBinaryError> {
220 let bytes = read_bytes(payload, pos, len, truncated_err)?;
221 std::str::from_utf8(bytes)
222 .map(str::to_owned)
223 .map_err(|_| utf8_err)
224}
225
226fn read_bytes<'a>(
227 payload: &'a [u8],
228 pos: &mut usize,
229 len: usize,
230 err: BulkBinaryError,
231) -> Result<&'a [u8], BulkBinaryError> {
232 let end = pos.saturating_add(len);
233 if end > payload.len() {
234 return Err(err);
235 }
236 let bytes = &payload[*pos..end];
237 *pos = end;
238 Ok(bytes)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn binary_bulk_payload_round_trips_values() {
247 let rows = vec![vec![
248 WireValue::I64(7),
249 WireValue::Text("Ada".into()),
250 WireValue::Bool(true),
251 ]];
252 let bytes = encode_bulk_binary_payload("users", &["id", "name", "active"], &rows).unwrap();
253 let decoded = decode_bulk_binary_payload(&bytes, BulkBinaryFlavor::Binary).unwrap();
254 assert_eq!(decoded.collection, "users");
255 assert_eq!(decoded.columns, vec!["id", "name", "active"]);
256 assert_eq!(decoded.rows, rows);
257 }
258
259 #[test]
260 fn binary_bulk_decode_preserves_error_prefixes() {
261 assert_eq!(
262 decode_bulk_binary_payload(&[0; 5], BulkBinaryFlavor::Binary)
263 .unwrap_err()
264 .to_string(),
265 "binary bulk: payload too short"
266 );
267 let payload = vec![1, 0, b't', 1, 0, 1, 0, b'x', 1, 0, 0, 0, 1];
268 assert_eq!(
269 decode_bulk_binary_payload(&payload, BulkBinaryFlavor::Prevalidated)
270 .unwrap_err()
271 .to_string(),
272 "prevalidated: truncated i64 value"
273 );
274 }
275}