1use crate::{
4 buffer::{ReadBuffer, WriteBuffer},
5 error::Error,
6};
7use bytes::Bytes;
8
9pub trait Codec: Sized {
11 fn write(&self, writer: &mut impl Writer);
13
14 fn read(reader: &mut impl Reader) -> Result<Self, Error>;
16
17 fn len_encoded(&self) -> usize;
19
20 fn encode(&self) -> Vec<u8> {
22 let len = self.len_encoded();
23 let mut buffer = WriteBuffer::new(len);
24 self.write(&mut buffer);
25 assert!(buffer.len() == len);
26 buffer.into()
27 }
28
29 fn decode(bytes: impl Into<Bytes>) -> Result<Self, Error> {
32 let mut reader = ReadBuffer::new(bytes.into());
33 let result = Self::read(&mut reader);
34 let remaining = reader.remaining();
35 if remaining > 0 {
36 return Err(Error::ExtraData(remaining));
37 }
38 result
39 }
40}
41
42pub trait SizedCodec: Codec {
44 const LEN_ENCODED: usize;
46
47 fn len_encoded(&self) -> usize {
51 Self::LEN_ENCODED
52 }
53
54 fn encode_fixed<const N: usize>(&self) -> [u8; N] {
56 assert_eq!(
59 N,
60 Self::LEN_ENCODED,
61 "Can't encode {} bytes into {} bytes",
62 Self::LEN_ENCODED,
63 N
64 );
65
66 self.encode().try_into().unwrap()
67 }
68}
69
70pub trait Reader {
72 fn read_u8(&mut self) -> Result<u8, Error>;
74
75 fn read_u16(&mut self) -> Result<u16, Error>;
77
78 fn read_u32(&mut self) -> Result<u32, Error>;
80
81 fn read_u64(&mut self) -> Result<u64, Error>;
83
84 fn read_u128(&mut self) -> Result<u128, Error>;
86
87 fn read_i8(&mut self) -> Result<i8, Error>;
89
90 fn read_i16(&mut self) -> Result<i16, Error>;
92
93 fn read_i32(&mut self) -> Result<i32, Error>;
95
96 fn read_i64(&mut self) -> Result<i64, Error>;
98
99 fn read_i128(&mut self) -> Result<i128, Error>;
101
102 fn read_f32(&mut self) -> Result<f32, Error>;
104
105 fn read_f64(&mut self) -> Result<f64, Error>;
107
108 fn read_varint(&mut self) -> Result<u64, Error>;
110
111 fn read_bytes(&mut self) -> Result<Bytes, Error>;
113
114 fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error>;
116
117 fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error>;
119
120 fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error>;
122
123 fn read_bool(&mut self) -> Result<bool, Error>;
125
126 fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error>;
128
129 fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error>;
131
132 fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error>;
134}
135
136pub trait Writer {
138 fn write_u8(&mut self, value: u8);
140
141 fn write_u16(&mut self, value: u16);
143
144 fn write_u32(&mut self, value: u32);
146
147 fn write_u64(&mut self, value: u64);
149
150 fn write_u128(&mut self, value: u128);
152
153 fn write_i8(&mut self, value: i8);
155
156 fn write_i16(&mut self, value: i16);
158
159 fn write_i32(&mut self, value: i32);
161
162 fn write_i64(&mut self, value: i64);
164
165 fn write_i128(&mut self, value: i128);
167
168 fn write_f32(&mut self, value: f32);
170
171 fn write_f64(&mut self, value: f64);
173
174 fn write_varint(&mut self, value: u64);
176
177 fn write_bytes(&mut self, bytes: &[u8]);
179
180 fn write_fixed(&mut self, bytes: &[u8]);
182
183 fn write_bool(&mut self, value: bool);
185
186 fn write_option<T: Codec>(&mut self, value: &Option<T>);
188
189 fn write_vec<T: Codec>(&mut self, values: &[T]);
191}
192
193impl Reader for ReadBuffer {
195 fn read_u8(&mut self) -> Result<u8, Error> {
196 self.get_u8()
197 }
198
199 fn read_u16(&mut self) -> Result<u16, Error> {
200 self.get_u16()
201 }
202
203 fn read_u32(&mut self) -> Result<u32, Error> {
204 self.get_u32()
205 }
206
207 fn read_u64(&mut self) -> Result<u64, Error> {
208 self.get_u64()
209 }
210
211 fn read_u128(&mut self) -> Result<u128, Error> {
212 self.get_u128()
213 }
214
215 fn read_i8(&mut self) -> Result<i8, Error> {
216 self.get_i8()
217 }
218
219 fn read_i16(&mut self) -> Result<i16, Error> {
220 self.get_i16()
221 }
222
223 fn read_i32(&mut self) -> Result<i32, Error> {
224 self.get_i32()
225 }
226
227 fn read_i64(&mut self) -> Result<i64, Error> {
228 self.get_i64()
229 }
230
231 fn read_i128(&mut self) -> Result<i128, Error> {
232 self.get_i128()
233 }
234
235 fn read_f32(&mut self) -> Result<f32, Error> {
236 self.get_f32()
237 }
238
239 fn read_f64(&mut self) -> Result<f64, Error> {
240 self.get_f64()
241 }
242
243 fn read_varint(&mut self) -> Result<u64, Error> {
244 self.read_varint()
245 }
246
247 fn read_bytes(&mut self) -> Result<Bytes, Error> {
248 let len = self.read_varint()? as usize;
249 self.read_n_bytes(len)
250 }
251
252 fn read_n_bytes(&mut self, n: usize) -> Result<Bytes, Error> {
253 let bytes = self.split_to(n)?;
254 Ok(bytes)
255 }
256
257 fn read_bytes_lte(&mut self, max: usize) -> Result<Bytes, Error> {
258 let len = self.read_varint()? as usize;
259 if len > max {
260 return Err(Error::LengthExceeded(len, max));
261 }
262 self.read_n_bytes(len)
263 }
264
265 fn read_fixed<const N: usize>(&mut self) -> Result<[u8; N], Error> {
266 let mut bytes = [0u8; N];
267 self.copy_to_slice(&mut bytes)?;
268 Ok(bytes)
269 }
270
271 fn read_bool(&mut self) -> Result<bool, Error> {
272 let b = self.read_u8()?;
273 if b > 1 {
274 return Err(Error::InvalidBool);
275 }
276 Ok(b != 0)
277 }
278
279 fn read_option<T: Codec>(&mut self) -> Result<Option<T>, Error> {
280 let has_value = self.read_bool()?;
281
282 if has_value {
283 Ok(Some(T::read(self)?))
284 } else {
285 Ok(None)
286 }
287 }
288
289 fn read_vec<T: Codec>(&mut self) -> Result<Vec<T>, Error> {
290 let len = self.read_varint()? as usize;
291 let mut items = Vec::with_capacity(len);
292 for _ in 0..len {
293 items.push(T::read(self)?);
294 }
295 Ok(items)
296 }
297
298 fn read_vec_lte<T: Codec>(&mut self, max: usize) -> Result<Vec<T>, Error> {
299 let len = self.read_varint()? as usize;
300
301 if len > max {
302 return Err(Error::LengthExceeded(len, max));
303 }
304
305 let mut items = Vec::with_capacity(len);
306 for _ in 0..len {
307 items.push(T::read(self)?);
308 }
309 Ok(items)
310 }
311}
312
313impl Writer for WriteBuffer {
315 fn write_u8(&mut self, value: u8) {
316 self.put_u8(value)
317 }
318
319 fn write_u16(&mut self, value: u16) {
320 self.put_u16(value)
321 }
322
323 fn write_u32(&mut self, value: u32) {
324 self.put_u32(value)
325 }
326
327 fn write_u64(&mut self, value: u64) {
328 self.put_u64(value)
329 }
330
331 fn write_u128(&mut self, value: u128) {
332 self.put_u128(value)
333 }
334
335 fn write_i8(&mut self, value: i8) {
336 self.put_i8(value)
337 }
338
339 fn write_i16(&mut self, value: i16) {
340 self.put_i16(value)
341 }
342
343 fn write_i32(&mut self, value: i32) {
344 self.put_i32(value)
345 }
346
347 fn write_i64(&mut self, value: i64) {
348 self.put_i64(value)
349 }
350
351 fn write_i128(&mut self, value: i128) {
352 self.put_i128(value)
353 }
354
355 fn write_f32(&mut self, value: f32) {
356 self.put_f32(value)
357 }
358
359 fn write_f64(&mut self, value: f64) {
360 self.put_f64(value)
361 }
362
363 fn write_varint(&mut self, value: u64) {
364 self.write_varint(value)
365 }
366
367 fn write_bytes(&mut self, bytes: &[u8]) {
368 self.write_varint(bytes.len() as u64);
369 self.write_fixed(bytes);
370 }
371
372 fn write_fixed(&mut self, bytes: &[u8]) {
373 self.put_slice(bytes);
374 }
375
376 fn write_bool(&mut self, value: bool) {
377 self.put_u8(if value { 1 } else { 0 });
378 }
379
380 fn write_option<T: Codec>(&mut self, value: &Option<T>) {
381 match value {
382 Some(v) => {
383 self.write_bool(true);
384 v.write(self);
385 }
386 None => {
387 self.write_bool(false);
388 }
389 }
390 }
391
392 fn write_vec<T: Codec>(&mut self, values: &[T]) {
393 self.write_varint(values.len() as u64);
394 for value in values {
395 value.write(self);
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::{varint::varint_size, Codec, Error, ReadBuffer, WriteBuffer};
404 use bytes::Bytes;
405
406 #[test]
407 fn test_insufficient_buffer() {
408 let mut reader = ReadBuffer::new(Bytes::from_static(&[0x01, 0x02]));
409 assert!(matches!(u32::read(&mut reader), Err(Error::EndOfBuffer)));
410 }
411
412 #[test]
413 fn test_extra_data() {
414 let encoded = Bytes::from_static(&[0x01, 0x02]);
415 assert!(matches!(u8::decode(encoded), Err(Error::ExtraData(1))));
416 }
417
418 #[test]
419 fn test_invalid_bool() {
420 let encoded = Bytes::from_static(&[0x02]);
421 assert!(matches!(bool::decode(encoded), Err(Error::InvalidBool)));
422 }
423
424 #[test]
425 fn test_varint() {
426 let value = u64::MAX / 2;
427 let mut writer = WriteBuffer::new(varint_size(value));
428 writer.write_varint(value);
429 let mut reader = ReadBuffer::new(writer.freeze());
430 let result = reader.read_varint().unwrap();
431 assert_eq!(result, value);
432 }
433
434 #[test]
435 fn test_length_limit_exceeded() {
436 let mut writer = WriteBuffer::new(10);
437 writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
438 let mut reader = ReadBuffer::new(writer.freeze());
439 assert!(matches!(
440 reader.read_bytes_lte(5),
441 Err(Error::LengthExceeded(6, 5))
442 ));
443 }
444 #[test]
445 fn test_bytes_lte_success() {
446 let mut writer = WriteBuffer::new(10);
447 writer.write_bytes(&[1, 2, 3]);
448 let mut reader = ReadBuffer::new(writer.freeze());
449 let result = reader.read_bytes_lte(5).unwrap();
450 assert_eq!(result, Bytes::from_static(&[1, 2, 3]));
451 }
452
453 #[test]
454 fn test_bytes_lte_exceeded() {
455 let mut writer = WriteBuffer::new(10);
456 writer.write_bytes(&[1, 2, 3, 4, 5, 6]);
457 let mut reader = ReadBuffer::new(writer.freeze());
458 assert!(matches!(
459 reader.read_bytes_lte(5),
460 Err(Error::LengthExceeded(6, 5))
461 ));
462 }
463
464 #[test]
465 fn test_vec_lte_success() {
466 let mut writer = WriteBuffer::new(10);
467 writer.write_vec(&[1u8, 2u8]);
468 let mut reader = ReadBuffer::new(writer.freeze());
469 let result = reader.read_vec_lte::<u8>(3).unwrap();
470 assert_eq!(result, vec![1u8, 2u8]);
471 }
472
473 #[test]
474 fn test_vec_lte_exceeded() {
475 let mut writer = WriteBuffer::new(10);
476 writer.write_vec(&[1u8, 2u8, 3u8]);
477 let mut reader = ReadBuffer::new(writer.freeze());
478 assert!(matches!(
479 reader.read_vec_lte::<u8>(2),
480 Err(Error::LengthExceeded(3, 2))
481 ));
482 }
483
484 #[test]
485 fn test_encode_fixed() {
486 let value = 42u32;
487 let encoded: [u8; 4] = value.encode_fixed();
488 let decoded = u32::decode(Bytes::copy_from_slice(&encoded)).unwrap();
489 assert_eq!(value, decoded);
490 }
491
492 #[test]
493 #[should_panic(expected = "Can't encode 4 bytes into 5 bytes")]
494 fn test_encode_fixed_panic() {
495 let _: [u8; 5] = 42u32.encode_fixed();
496 }
497}