commonware_codec/
codec.rs

1//! Core codec traits and implementations
2
3use crate::{
4    buffer::{ReadBuffer, WriteBuffer},
5    error::Error,
6};
7use bytes::Bytes;
8
9/// Trait for types that can be encoded to and decoded from bytes
10pub trait Codec: Sized {
11    /// Encodes this value to a writer.
12    fn write(&self, writer: &mut impl Writer);
13
14    /// Decodes a value from a reader.
15    fn read(reader: &mut impl Reader) -> Result<Self, Error>;
16
17    /// Returns the encoded length of this value.
18    fn len_encoded(&self) -> usize;
19
20    /// Encodes a value to bytes.
21    fn encode(&self) -> Vec<u8> {
22        let len = self.len_encoded();
23        let mut buffer = WriteBuffer::new(len);
24        self.write(&mut buffer);
25        assert!(buffer.len() == len);
26        buffer.into()
27    }
28
29    /// Decodes a value from bytes.
30    /// Returns an error if there is extra data after decoding the value.
31    fn decode(bytes: impl Into<Bytes>) -> Result<Self, Error> {
32        let mut reader = ReadBuffer::new(bytes.into());
33        let result = Self::read(&mut reader);
34        let remaining = reader.remaining();
35        if remaining > 0 {
36            return Err(Error::ExtraData(remaining));
37        }
38        result
39    }
40}
41
42/// Trait for types that have a fixed-length encoding
43pub trait SizedCodec: Codec {
44    /// The encoded length of this value.
45    const LEN_ENCODED: usize;
46
47    /// Returns the encoded length of this value.
48    ///
49    /// Should not be overridden by implementations.
50    fn len_encoded(&self) -> usize {
51        Self::LEN_ENCODED
52    }
53
54    /// Encodes a value to fixed-size bytes.
55    fn encode_fixed<const N: usize>(&self) -> [u8; N] {
56        // Ideally this is a compile-time check, but we can't do that in the current Rust version
57        // without adding a new generic parameter to the trait.
58        assert_eq!(
59            N,
60            Self::LEN_ENCODED,
61            "Can't encode {} bytes into {} bytes",
62            Self::LEN_ENCODED,
63            N
64        );
65
66        self.encode().try_into().unwrap()
67    }
68}
69
70/// Trait for codec read operations
71pub trait Reader {
72    /// Reads a u8 value
73    fn read_u8(&mut self) -> Result<u8, Error>;
74
75    /// Reads a u16 value
76    fn read_u16(&mut self) -> Result<u16, Error>;
77
78    /// Reads a u32 value
79    fn read_u32(&mut self) -> Result<u32, Error>;
80
81    /// Reads a u64 value
82    fn read_u64(&mut self) -> Result<u64, Error>;
83
84    /// Reads a u128 value
85    fn read_u128(&mut self) -> Result<u128, Error>;
86
87    /// Reads a i8 value
88    fn read_i8(&mut self) -> Result<i8, Error>;
89
90    /// Reads a i16 value
91    fn read_i16(&mut self) -> Result<i16, Error>;
92
93    /// Reads a i32 value
94    fn read_i32(&mut self) -> Result<i32, Error>;
95
96    /// Reads a i64 value
97    fn read_i64(&mut self) -> Result<i64, Error>;
98
99    /// Reads a i128 value
100    fn read_i128(&mut self) -> Result<i128, Error>;
101
102    /// Reads a f32 value
103    fn read_f32(&mut self) -> Result<f32, Error>;
104
105    /// Reads a f64 value
106    fn read_f64(&mut self) -> Result<f64, Error>;
107
108    /// Reads a varint-encoded integer
109    fn read_varint(&mut self) -> Result<u64, Error>;
110
111    /// Reads bytes with a length prefix
112    fn read_bytes(&mut self) -> Result<Bytes, Error>;
113
114    /// Reads bytes with a length prefix, with a limit on the number of bytes
115    fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error>;
116
117    /// Reads a fixed number of bytes
118    fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error>;
119
120    /// Reads a fixed number of bytes into a fixed-size byte array
121    fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error>;
122
123    /// Reads a boolean value
124    fn read_bool(&mut self) -> Result<bool, Error>;
125
126    /// Reads an option value
127    fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error>;
128
129    /// Reads a vector with a length prefix
130    fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error>;
131
132    /// Reads a vector with a length prefix, with a limit on the number of elements
133    fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error>;
134}
135
136/// Trait for codec write operations
137pub trait Writer {
138    /// Writes a u8 value
139    fn write_u8(&mut self, value: u8);
140
141    /// Writes a u16 value
142    fn write_u16(&mut self, value: u16);
143
144    /// Writes a u32 value
145    fn write_u32(&mut self, value: u32);
146
147    /// Writes a u64 value
148    fn write_u64(&mut self, value: u64);
149
150    /// Writes a u128 value
151    fn write_u128(&mut self, value: u128);
152
153    /// Writes a i8 value
154    fn write_i8(&mut self, value: i8);
155
156    /// Writes a i16 value
157    fn write_i16(&mut self, value: i16);
158
159    /// Writes a i32 value
160    fn write_i32(&mut self, value: i32);
161
162    /// Writes a i64 value
163    fn write_i64(&mut self, value: i64);
164
165    /// Writes a i128 value
166    fn write_i128(&mut self, value: i128);
167
168    /// Writes a f32 value
169    fn write_f32(&mut self, value: f32);
170
171    /// Writes a f64 value
172    fn write_f64(&mut self, value: f64);
173
174    /// Writes a varint-encoded integer
175    fn write_varint(&mut self, value: u64);
176
177    /// Writes bytes with a length prefix
178    fn write_bytes(&mut self, bytes: &[u8]);
179
180    /// Writes a fixed-size byte array
181    fn write_fixed(&mut self, bytes: &[u8]);
182
183    /// Writes a boolean value
184    fn write_bool(&mut self, value: bool);
185
186    /// Writes an option value
187    fn write_option<T: Codec>(&mut self, value: &Option<T>);
188
189    /// Writes a vector with a length prefix
190    fn write_vec<T: Codec>(&mut self, values: &[T]);
191}
192
193// Implement Reader for ReadBuffer
194impl Reader for ReadBuffer {
195    fn read_u8(&mut self) -> Result<u8, Error> {
196        self.get_u8()
197    }
198
199    fn read_u16(&mut self) -> Result<u16, Error> {
200        self.get_u16()
201    }
202
203    fn read_u32(&mut self) -> Result<u32, Error> {
204        self.get_u32()
205    }
206
207    fn read_u64(&mut self) -> Result<u64, Error> {
208        self.get_u64()
209    }
210
211    fn read_u128(&mut self) -> Result<u128, Error> {
212        self.get_u128()
213    }
214
215    fn read_i8(&mut self) -> Result<i8, Error> {
216        self.get_i8()
217    }
218
219    fn read_i16(&mut self) -> Result<i16, Error> {
220        self.get_i16()
221    }
222
223    fn read_i32(&mut self) -> Result<i32, Error> {
224        self.get_i32()
225    }
226
227    fn read_i64(&mut self) -> Result<i64, Error> {
228        self.get_i64()
229    }
230
231    fn read_i128(&mut self) -> Result<i128, Error> {
232        self.get_i128()
233    }
234
235    fn read_f32(&mut self) -> Result<f32, Error> {
236        self.get_f32()
237    }
238
239    fn read_f64(&mut self) -> Result<f64, Error> {
240        self.get_f64()
241    }
242
243    fn read_varint(&mut self) -> Result<u64, Error> {
244        self.read_varint()
245    }
246
247    fn read_bytes(&mut self) -> Result<Bytes, Error> {
248        let len = self.read_varint()? as usize;
249        self.read_n_bytes(len)
250    }
251
252    fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error> {
253        let bytes = self.split_to(n)?;
254        Ok(bytes)
255    }
256
257    fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error> {
258        let len = self.read_varint()? as usize;
259        if len > max {
260            return Err(Error::LengthExceeded(len, max));
261        }
262        self.read_n_bytes(len)
263    }
264
265    fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error> {
266        let mut bytes = [0u8; N];
267        self.copy_to_slice(&mut bytes)?;
268        Ok(bytes)
269    }
270
271    fn read_bool(&mut self) -> Result<bool, Error> {
272        let b = self.read_u8()?;
273        if b > 1 {
274            return Err(Error::InvalidBool);
275        }
276        Ok(b != 0)
277    }
278
279    fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error> {
280        let has_value = self.read_bool()?;
281
282        if has_value {
283            Ok(Some(T::read(self)?))
284        } else {
285            Ok(None)
286        }
287    }
288
289    fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error> {
290        let len = self.read_varint()? as usize;
291        let mut items = Vec::with_capacity(len);
292        for _ in 0..len {
293            items.push(T::read(self)?);
294        }
295        Ok(items)
296    }
297
298    fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error> {
299        let len = self.read_varint()? as usize;
300
301        if len > max {
302            return Err(Error::LengthExceeded(len, max));
303        }
304
305        let mut items = Vec::with_capacity(len);
306        for _ in 0..len {
307            items.push(T::read(self)?);
308        }
309        Ok(items)
310    }
311}
312
313// Implement Writer for WriteBuffer
314impl Writer for WriteBuffer {
315    fn write_u8(&mut self, value: u8) {
316        self.put_u8(value)
317    }
318
319    fn write_u16(&mut self, value: u16) {
320        self.put_u16(value)
321    }
322
323    fn write_u32(&mut self, value: u32) {
324        self.put_u32(value)
325    }
326
327    fn write_u64(&mut self, value: u64) {
328        self.put_u64(value)
329    }
330
331    fn write_u128(&mut self, value: u128) {
332        self.put_u128(value)
333    }
334
335    fn write_i8(&mut self, value: i8) {
336        self.put_i8(value)
337    }
338
339    fn write_i16(&mut self, value: i16) {
340        self.put_i16(value)
341    }
342
343    fn write_i32(&mut self, value: i32) {
344        self.put_i32(value)
345    }
346
347    fn write_i64(&mut self, value: i64) {
348        self.put_i64(value)
349    }
350
351    fn write_i128(&mut self, value: i128) {
352        self.put_i128(value)
353    }
354
355    fn write_f32(&mut self, value: f32) {
356        self.put_f32(value)
357    }
358
359    fn write_f64(&mut self, value: f64) {
360        self.put_f64(value)
361    }
362
363    fn write_varint(&mut self, value: u64) {
364        self.write_varint(value)
365    }
366
367    fn write_bytes(&mut self, bytes: &[u8]) {
368        self.write_varint(bytes.len() as u64);
369        self.write_fixed(bytes);
370    }
371
372    fn write_fixed(&mut self, bytes: &[u8]) {
373        self.put_slice(bytes);
374    }
375
376    fn write_bool(&mut self, value: bool) {
377        self.put_u8(if value { 1 } else { 0 });
378    }
379
380    fn write_option<T: Codec>(&mut self, value: &Option<T>) {
381        match value {
382            Some(v) => {
383                self.write_bool(true);
384                v.write(self);
385            }
386            None => {
387                self.write_bool(false);
388            }
389        }
390    }
391
392    fn write_vec<T: Codec>(&mut self, values: &[T]) {
393        self.write_varint(values.len() as u64);
394        for value in values {
395            value.write(self);
396        }
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::{varint::varint_size, Codec, Error, ReadBuffer, WriteBuffer};
404    use bytes::Bytes;
405
406    #[test]
407    fn test_insufficient_buffer() {
408        let mut reader = ReadBuffer::new(Bytes::from_static(&[0x01, 0x02]));
409        assert!(matches!(u32::read(&mut reader), Err(Error::EndOfBuffer)));
410    }
411
412    #[test]
413    fn test_extra_data() {
414        let encoded = Bytes::from_static(&[0x01, 0x02]);
415        assert!(matches!(u8::decode(encoded), Err(Error::ExtraData(1))));
416    }
417
418    #[test]
419    fn test_invalid_bool() {
420        let encoded = Bytes::from_static(&[0x02]);
421        assert!(matches!(bool::decode(encoded), Err(Error::InvalidBool)));
422    }
423
424    #[test]
425    fn test_varint() {
426        let value = u64::MAX / 2;
427        let mut writer = WriteBuffer::new(varint_size(value));
428        writer.write_varint(value);
429        let mut reader = ReadBuffer::new(writer.freeze());
430        let result = reader.read_varint().unwrap();
431        assert_eq!(result, value);
432    }
433
434    #[test]
435    fn test_length_limit_exceeded() {
436        let mut writer = WriteBuffer::new(10);
437        writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
438        let mut reader = ReadBuffer::new(writer.freeze());
439        assert!(matches!(
440            reader.read_bytes_lte(5),
441            Err(Error::LengthExceeded(6, 5))
442        ));
443    }
444    #[test]
445    fn test_bytes_lte_success() {
446        let mut writer = WriteBuffer::new(10);
447        writer.write_bytes(&[1, 2, 3]);
448        let mut reader = ReadBuffer::new(writer.freeze());
449        let result = reader.read_bytes_lte(5).unwrap();
450        assert_eq!(result, Bytes::from_static(&[1, 2, 3]));
451    }
452
453    #[test]
454    fn test_bytes_lte_exceeded() {
455        let mut writer = WriteBuffer::new(10);
456        writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
457        let mut reader = ReadBuffer::new(writer.freeze());
458        assert!(matches!(
459            reader.read_bytes_lte(5),
460            Err(Error::LengthExceeded(6, 5))
461        ));
462    }
463
464    #[test]
465    fn test_vec_lte_success() {
466        let mut writer = WriteBuffer::new(10);
467        writer.write_vec(&[1u8, 2u8]);
468        let mut reader = ReadBuffer::new(writer.freeze());
469        let result = reader.read_vec_lte::<u8>(3).unwrap();
470        assert_eq!(result, vec![1u8, 2u8]);
471    }
472
473    #[test]
474    fn test_vec_lte_exceeded() {
475        let mut writer = WriteBuffer::new(10);
476        writer.write_vec(&[1u8, 2u8, 3u8]);
477        let mut reader = ReadBuffer::new(writer.freeze());
478        assert!(matches!(
479            reader.read_vec_lte::<u8>(2),
480            Err(Error::LengthExceeded(3, 2))
481        ));
482    }
483
484    #[test]
485    fn test_encode_fixed() {
486        let value = 42u32;
487        let encoded: [u8; 4] = value.encode_fixed();
488        let decoded = u32::decode(Bytes::copy_from_slice(&encoded)).unwrap();
489        assert_eq!(value, decoded);
490    }
491
492    #[test]
493    #[should_panic(expected = "Can't encode 4 bytes into 5 bytes")]
494    fn test_encode_fixed_panic() {
495        let _: [u8; 5] = 42u32.encode_fixed();
496    }
497}