1use crate::{
4    buffer::{ReadBuffer, WriteBuffer},
5    error::Error,
6};
7use bytes::Bytes;
8
9pub trait Codec: Sized {
11    fn write(&self, writer: &mut impl Writer);
13
14    fn read(reader: &mut impl Reader) -> Result<Self, Error>;
16
17    fn len_encoded(&self) -> usize;
19
20    fn encode(&self) -> Vec<u8> {
22        let len = self.len_encoded();
23        let mut buffer = WriteBuffer::new(len);
24        self.write(&mut buffer);
25        assert!(buffer.len() == len);
26        buffer.into()
27    }
28
29    fn decode(bytes: impl Into<Bytes>) -> Result<Self, Error> {
32        let mut reader = ReadBuffer::new(bytes.into());
33        let result = Self::read(&mut reader);
34        let remaining = reader.remaining();
35        if remaining > 0 {
36            return Err(Error::ExtraData(remaining));
37        }
38        result
39    }
40}
41
42pub trait SizedCodec: Codec {
44    const LEN_CODEC: usize;
46
47    fn len_encoded(&self) -> usize {
49        Self::LEN_CODEC
50    }
51
52    fn encode_fixed<const N: usize>(&self) -> [u8; N] {
54        self.encode().try_into().unwrap()
55    }
56}
57
58pub trait Reader {
60    fn read<T: Codec>(&mut self) -> Result<T, Error>;
62
63    fn read_u8(&mut self) -> Result<u8, Error>;
65
66    fn read_u16(&mut self) -> Result<u16, Error>;
68
69    fn read_u32(&mut self) -> Result<u32, Error>;
71
72    fn read_u64(&mut self) -> Result<u64, Error>;
74
75    fn read_u128(&mut self) -> Result<u128, Error>;
77
78    fn read_i8(&mut self) -> Result<i8, Error>;
80
81    fn read_i16(&mut self) -> Result<i16, Error>;
83
84    fn read_i32(&mut self) -> Result<i32, Error>;
86
87    fn read_i64(&mut self) -> Result<i64, Error>;
89
90    fn read_i128(&mut self) -> Result<i128, Error>;
92
93    fn read_f32(&mut self) -> Result<f32, Error>;
95
96    fn read_f64(&mut self) -> Result<f64, Error>;
98
99    fn read_varint(&mut self) -> Result<u64, Error>;
101
102    fn read_bytes(&mut self) -> Result<Bytes, Error>;
104
105    fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error>;
107
108    fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error>;
110
111    fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error>;
113
114    fn read_bool(&mut self) -> Result<bool, Error>;
116
117    fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error>;
119
120    fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error>;
122
123    fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error>;
125}
126
127pub trait Writer {
129    fn write<T: Codec>(&mut self, value: &T);
131
132    fn write_u8(&mut self, value: u8);
134
135    fn write_u16(&mut self, value: u16);
137
138    fn write_u32(&mut self, value: u32);
140
141    fn write_u64(&mut self, value: u64);
143
144    fn write_u128(&mut self, value: u128);
146
147    fn write_i8(&mut self, value: i8);
149
150    fn write_i16(&mut self, value: i16);
152
153    fn write_i32(&mut self, value: i32);
155
156    fn write_i64(&mut self, value: i64);
158
159    fn write_i128(&mut self, value: i128);
161
162    fn write_f32(&mut self, value: f32);
164
165    fn write_f64(&mut self, value: f64);
167
168    fn write_varint(&mut self, value: u64);
170
171    fn write_bytes(&mut self, bytes: &[u8]);
173
174    fn write_fixed(&mut self, bytes: &[u8]);
176
177    fn write_bool(&mut self, value: bool);
179
180    fn write_option<T: Codec>(&mut self, value: &Option<T>);
182
183    fn write_vec<T: Codec>(&mut self, values: &[T]);
185}
186
187impl Reader for ReadBuffer {
189    fn read<T: Codec>(&mut self) -> Result<T, Error> {
190        T::read(self)
191    }
192
193    fn read_u8(&mut self) -> Result<u8, Error> {
194        self.get_u8()
195    }
196
197    fn read_u16(&mut self) -> Result<u16, Error> {
198        self.get_u16()
199    }
200
201    fn read_u32(&mut self) -> Result<u32, Error> {
202        self.get_u32()
203    }
204
205    fn read_u64(&mut self) -> Result<u64, Error> {
206        self.get_u64()
207    }
208
209    fn read_u128(&mut self) -> Result<u128, Error> {
210        self.get_u128()
211    }
212
213    fn read_i8(&mut self) -> Result<i8, Error> {
214        self.get_i8()
215    }
216
217    fn read_i16(&mut self) -> Result<i16, Error> {
218        self.get_i16()
219    }
220
221    fn read_i32(&mut self) -> Result<i32, Error> {
222        self.get_i32()
223    }
224
225    fn read_i64(&mut self) -> Result<i64, Error> {
226        self.get_i64()
227    }
228
229    fn read_i128(&mut self) -> Result<i128, Error> {
230        self.get_i128()
231    }
232
233    fn read_f32(&mut self) -> Result<f32, Error> {
234        self.get_f32()
235    }
236
237    fn read_f64(&mut self) -> Result<f64, Error> {
238        self.get_f64()
239    }
240
241    fn read_varint(&mut self) -> Result<u64, Error> {
242        self.read_varint()
243    }
244
245    fn read_bytes(&mut self) -> Result<Bytes, Error> {
246        let len = self.read_varint()? as usize;
247        self.read_n_bytes(len)
248    }
249
250    fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error> {
251        let bytes = self.split_to(n)?;
252        Ok(bytes)
253    }
254
255    fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error> {
256        let len = self.read_varint()? as usize;
257        if len > max {
258            return Err(Error::LengthExceeded(len, max));
259        }
260        self.read_n_bytes(len)
261    }
262
263    fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error> {
264        let mut bytes = [0u8; N];
265        self.copy_to_slice(&mut bytes)?;
266        Ok(bytes)
267    }
268
269    fn read_bool(&mut self) -> Result<bool, Error> {
270        let b = self.read_u8()?;
271        if b > 1 {
272            return Err(Error::InvalidBool);
273        }
274        Ok(b != 0)
275    }
276
277    fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error> {
278        let has_value = self.read_bool()?;
279
280        if has_value {
281            Ok(Some(self.read()?))
282        } else {
283            Ok(None)
284        }
285    }
286
287    fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error> {
288        let len = self.read_varint()? as usize;
289        let mut items = Vec::with_capacity(len);
290        for _ in 0..len {
291            items.push(self.read()?);
292        }
293        Ok(items)
294    }
295
296    fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error> {
297        let len = self.read_varint()? as usize;
298
299        if len > max {
300            return Err(Error::LengthExceeded(len, max));
301        }
302
303        let mut items = Vec::with_capacity(len);
304        for _ in 0..len {
305            items.push(self.read()?);
306        }
307        Ok(items)
308    }
309}
310
311impl Writer for WriteBuffer {
313    fn write<T: Codec>(&mut self, value: &T) {
314        value.write(self);
315    }
316
317    fn write_u8(&mut self, value: u8) {
318        self.put_u8(value)
319    }
320
321    fn write_u16(&mut self, value: u16) {
322        self.put_u16(value)
323    }
324
325    fn write_u32(&mut self, value: u32) {
326        self.put_u32(value)
327    }
328
329    fn write_u64(&mut self, value: u64) {
330        self.put_u64(value)
331    }
332
333    fn write_u128(&mut self, value: u128) {
334        self.put_u128(value)
335    }
336
337    fn write_i8(&mut self, value: i8) {
338        self.put_i8(value)
339    }
340
341    fn write_i16(&mut self, value: i16) {
342        self.put_i16(value)
343    }
344
345    fn write_i32(&mut self, value: i32) {
346        self.put_i32(value)
347    }
348
349    fn write_i64(&mut self, value: i64) {
350        self.put_i64(value)
351    }
352
353    fn write_i128(&mut self, value: i128) {
354        self.put_i128(value)
355    }
356
357    fn write_f32(&mut self, value: f32) {
358        self.put_f32(value)
359    }
360
361    fn write_f64(&mut self, value: f64) {
362        self.put_f64(value)
363    }
364
365    fn write_varint(&mut self, value: u64) {
366        self.write_varint(value)
367    }
368
369    fn write_bytes(&mut self, bytes: &[u8]) {
370        self.write_varint(bytes.len() as u64);
371        self.write_fixed(bytes);
372    }
373
374    fn write_fixed(&mut self, bytes: &[u8]) {
375        self.put_slice(bytes);
376    }
377
378    fn write_bool(&mut self, value: bool) {
379        self.put_u8(if value { 1 } else { 0 });
380    }
381
382    fn write_option<T: Codec>(&mut self, value: &Option<T>) {
383        match value {
384            Some(v) => {
385                self.write_bool(true);
386                self.write(v);
387            }
388            None => {
389                self.write_bool(false);
390            }
391        }
392    }
393
394    fn write_vec<T: Codec>(&mut self, values: &[T]) {
395        self.write_varint(values.len() as u64);
396        for value in values {
397            self.write(value);
398        }
399    }
400}
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::buffer::{ReadBuffer, WriteBuffer};
405    use crate::error::Error;
406    use bytes::Bytes;
407
408    #[test]
409    fn test_insufficient_buffer() {
410        let mut reader = ReadBuffer::new(Bytes::from_static(&[0x01, 0x02]));
411        assert!(matches!(u32::read(&mut reader), Err(Error::EndOfBuffer)));
412    }
413
414    #[test]
415    fn test_extra_data() {
416        let encoded = Bytes::from_static(&[0x01, 0x02]);
417        assert!(matches!(u8::decode(encoded), Err(Error::ExtraData(1))));
418    }
419
420    #[test]
421    fn test_invalid_bool() {
422        let encoded = Bytes::from_static(&[0x02]);
423        assert!(matches!(bool::decode(encoded), Err(Error::InvalidBool)));
424    }
425
426    #[test]
427    fn test_invalid_varint() {
428        let encoded = Bytes::from_static(&[
429            0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
430        ]);
431        assert!(matches!(
432            ReadBuffer::new(encoded).read_varint(),
433            Err(Error::InvalidVarint)
434        ));
435    }
436
437    #[test]
438    fn test_length_limit_exceeded() {
439        let mut writer = WriteBuffer::new(10);
440        writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
441        let mut reader = ReadBuffer::new(writer.freeze());
442        assert!(matches!(
443            reader.read_bytes_lte(5),
444            Err(Error::LengthExceeded(6, 5))
445        ));
446    }
447    #[test]
448    fn test_bytes_lte_success() {
449        let mut writer = WriteBuffer::new(10);
450        writer.write_bytes(&[1, 2, 3]);
451        let mut reader = ReadBuffer::new(writer.freeze());
452        let result = reader.read_bytes_lte(5).unwrap();
453        assert_eq!(result, Bytes::from_static(&[1, 2, 3]));
454    }
455
456    #[test]
457    fn test_bytes_lte_exceeded() {
458        let mut writer = WriteBuffer::new(10);
459        writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
460        let mut reader = ReadBuffer::new(writer.freeze());
461        assert!(matches!(
462            reader.read_bytes_lte(5),
463            Err(Error::LengthExceeded(6, 5))
464        ));
465    }
466
467    #[test]
468    fn test_vec_lte_success() {
469        let mut writer = WriteBuffer::new(10);
470        writer.write_vec(&[1u8, 2u8]);
471        let mut reader = ReadBuffer::new(writer.freeze());
472        let result = reader.read_vec_lte::<u8>(3).unwrap();
473        assert_eq!(result, vec![1u8, 2u8]);
474    }
475
476    #[test]
477    fn test_vec_lte_exceeded() {
478        let mut writer = WriteBuffer::new(10);
479        writer.write_vec(&[1u8, 2u8, 3u8]);
480        let mut reader = ReadBuffer::new(writer.freeze());
481        assert!(matches!(
482            reader.read_vec_lte::<u8>(2),
483            Err(Error::LengthExceeded(3, 2))
484        ));
485    }
486}