amadeus_utils/
vanilla_ser.rs

1/// Custom wire format for encoding transactions:
2/// - 0 => nil
3/// - 1 => true
4/// - 2 => false
5/// - 3 => integer: encode_varint
6/// - 5 => binary/atom: encode_varint(len) + raw bytes
7/// - 6 => list: encode_varint(len) + encoded elements
8/// - 7 => map: encode_varint(len) + sorted (by key) [key, value] encoded pairs
9///
10/// Variant:
11/// - 0 => single 0x00 byte
12/// - otherwise: first byte has sign (MSB, 1 bit) and length in bytes (7 bits),
13///   followed by that many big-endian magnitude bytes. sign=0 => positive, sign=1 => negative.
14use std::cmp::Ordering;
15use std::collections::BTreeMap;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum Value {
19    Nil,
20    Bool(bool),
21    Int(i128),
22    Bytes(Vec<u8>),
23    List(Vec<Value>),
24    Map(BTreeMap<Value, Value>),
25}
26
27#[derive(Debug, thiserror::Error, PartialEq, Eq)]
28pub enum Error {
29    #[error("unexpected end of input")]
30    UnexpectedEof,
31    #[error("invalid type tag: {0}")]
32    InvalidType(u8),
33    #[error("invalid variant encoding")]
34    InvalidVarInt,
35    #[error("integer overflow (requires bigint)")]
36    Overflow,
37    #[error("trailing data after full decode")]
38    TrailingData,
39}
40
41impl Value {
42    /// Deterministic comparator roughly analogous to term ordering for common types used as keys
43    fn cmp_keys(a: &Value, b: &Value) -> Ordering {
44        use Value::*;
45        let tag_order = |v: &Value| -> u8 {
46            match v {
47                Nil => 0,
48                Bool(false) => 1,
49                Bool(true) => 2,
50                Int(_) => 3,
51                Bytes(_) => 5,
52                List(_) => 6,
53                Map(_) => 7,
54            }
55        };
56        let ta = tag_order(a);
57        let tb = tag_order(b);
58        if ta != tb {
59            return ta.cmp(&tb);
60        }
61        match (a, b) {
62            (Nil, Nil) => Ordering::Equal,
63            (Bool(x), Bool(y)) => x.cmp(y), // false < true
64            (Int(x), Int(y)) => x.cmp(y),
65            (Bytes(x), Bytes(y)) => x.as_slice().cmp(y.as_slice()),
66            (List(x), List(y)) => {
67                let min_len = x.len().min(y.len());
68                for i in 0..min_len {
69                    let c = Value::cmp_keys(&x[i], &y[i]);
70                    if c != Ordering::Equal {
71                        return c;
72                    }
73                }
74                x.len().cmp(&y.len())
75            }
76            (Map(x), Map(y)) => {
77                // Compare by iterating BTreeMap entries (already sorted by key)
78                let mut xi = x.iter();
79                let mut yi = y.iter();
80                loop {
81                    match (xi.next(), yi.next()) {
82                        (None, None) => return Ordering::Equal,
83                        (None, Some(_)) => return Ordering::Less,
84                        (Some(_), None) => return Ordering::Greater,
85                        (Some((kxa, vxa)), Some((kxb, vxb))) => {
86                            let c = Value::cmp_keys(kxa, kxb);
87                            if c != Ordering::Equal {
88                                return c;
89                            }
90                            let c2 = Value::cmp_keys(vxa, vxb);
91                            if c2 != Ordering::Equal {
92                                return c2;
93                            }
94                        }
95                    }
96                }
97            }
98            _ => Ordering::Equal, // same tag values are handled above; cross-tags handled by tag_order
99        }
100    }
101}
102
103impl Ord for Value {
104    fn cmp(&self, other: &Self) -> Ordering {
105        Value::cmp_keys(self, other)
106    }
107}
108
109impl PartialOrd for Value {
110    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111        Some(self.cmp(other))
112    }
113}
114
115pub fn encode(value: &Value) -> Vec<u8> {
116    let mut out = Vec::new();
117    encode_into(value, &mut out);
118    out
119}
120
121pub fn validate(bytes: &[u8]) -> Result<Value, Error> {
122    let (term, rest) = decode(bytes)?;
123    if !rest.is_empty() {
124        return Err(Error::TrailingData);
125    }
126    let re = encode(&term);
127    if re.as_slice() == bytes {
128        Ok(term)
129    } else {
130        Err(Error::InvalidVarInt) // mismatch with Elixir roundtrip; keep simple error
131    }
132}
133
134pub fn decode_all(bytes: &[u8]) -> Result<Value, Error> {
135    let (v, rest) = decode(bytes)?;
136    if rest.is_empty() { Ok(v) } else { Err(Error::TrailingData) }
137}
138
139pub fn decode(mut bytes: &[u8]) -> Result<(Value, &[u8]), Error> {
140    if bytes.is_empty() {
141        return Err(Error::UnexpectedEof);
142    }
143    let t = bytes[0];
144    bytes = &bytes[1..];
145    match t {
146        0 => Ok((Value::Nil, bytes)),
147        1 => Ok((Value::Bool(true), bytes)),
148        2 => Ok((Value::Bool(false), bytes)),
149        3 => {
150            let (i, rest) = decode_varint(bytes)?;
151            Ok((Value::Int(i), rest))
152        }
153        5 => {
154            let (len_i, rest) = decode_varint(bytes)?;
155            let len: usize = len_i.try_into().map_err(|_| Error::InvalidVarInt)?;
156            let rest_len = rest.len();
157            if rest_len < len {
158                return Err(Error::UnexpectedEof);
159            }
160            let payload = rest[..len].to_vec();
161            let rest2 = &rest[len..];
162            Ok((Value::Bytes(payload), rest2))
163        }
164        6 => {
165            let (len_i, mut rest) = decode_varint(bytes)?;
166            let len: usize = len_i.try_into().map_err(|_| Error::InvalidVarInt)?;
167            let mut items = Vec::with_capacity(len);
168            for _ in 0..len {
169                let (v, r) = decode(rest)?;
170                items.push(v);
171                rest = r;
172            }
173            Ok((Value::List(items), rest))
174        }
175        7 => {
176            let (len_i, mut rest) = decode_varint(bytes)?;
177            let len: usize = len_i.try_into().map_err(|_| Error::InvalidVarInt)?;
178            let mut map: BTreeMap<Value, Value> = BTreeMap::new();
179            for _ in 0..len {
180                let (k, r1) = decode(rest)?;
181                let (v, r2) = decode(r1)?;
182                map.insert(k, v);
183                rest = r2;
184            }
185            Ok((Value::Map(map), rest))
186        }
187        other => Err(Error::InvalidType(other)),
188    }
189}
190
191fn encode_into(value: &Value, out: &mut Vec<u8>) {
192    use Value::*;
193    match value {
194        Nil => out.push(0),
195        Bool(true) => out.push(1),
196        Bool(false) => out.push(2),
197        Int(i) => {
198            out.push(3);
199            encode_varint(*i, out);
200        }
201        Bytes(b) => {
202            out.push(5);
203            encode_varint(b.len() as i128, out);
204            out.extend_from_slice(b);
205        }
206        List(items) => {
207            out.push(6);
208            encode_varint(items.len() as i128, out);
209            for it in items {
210                encode_into(it, out);
211            }
212        }
213        Map(kvs) => {
214            out.push(7);
215            encode_varint(kvs.len() as i128, out);
216            for (k, v) in kvs.iter() {
217                encode_into(k, out);
218                encode_into(v, out);
219            }
220        }
221    }
222}
223
224fn encode_varint(n: i128, out: &mut Vec<u8>) {
225    if n == 0 {
226        out.push(0);
227        return;
228    }
229    let sign_bit: u8 = if n >= 0 { 0 } else { 1 };
230    let mag = magnitude_u128(n);
231    let be = mag.to_be_bytes();
232    // strip leading zeros
233    let first_nz = be.iter().position(|&b| b != 0).unwrap_or(be.len() - 1);
234    let bytes = &be[first_nz..];
235    let len = bytes.len();
236    assert!(len <= 127, "varint magnitude too large to encode length in 7 bits");
237    out.push((sign_bit << 7) | (len as u8));
238    out.extend_from_slice(bytes);
239}
240
241fn decode_varint(input: &[u8]) -> Result<(i128, &[u8]), Error> {
242    if input.is_empty() {
243        return Err(Error::UnexpectedEof);
244    }
245    let b0 = input[0];
246    if b0 == 0 {
247        return Ok((0, &input[1..]));
248    }
249    let sign = (b0 & 0b1000_0000) >> 7;
250    let len = (b0 & 0b0111_1111) as usize;
251    let rest = &input[1..];
252    if rest.len() < len {
253        return Err(Error::UnexpectedEof);
254    }
255    let payload = &rest[..len];
256    let rest2 = &rest[len..];
257    let mut mag: u128 = 0;
258    for &byte in payload {
259        mag = (mag << 8) | (byte as u128);
260    }
261    if sign == 0 {
262        if mag > i128::MAX as u128 {
263            return Err(Error::Overflow);
264        }
265        Ok((mag as i128, rest2))
266    } else {
267        // negative: -mag, but ensure mag fits into i128::MAX + 1
268        if mag > (i128::MAX as u128) + 1 {
269            return Err(Error::Overflow);
270        }
271        if mag == 0 {
272            // -0 should still be 0, but Elixir encode won't produce sign=1 with mag=0
273            return Ok((0, rest2));
274        }
275        let val = if mag == (i128::MAX as u128) + 1 { i128::MIN } else { -(mag as i128) };
276        Ok((val, rest2))
277    }
278}
279
280#[inline]
281fn magnitude_u128(n: i128) -> u128 {
282    if n >= 0 { n as u128 } else { (!(n as u128)).wrapping_add(1) }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn roundtrip_primitives() {
291        let cases = vec![
292            Value::Nil,
293            Value::Bool(true),
294            Value::Bool(false),
295            Value::Int(0),
296            Value::Int(1),
297            Value::Int(-1),
298            Value::Int(255),
299            Value::Int(256),
300            Value::Bytes(vec![]),
301            Value::Bytes(b"abc".to_vec()),
302        ];
303        for v in cases {
304            let enc = encode(&v);
305            let dec = decode_all(&enc).unwrap();
306            assert_eq!(v, dec);
307        }
308    }
309
310    #[test]
311    fn roundtrip_lists_and_maps() {
312        let list = Value::List(vec![Value::Int(1), Value::Bool(false), Value::Bytes(b"x".to_vec())]);
313        let mut bm = BTreeMap::new();
314        bm.insert(Value::Bytes(b"b".to_vec()), Value::Int(2));
315        bm.insert(Value::Bytes(b"a".to_vec()), Value::Int(1));
316        let map = Value::Map(bm);
317        let v = Value::List(vec![list, map]);
318        let enc = encode(&v);
319        let dec = decode_all(&enc).unwrap();
320        assert_eq!(v, dec);
321    }
322
323    #[test]
324    fn varint_sign_and_length() {
325        for &n in &[0i128, 1, -1, 127, 128, -128, 255, 256, i128::MIN + 1] {
326            let mut buf = Vec::new();
327            encode_varint(n, &mut buf);
328            let (m, rest) = decode_varint(&buf).unwrap();
329            assert_eq!(n, m);
330            assert!(rest.is_empty());
331        }
332        // i128::MIN is representable
333        let mut buf = Vec::new();
334        encode_varint(i128::MIN, &mut buf);
335        let (m, rest) = decode_varint(&buf).unwrap();
336        assert_eq!(i128::MIN, m);
337        assert!(rest.is_empty());
338    }
339}