kadmium/
codec.rs

1//! Encoding and decoding utilities for messages.
2
3use std::io;
4
5use bytes::{Bytes, BytesMut};
6use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
7
8use crate::core::message::Message;
9
10/// Backed by Bincode and Tokio's [`LengthDelimitedCodec`], this codec implements the [`Encoder`] and
11/// [`Decoder`] traits for [`Message`].
12pub struct MessageCodec {
13    codec: LengthDelimitedCodec,
14}
15
16impl MessageCodec {
17    /// Returns a new message codec.
18    pub fn new() -> Self {
19        Self {
20            codec: LengthDelimitedCodec::new(),
21        }
22    }
23}
24
25impl Default for MessageCodec {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl Decoder for MessageCodec {
32    type Item = Message;
33    type Error = io::Error;
34
35    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
36        let bytes = match self.codec.decode(src)? {
37            Some(bytes) => bytes,
38            None => return Ok(None),
39        };
40
41        match bincode::decode_from_slice(&bytes, bincode::config::standard()) {
42            Ok((message, _length)) => Ok(Some(message)),
43            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
44        }
45    }
46}
47
48impl Encoder<Message> for MessageCodec {
49    type Error = io::Error;
50
51    fn encode(&mut self, message: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
52        let _ = match bincode::encode_to_vec(message, bincode::config::standard()) {
53            Ok(bytes) => self.codec.encode(Bytes::copy_from_slice(&bytes), dst),
54            Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
55        };
56
57        Ok(())
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use rand::{thread_rng, Rng};
64
65    use super::*;
66    use crate::{
67        core::id::Id,
68        message::{Chunk, FindKNodes, KNodes, Ping, Pong},
69    };
70
71    #[test]
72    fn codec_ping() {
73        let mut rng = thread_rng();
74
75        let message = Message::Ping(Ping {
76            nonce: rng.gen(),
77            id: Id::rand(),
78        });
79
80        let mut codec = MessageCodec::new();
81        let mut dst = BytesMut::new();
82
83        assert!(codec.encode(message.clone(), &mut dst).is_ok());
84        assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
85    }
86
87    #[test]
88    fn codec_pong() {
89        let mut rng = thread_rng();
90
91        let message = Message::Pong(Pong {
92            nonce: rng.gen(),
93            id: Id::rand(),
94        });
95
96        let mut codec = MessageCodec::new();
97        let mut dst = BytesMut::new();
98
99        assert!(codec.encode(message.clone(), &mut dst).is_ok());
100        assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
101    }
102
103    #[test]
104    fn codec_find_k_nodes() {
105        let mut rng = thread_rng();
106
107        let message = Message::FindKNodes(FindKNodes {
108            nonce: rng.gen(),
109            id: Id::rand(),
110        });
111
112        let mut codec = MessageCodec::new();
113        let mut dst = BytesMut::new();
114
115        assert!(codec.encode(message.clone(), &mut dst).is_ok());
116        assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
117    }
118
119    #[test]
120    fn codec_k_nodes() {
121        let mut rng = thread_rng();
122
123        let message = Message::KNodes(KNodes {
124            nonce: rng.gen(),
125            nodes: vec![(Id::from_u16(0), "127.0.0.1:0".parse().unwrap())],
126        });
127
128        let mut codec = MessageCodec::new();
129        let mut dst = BytesMut::new();
130
131        assert!(codec.encode(message.clone(), &mut dst).is_ok());
132        assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
133    }
134
135    #[test]
136    fn codec_chunk() {
137        use rand::Fill;
138
139        let mut rng = thread_rng();
140        let mut data = [0u8; 32];
141        data.try_fill(&mut rng).unwrap();
142
143        let message = Message::Chunk(Chunk {
144            nonce: rng.gen(),
145            height: rng.gen(),
146            data: Bytes::copy_from_slice(&data),
147        });
148
149        let mut codec = MessageCodec::new();
150        let mut dst = BytesMut::new();
151
152        assert!(codec.encode(message.clone(), &mut dst).is_ok());
153        assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
154    }
155}