Skip to main content

crabka_protocol/
tagged_fields.rs

1//! KIP-482 flexible-version tagged fields.
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5use crate::ProtocolError;
6use crate::primitives::varint::{get_uvarint, put_uvarint, uvarint_len};
7
8/// An unknown tagged field that was preserved verbatim during decode.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct UnknownTaggedField {
11    pub tag: u32,
12    pub bytes: Bytes,
13}
14
15/// A collection of tagged fields that the schema does not declare. Generated
16/// message types contain a `Vec<UnknownTaggedField>` (sorted by tag) so that
17/// values can be round-tripped without information loss.
18#[derive(Debug, Clone, Default, PartialEq, Eq)]
19pub struct UnknownTaggedFields(pub Vec<UnknownTaggedField>);
20
21impl UnknownTaggedFields {
22    #[must_use]
23    pub fn is_empty(&self) -> bool {
24        self.0.is_empty()
25    }
26    #[must_use]
27    pub fn len(&self) -> usize {
28        self.0.len()
29    }
30}
31
32/// Read the tagged-fields trailer at the current position of `buf`. `known`
33/// is called for each entry whose tag is in the schema; it should consume the
34/// field's payload (a `size`-byte slice) and return Ok if it recognised the
35/// tag. Anything `known` returns `Ok(false)` for is captured into the unknown
36/// vec instead.
37pub fn read_tagged_fields<B, F>(
38    buf: &mut B,
39    mut known: F,
40) -> Result<UnknownTaggedFields, ProtocolError>
41where
42    B: Buf,
43    F: FnMut(u32, &mut &[u8]) -> Result<bool, ProtocolError>,
44{
45    let count = get_uvarint(buf)? as usize;
46    let mut unknown = Vec::new();
47    let mut last_tag: Option<u32> = None;
48    for _ in 0..count {
49        let tag = get_uvarint(buf)?;
50        if let Some(prev) = last_tag
51            && tag <= prev
52        {
53            return Err(ProtocolError::InvalidValue(
54                "tagged fields not strictly ascending",
55            ));
56        }
57        last_tag = Some(tag);
58        let size = get_uvarint(buf)? as usize;
59        if buf.remaining() < size {
60            return Err(ProtocolError::UnexpectedEof {
61                needed: size - buf.remaining(),
62            });
63        }
64        // Copy the payload so we can hand a slice to the closure or store it.
65        let mut payload = vec![0u8; size];
66        buf.copy_to_slice(&mut payload);
67        let mut slice = &payload[..];
68        if !known(tag, &mut slice)? {
69            unknown.push(UnknownTaggedField {
70                tag,
71                bytes: Bytes::from(payload),
72            });
73        } else if !slice.is_empty() {
74            return Err(ProtocolError::InvalidValue(
75                "tagged field decoder did not consume all bytes",
76            ));
77        }
78    }
79    Ok(UnknownTaggedFields(unknown))
80}
81
82/// Helper used by generated code while emitting tagged fields. Call
83/// `WriteTaggedFields::new()`, then `add` each known tag-and-payload-encoder,
84/// then `write` to flush, merging with `unknown`.
85pub struct WriteTaggedFields {
86    entries: Vec<(u32, Bytes)>,
87}
88
89impl Default for WriteTaggedFields {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl WriteTaggedFields {
96    #[must_use]
97    pub fn new() -> Self {
98        Self {
99            entries: Vec::new(),
100        }
101    }
102
103    pub fn add(&mut self, tag: u32, payload: Bytes) {
104        self.entries.push((tag, payload));
105    }
106
107    pub fn write<B: BufMut>(mut self, buf: &mut B, unknown: &UnknownTaggedFields) {
108        for u in &unknown.0 {
109            self.entries.push((u.tag, u.bytes.clone()));
110        }
111        self.entries.sort_by_key(|(t, _)| *t);
112        put_uvarint(
113            buf,
114            u32::try_from(self.entries.len()).expect("too many tagged fields"),
115        );
116        for (tag, payload) in self.entries {
117            put_uvarint(buf, tag);
118            put_uvarint(
119                buf,
120                u32::try_from(payload.len()).expect("tagged field too large"),
121            );
122            buf.put_slice(&payload);
123        }
124    }
125}
126
127/// Predicted length of the tagged-fields trailer.
128#[must_use]
129pub fn tagged_fields_len(known: &[(u32, usize)], unknown: &UnknownTaggedFields) -> usize {
130    let total = known.len() + unknown.0.len();
131    let mut n = uvarint_len(u32::try_from(total).unwrap());
132    for (tag, len) in known {
133        n += uvarint_len(*tag) + uvarint_len(u32::try_from(*len).unwrap()) + *len;
134    }
135    for u in &unknown.0 {
136        n +=
137            uvarint_len(u.tag) + uvarint_len(u32::try_from(u.bytes.len()).unwrap()) + u.bytes.len();
138    }
139    n
140}
141
142/// Encode a value into a freshly-allocated `Bytes` (used to materialize a
143/// tagged-field payload before sizing the outer trailer).
144///
145/// The write closure may return an error (e.g. from nested struct encode);
146/// the error propagates as a panic since tagged-field encoding failures indicate
147/// a bug in the emitter's `encoded_len` prediction.
148pub fn encode_to_bytes<F>(predicted_len: usize, write: F) -> Bytes
149where
150    F: FnOnce(&mut BytesMut) -> Result<(), crate::ProtocolError>,
151{
152    let mut buf = BytesMut::with_capacity(predicted_len);
153    write(&mut buf).expect("tagged-field encode failed: emitter bug");
154    debug_assert_eq!(buf.len(), predicted_len, "encoded_len lied");
155    buf.freeze()
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use assert2::assert;
162
163    #[test]
164    fn empty_tagged_fields() {
165        let buf = [0x00u8];
166        let mut cur = &buf[..];
167        let unknown = read_tagged_fields(&mut cur, |_, _| Ok(false)).unwrap();
168        assert!(unknown.is_empty());
169        assert!(cur.is_empty());
170    }
171
172    #[test]
173    fn unknown_tagged_fields_preserved() {
174        // count=1, tag=5, size=3, payload=[10,20,30]
175        let buf = [0x01, 0x05, 0x03, 10, 20, 30];
176        let mut cur = &buf[..];
177        let unknown = read_tagged_fields(&mut cur, |_, _| Ok(false)).unwrap();
178        assert!(unknown.len() == 1);
179        assert!(unknown.0[0].tag == 5);
180        assert!(unknown.0[0].bytes.as_ref() == &[10, 20, 30]);
181    }
182
183    #[test]
184    fn ascending_order_enforced() {
185        // count=2, tag=5..., tag=3...  — invalid (descending)
186        let buf = [0x02, 0x05, 0x01, 0x00, 0x03, 0x01, 0x00];
187        let mut cur = &buf[..];
188        assert!(read_tagged_fields(&mut cur, |_, _| Ok(false)).is_err());
189    }
190
191    #[test]
192    fn write_merges_known_and_unknown_sorted() {
193        let mut w = WriteTaggedFields::new();
194        w.add(10, Bytes::from_static(&[0xAA]));
195        let unknown = UnknownTaggedFields(vec![UnknownTaggedField {
196            tag: 5,
197            bytes: Bytes::from_static(&[0xBB]),
198        }]);
199        let mut out = BytesMut::new();
200        w.write(&mut out, &unknown);
201        // Expect: count=2, tag=5,len=1,0xBB, tag=10,len=1,0xAA
202        assert!(&out[..] == &[0x02, 0x05, 0x01, 0xBB, 0x0A, 0x01, 0xAA]);
203    }
204}