1use std::fmt;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct BulkJsonPayload {
11 pub collection: String,
12 pub json_payloads: Vec<String>,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum BulkJsonError {
17 PayloadTooShort,
18 MissingCollectionLength,
19 TruncatedCollectionName,
20 InvalidCollectionName,
21 MissingRowCount,
22 MissingJsonLength,
23 TruncatedJsonPayload,
24 InvalidJsonPayload,
25 FieldTooLarge(&'static str),
26 RowCountOverflow,
27}
28
29impl fmt::Display for BulkJsonError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::PayloadTooShort => write!(f, "bulk insert: payload too short"),
33 Self::MissingCollectionLength => {
34 write!(f, "bulk insert: missing collection length")
35 }
36 Self::TruncatedCollectionName => {
37 write!(f, "bulk insert: truncated collection name")
38 }
39 Self::InvalidCollectionName => write!(f, "bulk insert: invalid collection name"),
40 Self::MissingRowCount => write!(f, "bulk insert: missing row count"),
41 Self::MissingJsonLength => write!(f, "bulk insert: missing JSON length"),
42 Self::TruncatedJsonPayload => write!(f, "bulk insert: truncated JSON payload"),
43 Self::InvalidJsonPayload => write!(f, "bulk insert: invalid JSON payload"),
44 Self::FieldTooLarge(field) => {
45 write!(f, "{field} is too large for RedWire JSON bulk insert")
46 }
47 Self::RowCountOverflow => {
48 write!(f, "row count is too large for RedWire JSON bulk insert")
49 }
50 }
51 }
52}
53
54impl std::error::Error for BulkJsonError {}
55
56pub fn encode_bulk_json_payload(
57 collection: &str,
58 json_payloads: &[String],
59) -> Result<Vec<u8>, BulkJsonError> {
60 let collection_len =
61 u16::try_from(collection.len()).map_err(|_| BulkJsonError::FieldTooLarge("collection"))?;
62 let row_count =
63 u32::try_from(json_payloads.len()).map_err(|_| BulkJsonError::RowCountOverflow)?;
64
65 let mut out = Vec::new();
66 out.extend_from_slice(&collection_len.to_le_bytes());
67 out.extend_from_slice(collection.as_bytes());
68 out.extend_from_slice(&row_count.to_le_bytes());
69 for payload in json_payloads {
70 let json_len =
71 u32::try_from(payload.len()).map_err(|_| BulkJsonError::FieldTooLarge("json"))?;
72 out.extend_from_slice(&json_len.to_le_bytes());
73 out.extend_from_slice(payload.as_bytes());
74 }
75 Ok(out)
76}
77
78pub fn decode_bulk_json_payload(payload: &[u8]) -> Result<BulkJsonPayload, BulkJsonError> {
79 if payload.len() < 2 {
80 return Err(BulkJsonError::PayloadTooShort);
81 }
82
83 let mut pos = 0;
84 let coll_len = read_u16(payload, &mut pos, BulkJsonError::MissingCollectionLength)? as usize;
85 let collection = read_string(
86 payload,
87 &mut pos,
88 coll_len,
89 BulkJsonError::TruncatedCollectionName,
90 BulkJsonError::InvalidCollectionName,
91 )?;
92
93 let nrows = read_u32(payload, &mut pos, BulkJsonError::MissingRowCount)? as usize;
94 let mut json_payloads = Vec::with_capacity(nrows);
95 for _ in 0..nrows {
96 let json_len = read_u32(payload, &mut pos, BulkJsonError::MissingJsonLength)? as usize;
97 json_payloads.push(read_string(
98 payload,
99 &mut pos,
100 json_len,
101 BulkJsonError::TruncatedJsonPayload,
102 BulkJsonError::InvalidJsonPayload,
103 )?);
104 }
105
106 Ok(BulkJsonPayload {
107 collection,
108 json_payloads,
109 })
110}
111
112fn read_u16(payload: &[u8], pos: &mut usize, err: BulkJsonError) -> Result<u16, BulkJsonError> {
113 let bytes = read_bytes(payload, pos, 2, err)?;
114 Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
115}
116
117fn read_u32(payload: &[u8], pos: &mut usize, err: BulkJsonError) -> Result<u32, BulkJsonError> {
118 let bytes = read_bytes(payload, pos, 4, err)?;
119 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
120}
121
122fn read_string(
123 payload: &[u8],
124 pos: &mut usize,
125 len: usize,
126 truncated_err: BulkJsonError,
127 utf8_err: BulkJsonError,
128) -> Result<String, BulkJsonError> {
129 let bytes = read_bytes(payload, pos, len, truncated_err)?;
130 std::str::from_utf8(bytes)
131 .map(str::to_string)
132 .map_err(|_| utf8_err)
133}
134
135fn read_bytes<'a>(
136 payload: &'a [u8],
137 pos: &mut usize,
138 len: usize,
139 err: BulkJsonError,
140) -> Result<&'a [u8], BulkJsonError> {
141 let Some(end) = pos.checked_add(len) else {
142 return Err(err);
143 };
144 if end > payload.len() {
145 return Err(err);
146 }
147 let bytes = &payload[*pos..end];
148 *pos = end;
149 Ok(bytes)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn bulk_json_payload_round_trips() {
158 let rows = vec![r#"{"id":1}"#.to_string(), r#"{"id":2}"#.to_string()];
159 let bytes = encode_bulk_json_payload("events", &rows).unwrap();
160 let decoded = decode_bulk_json_payload(&bytes).unwrap();
161 assert_eq!(decoded.collection, "events");
162 assert_eq!(decoded.json_payloads, rows);
163 }
164
165 #[test]
166 fn bulk_json_decode_preserves_error_prefixes() {
167 assert_eq!(
168 decode_bulk_json_payload(&[0]).unwrap_err().to_string(),
169 "bulk insert: payload too short"
170 );
171
172 let payload = vec![1, 0, b't', 1, 0, 0, 0, 10, 0, 0, 0, b'{'];
173 assert_eq!(
174 decode_bulk_json_payload(&payload).unwrap_err().to_string(),
175 "bulk insert: truncated JSON payload"
176 );
177 }
178}