1use bcp_wire::WireError;
2use bcp_wire::varint::{decode_varint, encode_varint};
3
4use crate::error::TypeError;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq)]
30pub enum FieldWireType {
31 Varint = 0,
32 Bytes = 1,
33 Nested = 2,
34}
35
36impl FieldWireType {
37 pub fn from_raw(value: u64) -> Result<Self, TypeError> {
41 match value {
42 0 => Ok(Self::Varint),
43 1 => Ok(Self::Bytes),
44 2 => Ok(Self::Nested),
45 other => Err(TypeError::UnknownFieldWireType { value: other }),
46 }
47 }
48}
49
50const MAX_VARINT_LEN: usize = 10;
59
60fn push_varint(buf: &mut Vec<u8>, value: u64) {
64 let mut scratch = [0u8; MAX_VARINT_LEN];
65 let n = encode_varint(value, &mut scratch);
66 buf.extend_from_slice(&scratch[..n]);
67}
68
69pub fn encode_varint_field(buf: &mut Vec<u8>, field_id: u64, value: u64) {
76 push_varint(buf, field_id);
77 push_varint(buf, 0); push_varint(buf, value);
79}
80
81pub fn encode_bytes_field(buf: &mut Vec<u8>, field_id: u64, data: &[u8]) {
88 push_varint(buf, field_id);
89 push_varint(buf, 1); push_varint(buf, data.len() as u64);
91 buf.extend_from_slice(data);
92}
93
94pub fn encode_nested_field(buf: &mut Vec<u8>, field_id: u64, nested_data: &[u8]) {
105 push_varint(buf, field_id);
106 push_varint(buf, 2); push_varint(buf, nested_data.len() as u64);
108 buf.extend_from_slice(nested_data);
109}
110
111#[derive(Clone, Copy, Debug, PartialEq, Eq)]
124pub struct FieldHeader {
125 pub field_id: u64,
126 pub wire_type: FieldWireType,
127}
128
129pub fn decode_field_header(buf: &[u8]) -> Result<(FieldHeader, usize), TypeError> {
134 let mut cursor = 0;
135
136 let (field_id, n) = decode_varint(buf).map_err(WireError::from)?;
137 cursor += n;
138
139 let (wire_type_raw, n) = decode_varint(
140 buf.get(cursor..)
141 .ok_or(WireError::UnexpectedEof { offset: cursor })?,
142 )
143 .map_err(WireError::from)?;
144 cursor += n;
145
146 let wire_type = FieldWireType::from_raw(wire_type_raw)?;
147
148 Ok((
149 FieldHeader {
150 field_id,
151 wire_type,
152 },
153 cursor,
154 ))
155}
156
157pub fn decode_varint_value(buf: &[u8]) -> Result<(u64, usize), TypeError> {
162 let (value, n) = decode_varint(buf)?;
163 Ok((value, n))
164}
165
166pub fn decode_bytes_value(buf: &[u8]) -> Result<(&[u8], usize), TypeError> {
172 let (len, n) = decode_varint(buf)?;
173 let len = usize::try_from(len).map_err(|_| {
175 WireError::UnexpectedEof { offset: n }
176 })?;
177 let end = match n.checked_add(len) {
179 Some(e) => e,
180 None => {
181 return Err(WireError::UnexpectedEof { offset: n }.into())
182 }
183 };
184 let data = buf
185 .get(n..end)
186 .ok_or(WireError::UnexpectedEof { offset: n })?;
187 Ok((data, end))
188}
189
190pub fn skip_field(buf: &[u8], wire_type: FieldWireType) -> Result<usize, TypeError> {
196 match wire_type {
197 FieldWireType::Varint => {
198 let (_value, n) = decode_varint(buf)?;
199 Ok(n)
200 }
201 FieldWireType::Bytes | FieldWireType::Nested => {
202 let (_data, n) = decode_bytes_value(buf)?;
203 Ok(n)
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn roundtrip_varint_field() {
214 let mut buf = Vec::new();
215 encode_varint_field(&mut buf, 1, 42);
216
217 let (header, mut cursor) = decode_field_header(&buf).unwrap();
218 assert_eq!(header.field_id, 1);
219 assert_eq!(header.wire_type, FieldWireType::Varint);
220
221 let (value, n) = decode_varint_value(&buf[cursor..]).unwrap();
222 cursor += n;
223 assert_eq!(value, 42);
224 assert_eq!(cursor, buf.len());
225 }
226
227 #[test]
228 fn roundtrip_bytes_field() {
229 let mut buf = Vec::new();
230 encode_bytes_field(&mut buf, 2, b"hello");
231
232 let (header, mut cursor) = decode_field_header(&buf).unwrap();
233 assert_eq!(header.field_id, 2);
234 assert_eq!(header.wire_type, FieldWireType::Bytes);
235
236 let (data, n) = decode_bytes_value(&buf[cursor..]).unwrap();
237 cursor += n;
238 assert_eq!(data, b"hello");
239 assert_eq!(cursor, buf.len());
240 }
241
242 #[test]
243 fn roundtrip_nested_field() {
244 let mut inner = Vec::new();
246 encode_varint_field(&mut inner, 1, 99);
247
248 let mut buf = Vec::new();
249 encode_nested_field(&mut buf, 3, &inner);
250
251 let (header, mut cursor) = decode_field_header(&buf).unwrap();
252 assert_eq!(header.field_id, 3);
253 assert_eq!(header.wire_type, FieldWireType::Nested);
254
255 let (nested_bytes, n) = decode_bytes_value(&buf[cursor..]).unwrap();
256 cursor += n;
257 assert_eq!(cursor, buf.len());
258
259 let (inner_header, inner_cursor) = decode_field_header(nested_bytes).unwrap();
261 assert_eq!(inner_header.field_id, 1);
262 let (value, _) = decode_varint_value(&nested_bytes[inner_cursor..]).unwrap();
263 assert_eq!(value, 99);
264 }
265
266 #[test]
267 fn skip_varint_field() {
268 let mut buf = Vec::new();
269 encode_varint_field(&mut buf, 1, 12345);
270
271 let (header, cursor) = decode_field_header(&buf).unwrap();
272 let skipped = skip_field(&buf[cursor..], header.wire_type).unwrap();
273 assert_eq!(cursor + skipped, buf.len());
274 }
275
276 #[test]
277 fn skip_bytes_field() {
278 let mut buf = Vec::new();
279 encode_bytes_field(&mut buf, 2, b"skip me");
280
281 let (header, cursor) = decode_field_header(&buf).unwrap();
282 let skipped = skip_field(&buf[cursor..], header.wire_type).unwrap();
283 assert_eq!(cursor + skipped, buf.len());
284 }
285
286 #[test]
287 fn multiple_fields_sequential() {
288 let mut buf = Vec::new();
289 encode_varint_field(&mut buf, 1, 7);
290 encode_bytes_field(&mut buf, 2, b"world");
291 encode_varint_field(&mut buf, 3, 256);
292
293 let mut cursor = 0;
294
295 let (h, n) = decode_field_header(&buf[cursor..]).unwrap();
297 cursor += n;
298 assert_eq!(h.field_id, 1);
299 let (v, n) = decode_varint_value(&buf[cursor..]).unwrap();
300 cursor += n;
301 assert_eq!(v, 7);
302
303 let (h, n) = decode_field_header(&buf[cursor..]).unwrap();
305 cursor += n;
306 assert_eq!(h.field_id, 2);
307 let (data, n) = decode_bytes_value(&buf[cursor..]).unwrap();
308 cursor += n;
309 assert_eq!(data, b"world");
310
311 let (h, n) = decode_field_header(&buf[cursor..]).unwrap();
313 cursor += n;
314 assert_eq!(h.field_id, 3);
315 let (v, n) = decode_varint_value(&buf[cursor..]).unwrap();
316 cursor += n;
317 assert_eq!(v, 256);
318
319 assert_eq!(cursor, buf.len());
320 }
321
322 #[test]
323 fn unknown_wire_type_rejected() {
324 let mut buf = Vec::new();
325 push_varint(&mut buf, 1); push_varint(&mut buf, 5); let result = decode_field_header(&buf);
329 assert!(matches!(
330 result,
331 Err(TypeError::UnknownFieldWireType { value: 5 })
332 ));
333 }
334}