Skip to main content

context_engine/
tree.rs

1use alloc::vec::Vec;
2
3use crate::provided::Tree;
4
5// ── Wire format ───────────────────────────────────────────────────────────────
6//
7//   Null     : 0x00
8//   Scalar   : 0x01 | len(u32le) | bytes
9//   Sequence : 0x02 | count(u32le) | item...
10//   Mapping  : 0x03 | count(u32le) | (key_len(u32le) | key_bytes | item)...
11
12const TAG_NULL:     u8 = 0x00;
13const TAG_SCALAR:   u8 = 0x01;
14const TAG_SEQUENCE: u8 = 0x02;
15const TAG_MAPPING:  u8 = 0x03;
16
17impl Tree {
18    /// Serializes the tree to the wire format.
19    ///
20    /// ```
21    /// # extern crate alloc;
22    /// use context_engine::Tree;
23    /// let v = Tree::Scalar(b"hi".to_vec());
24    /// let bytes = v.wire();
25    /// assert_eq!(Tree::unwire(&bytes), Some(v));
26    /// ```
27    pub fn wire(&self) -> Vec<u8> {
28        let mut buf = Vec::new();
29        write_value(self, &mut buf);
30        buf
31    }
32
33    /// Deserializes a tree from wire-format bytes. Returns `None` on malformed input.
34    ///
35    /// ```
36    /// # extern crate alloc;
37    /// use context_engine::Tree;
38    /// assert_eq!(Tree::unwire(&[0xFF]), None);
39    /// let bytes = Tree::Null.wire();
40    /// assert_eq!(Tree::unwire(&bytes), Some(Tree::Null));
41    /// ```
42    pub fn unwire(bytes: &[u8]) -> Option<Self> {
43        let (value, _) = read_value(bytes)?;
44        Some(value)
45    }
46}
47
48fn write_value(value: &Tree, buf: &mut Vec<u8>) {
49    match value {
50        Tree::Null => {
51            buf.push(TAG_NULL);
52        }
53        Tree::Scalar(b) => {
54            buf.push(TAG_SCALAR);
55            buf.extend_from_slice(&(b.len() as u32).to_le_bytes());
56            buf.extend_from_slice(b);
57        }
58        Tree::Sequence(items) => {
59            buf.push(TAG_SEQUENCE);
60            buf.extend_from_slice(&(items.len() as u32).to_le_bytes());
61            for item in items {
62                write_value(item, buf);
63            }
64        }
65        Tree::Mapping(pairs) => {
66            buf.push(TAG_MAPPING);
67            buf.extend_from_slice(&(pairs.len() as u32).to_le_bytes());
68            for (k, v) in pairs {
69                buf.extend_from_slice(&(k.len() as u32).to_le_bytes());
70                buf.extend_from_slice(k);
71                write_value(v, buf);
72            }
73        }
74    }
75}
76
77fn read_value(bytes: &[u8]) -> Option<(Tree, &[u8])> {
78    let (&tag, rest) = bytes.split_first()?;
79    match tag {
80        TAG_NULL => Some((Tree::Null, rest)),
81        TAG_SCALAR => {
82            let (len, rest) = read_u32(rest)?;
83            let (data, rest) = split_at(rest, len)?;
84            Some((Tree::Scalar(data.to_vec()), rest))
85        }
86        TAG_SEQUENCE => {
87            let (count, mut rest) = read_u32(rest)?;
88            let mut items = Vec::with_capacity(count);
89            for _ in 0..count {
90                let (item, next) = read_value(rest)?;
91                items.push(item);
92                rest = next;
93            }
94            Some((Tree::Sequence(items), rest))
95        }
96        TAG_MAPPING => {
97            let (count, mut rest) = read_u32(rest)?;
98            let mut pairs = Vec::with_capacity(count);
99            for _ in 0..count {
100                let (klen, next) = read_u32(rest)?;
101                let (kdata, next) = split_at(next, klen)?;
102                let (val, next) = read_value(next)?;
103                pairs.push((kdata.to_vec(), val));
104                rest = next;
105            }
106            Some((Tree::Mapping(pairs), rest))
107        }
108        _ => None,
109    }
110}
111
112fn read_u32(bytes: &[u8]) -> Option<(usize, &[u8])> {
113    let (b, rest) = split_at(bytes, 4)?;
114    let n = u32::from_le_bytes(b.try_into().ok()?) as usize;
115    Some((n, rest))
116}
117
118fn split_at(bytes: &[u8], n: usize) -> Option<(&[u8], &[u8])> {
119    if bytes.len() >= n { Some(bytes.split_at(n)) } else { None }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use alloc::vec;
126
127    fn rt(v: &Tree) -> Tree {
128        Tree::unwire(&v.wire()).unwrap()
129    }
130
131    #[test]
132    fn null_roundtrip() {
133        assert_eq!(rt(&Tree::Null), Tree::Null);
134    }
135
136    #[test]
137    fn scalar_roundtrip() {
138        assert_eq!(rt(&Tree::Scalar(b"hello".to_vec())), Tree::Scalar(b"hello".to_vec()));
139    }
140
141    #[test]
142    fn scalar_empty_roundtrip() {
143        assert_eq!(rt(&Tree::Scalar(vec![])), Tree::Scalar(vec![]));
144    }
145
146    #[test]
147    fn sequence_roundtrip() {
148        let v = Tree::Sequence(vec![
149            Tree::Scalar(b"a".to_vec()),
150            Tree::Null,
151            Tree::Scalar(b"b".to_vec()),
152        ]);
153        assert_eq!(rt(&v), v);
154    }
155
156    #[test]
157    fn mapping_roundtrip() {
158        let v = Tree::Mapping(vec![
159            (b"id".to_vec(),   Tree::Scalar(b"1".to_vec())),
160            (b"name".to_vec(), Tree::Scalar(b"alice".to_vec())),
161        ]);
162        assert_eq!(rt(&v), v);
163    }
164
165    #[test]
166    fn nested_roundtrip() {
167        let v = Tree::Mapping(vec![
168            (b"user".to_vec(), Tree::Mapping(vec![
169                (b"id".to_vec(),    Tree::Scalar(b"1".to_vec())),
170                (b"tags".to_vec(),  Tree::Sequence(vec![
171                    Tree::Scalar(b"admin".to_vec()),
172                    Tree::Scalar(b"staff".to_vec()),
173                ])),
174                (b"extra".to_vec(), Tree::Null),
175            ])),
176        ]);
177        assert_eq!(rt(&v), v);
178    }
179
180    #[test]
181    fn unwire_invalid_returns_none() {
182        assert_eq!(Tree::unwire(&[0xFF]), None);
183        assert_eq!(Tree::unwire(&[TAG_SCALAR, 0x05, 0x00, 0x00, 0x00]), None);
184    }
185
186    #[test]
187    fn mapping_with_null_field_roundtrip() {
188        let v = Tree::Mapping(vec![
189            (b"id".to_vec(),         Tree::Scalar(b"1".to_vec())),
190            (b"deleted_at".to_vec(), Tree::Null),
191        ]);
192        assert_eq!(Tree::unwire(&v.wire()).unwrap(), v);
193    }
194}
195