modality_api/
protocol.rs

1use crate::{AttrVal, EventCoordinate, LogicalTime, Nanoseconds, TimelineId};
2use minicbor::{data::Tag, decode, encode, Decode, Decoder, Encode, Encoder};
3use uuid::Uuid;
4
5pub const TAG_NS: Tag = Tag::Unassigned(40000);
6pub const TAG_LOGICAL_TIME: Tag = Tag::Unassigned(40001);
7pub const TAG_TIMELINE_ID: Tag = Tag::Unassigned(40002);
8pub const TAG_EVENT_COORDINATE: Tag = Tag::Unassigned(40003);
9
10impl Encode for Nanoseconds {
11    fn encode<W: encode::Write>(&self, e: &mut Encoder<W>) -> Result<(), encode::Error<W::Error>> {
12        e.tag(TAG_NS)?.u64(self.get_raw())?;
13        Ok(())
14    }
15}
16
17impl<'b> Decode<'b> for Nanoseconds {
18    fn decode(d: &mut Decoder<'b>) -> Result<Self, decode::Error> {
19        let t = d.tag()?;
20        if t != TAG_NS {
21            return Err(decode::Error::Message("Expected TAG_NS"));
22        }
23
24        Ok(d.u64()?.into())
25    }
26}
27
28impl Encode for LogicalTime {
29    fn encode<W: encode::Write>(&self, e: &mut Encoder<W>) -> Result<(), encode::Error<W::Error>> {
30        e.tag(TAG_LOGICAL_TIME)?.encode(self.get_raw())?;
31        Ok(())
32    }
33}
34
35impl<'b> Decode<'b> for LogicalTime {
36    fn decode(d: &mut Decoder<'b>) -> Result<Self, decode::Error> {
37        let t = d.tag()?;
38        if t != TAG_LOGICAL_TIME {
39            return Err(decode::Error::Message("Expected TAG_LOGICAL_TIME"));
40        }
41
42        let els: Result<Vec<u64>, decode::Error> = d.array_iter()?.collect();
43        let els = els?;
44        if els.len() != 4 {
45            return Err(decode::Error::Message("LogicalTime array length must be 4"));
46        }
47
48        Ok(LogicalTime::quaternary(els[0], els[1], els[2], els[3]))
49    }
50}
51
52impl Encode for TimelineId {
53    fn encode<W: encode::Write>(&self, e: &mut Encoder<W>) -> Result<(), encode::Error<W::Error>> {
54        e.tag(TAG_TIMELINE_ID)?.bytes(self.get_raw().as_bytes())?;
55        Ok(())
56    }
57}
58
59impl<'b> Decode<'b> for TimelineId {
60    fn decode(d: &mut Decoder<'b>) -> Result<Self, decode::Error> {
61        let t = d.tag()?;
62        if t != TAG_TIMELINE_ID {
63            return Err(decode::Error::Message("Expected TAG_TIMELINE_ID"));
64        }
65
66        Uuid::from_slice(d.bytes()?)
67            .map(Into::into)
68            .map_err(|_uuid_err| decode::Error::Message("Error decoding uuid for TimelineId"))
69    }
70}
71
72impl Encode for EventCoordinate {
73    fn encode<W: encode::Write>(&self, e: &mut Encoder<W>) -> Result<(), encode::Error<W::Error>> {
74        e.tag(TAG_EVENT_COORDINATE)?.bytes(&self.as_bytes())?;
75        Ok(())
76    }
77}
78
79impl<'b> Decode<'b> for EventCoordinate {
80    fn decode(d: &mut Decoder<'b>) -> Result<Self, decode::Error> {
81        let t = d.tag()?;
82        if t != TAG_EVENT_COORDINATE {
83            return Err(decode::Error::Message("Expected TAG_EVENT_COORDINATE"));
84        }
85
86        EventCoordinate::from_byte_slice(d.bytes()?)
87            .ok_or(decode::Error::Message("Error decoding event coordinate"))
88    }
89}
90
91impl Encode for AttrVal {
92    fn encode<W: encode::Write>(&self, e: &mut Encoder<W>) -> Result<(), encode::Error<W::Error>> {
93        match self {
94            AttrVal::String(s) => {
95                e.str(s.as_ref())?;
96            }
97            AttrVal::Integer(i) => {
98                e.i64(*i)?;
99            }
100            AttrVal::BigInt(bi) => {
101                if **bi >= 0i128 {
102                    e.tag(Tag::PosBignum)?.bytes(&bi.to_be_bytes())?;
103                } else {
104                    // this is what the spec says to do. don't ask me.
105                    e.tag(Tag::NegBignum)?.bytes(&((-1 - **bi).to_be_bytes()))?;
106                }
107            }
108            AttrVal::Float(f) => {
109                e.f64(**f)?;
110            }
111            AttrVal::Bool(b) => {
112                e.bool(*b)?;
113            }
114            AttrVal::Timestamp(ns) => {
115                ns.encode(e)?;
116            }
117            AttrVal::LogicalTime(lt) => {
118                lt.encode(e)?;
119            }
120            AttrVal::TimelineId(tid) => {
121                tid.encode(e)?;
122            }
123            AttrVal::EventCoordinate(ec) => {
124                ec.encode(e)?;
125            }
126        }
127
128        Ok(())
129    }
130}
131
132impl<'b> Decode<'b> for AttrVal {
133    fn decode(d: &mut Decoder<'b>) -> Result<Self, decode::Error> {
134        use minicbor::data::Type;
135        let t = d.datatype()?;
136        match t {
137            Type::Bool => Ok((d.bool()?).into()),
138
139            Type::U8 => Ok((d.u8()?).into()),
140            Type::U16 => Ok((d.u16()?).into()),
141            Type::U32 => Ok((d.u32()?).into()),
142            Type::I8 => Ok((d.i8()?).into()),
143            Type::I16 => Ok((d.i16()?).into()),
144            Type::I32 => Ok((d.i32()?).into()),
145            Type::I64 => Ok((d.i64()?).into()),
146
147            Type::U64 => Ok((d.u64()? as i128).into()),
148            Type::F32 => Ok((d.f32()?).into()),
149            Type::F64 => Ok((d.f64()?).into()),
150
151            Type::String => Ok(d.str()?.into()),
152            Type::StringIndef => {
153                let mut s = String::new();
154                for s_res in d.str_iter()? {
155                    s += s_res?;
156                }
157                Ok(s.into())
158            }
159
160            Type::Tag => {
161                // probe == lookahead
162                match d.probe().tag()? {
163                    TAG_NS => Ok(Nanoseconds::decode(d)?.into()),
164                    TAG_LOGICAL_TIME => Ok(LogicalTime::decode(d)?.into()),
165                    TAG_TIMELINE_ID => Ok(TimelineId::decode(d)?.into()),
166
167                    Tag::PosBignum | Tag::NegBignum => {
168                        let tag = d.tag()?;
169                        let bytes = d.bytes()?;
170                        if bytes.len() != 16 {
171                            // Lame
172                            return Err(decode::Error::Message(
173                                "Bignums must be encoded as exactly 16 bytes",
174                            ));
175                        }
176                        // LAAAAAAAAAAAAAAAAAAAME
177                        let mut encoded_num = i128::from_be_bytes([
178                            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
179                            bytes[7], bytes[8], bytes[9], bytes[10], bytes[11], bytes[12],
180                            bytes[13], bytes[14], bytes[15],
181                        ]);
182                        if tag == Tag::NegBignum {
183                            encoded_num = -1 - encoded_num;
184                        }
185
186                        Ok(encoded_num.into())
187                    }
188
189                    _ => Err(decode::Error::Message("Unexpected Tag for Attrval")),
190                }
191            }
192            _ => Err(decode::Error::TypeMismatch(
193                t,
194                "Unexpected datatype for AttrVal",
195            )),
196        }
197    }
198}
199
200#[cfg(test)]
201mod test {
202    use super::*;
203    use proptest::prelude::*;
204
205    #[test]
206    fn round_trip_attr_val() {
207        proptest!(|(attr_val in crate::proptest_strategies::attr_val())| {
208            let mut buf = vec![];
209            minicbor::encode(&attr_val, &mut buf)?;
210
211            let attr_val_prime: AttrVal = minicbor::decode(&buf)?;
212            prop_assert_eq!(attr_val, attr_val_prime);
213        });
214    }
215
216    #[test]
217    fn round_trip_attr_val_with_codec_specific_negative_number_edge_cases() {
218        let edges = [
219            std::i8::MIN as i64,
220            std::i16::MIN as i64,
221            std::i32::MIN as i64,
222        ];
223        for edge in edges {
224            for offset in -3..=3 {
225                let val = edge + offset;
226
227                let attr_val = AttrVal::from(val);
228                let mut buf = vec![];
229                minicbor::encode(&attr_val, &mut buf).unwrap();
230
231                let attr_val_prime: AttrVal = minicbor::decode(&buf).unwrap();
232                assert_eq!(attr_val, attr_val_prime);
233            }
234        }
235    }
236}