crabka_protocol/
tagged_fields.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5use crate::ProtocolError;
6use crate::primitives::varint::{get_uvarint, put_uvarint, uvarint_len};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct UnknownTaggedField {
11 pub tag: u32,
12 pub bytes: Bytes,
13}
14
15#[derive(Debug, Clone, Default, PartialEq, Eq)]
19pub struct UnknownTaggedFields(pub Vec<UnknownTaggedField>);
20
21impl UnknownTaggedFields {
22 #[must_use]
23 pub fn is_empty(&self) -> bool {
24 self.0.is_empty()
25 }
26 #[must_use]
27 pub fn len(&self) -> usize {
28 self.0.len()
29 }
30}
31
32pub fn read_tagged_fields<B, F>(
38 buf: &mut B,
39 mut known: F,
40) -> Result<UnknownTaggedFields, ProtocolError>
41where
42 B: Buf,
43 F: FnMut(u32, &mut &[u8]) -> Result<bool, ProtocolError>,
44{
45 let count = get_uvarint(buf)? as usize;
46 let mut unknown = Vec::new();
47 let mut last_tag: Option<u32> = None;
48 for _ in 0..count {
49 let tag = get_uvarint(buf)?;
50 if let Some(prev) = last_tag
51 && tag <= prev
52 {
53 return Err(ProtocolError::InvalidValue(
54 "tagged fields not strictly ascending",
55 ));
56 }
57 last_tag = Some(tag);
58 let size = get_uvarint(buf)? as usize;
59 if buf.remaining() < size {
60 return Err(ProtocolError::UnexpectedEof {
61 needed: size - buf.remaining(),
62 });
63 }
64 let mut payload = vec![0u8; size];
66 buf.copy_to_slice(&mut payload);
67 let mut slice = &payload[..];
68 if !known(tag, &mut slice)? {
69 unknown.push(UnknownTaggedField {
70 tag,
71 bytes: Bytes::from(payload),
72 });
73 } else if !slice.is_empty() {
74 return Err(ProtocolError::InvalidValue(
75 "tagged field decoder did not consume all bytes",
76 ));
77 }
78 }
79 Ok(UnknownTaggedFields(unknown))
80}
81
82pub struct WriteTaggedFields {
86 entries: Vec<(u32, Bytes)>,
87}
88
89impl Default for WriteTaggedFields {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl WriteTaggedFields {
96 #[must_use]
97 pub fn new() -> Self {
98 Self {
99 entries: Vec::new(),
100 }
101 }
102
103 pub fn add(&mut self, tag: u32, payload: Bytes) {
104 self.entries.push((tag, payload));
105 }
106
107 pub fn write<B: BufMut>(mut self, buf: &mut B, unknown: &UnknownTaggedFields) {
108 for u in &unknown.0 {
109 self.entries.push((u.tag, u.bytes.clone()));
110 }
111 self.entries.sort_by_key(|(t, _)| *t);
112 put_uvarint(
113 buf,
114 u32::try_from(self.entries.len()).expect("too many tagged fields"),
115 );
116 for (tag, payload) in self.entries {
117 put_uvarint(buf, tag);
118 put_uvarint(
119 buf,
120 u32::try_from(payload.len()).expect("tagged field too large"),
121 );
122 buf.put_slice(&payload);
123 }
124 }
125}
126
127#[must_use]
129pub fn tagged_fields_len(known: &[(u32, usize)], unknown: &UnknownTaggedFields) -> usize {
130 let total = known.len() + unknown.0.len();
131 let mut n = uvarint_len(u32::try_from(total).unwrap());
132 for (tag, len) in known {
133 n += uvarint_len(*tag) + uvarint_len(u32::try_from(*len).unwrap()) + *len;
134 }
135 for u in &unknown.0 {
136 n +=
137 uvarint_len(u.tag) + uvarint_len(u32::try_from(u.bytes.len()).unwrap()) + u.bytes.len();
138 }
139 n
140}
141
142pub fn encode_to_bytes<F>(predicted_len: usize, write: F) -> Bytes
149where
150 F: FnOnce(&mut BytesMut) -> Result<(), crate::ProtocolError>,
151{
152 let mut buf = BytesMut::with_capacity(predicted_len);
153 write(&mut buf).expect("tagged-field encode failed: emitter bug");
154 debug_assert_eq!(buf.len(), predicted_len, "encoded_len lied");
155 buf.freeze()
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use assert2::assert;
162
163 #[test]
164 fn empty_tagged_fields() {
165 let buf = [0x00u8];
166 let mut cur = &buf[..];
167 let unknown = read_tagged_fields(&mut cur, |_, _| Ok(false)).unwrap();
168 assert!(unknown.is_empty());
169 assert!(cur.is_empty());
170 }
171
172 #[test]
173 fn unknown_tagged_fields_preserved() {
174 let buf = [0x01, 0x05, 0x03, 10, 20, 30];
176 let mut cur = &buf[..];
177 let unknown = read_tagged_fields(&mut cur, |_, _| Ok(false)).unwrap();
178 assert!(unknown.len() == 1);
179 assert!(unknown.0[0].tag == 5);
180 assert!(unknown.0[0].bytes.as_ref() == &[10, 20, 30]);
181 }
182
183 #[test]
184 fn ascending_order_enforced() {
185 let buf = [0x02, 0x05, 0x01, 0x00, 0x03, 0x01, 0x00];
187 let mut cur = &buf[..];
188 assert!(read_tagged_fields(&mut cur, |_, _| Ok(false)).is_err());
189 }
190
191 #[test]
192 fn write_merges_known_and_unknown_sorted() {
193 let mut w = WriteTaggedFields::new();
194 w.add(10, Bytes::from_static(&[0xAA]));
195 let unknown = UnknownTaggedFields(vec![UnknownTaggedField {
196 tag: 5,
197 bytes: Bytes::from_static(&[0xBB]),
198 }]);
199 let mut out = BytesMut::new();
200 w.write(&mut out, &unknown);
201 assert!(&out[..] == &[0x02, 0x05, 0x01, 0xBB, 0x0A, 0x01, 0xAA]);
203 }
204}