1use std::io;
4
5use bytes::{Bytes, BytesMut};
6use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
7
8use crate::core::message::Message;
9
10pub struct MessageCodec {
13 codec: LengthDelimitedCodec,
14}
15
16impl MessageCodec {
17 pub fn new() -> Self {
19 Self {
20 codec: LengthDelimitedCodec::new(),
21 }
22 }
23}
24
25impl Default for MessageCodec {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl Decoder for MessageCodec {
32 type Item = Message;
33 type Error = io::Error;
34
35 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
36 let bytes = match self.codec.decode(src)? {
37 Some(bytes) => bytes,
38 None => return Ok(None),
39 };
40
41 match bincode::decode_from_slice(&bytes, bincode::config::standard()) {
42 Ok((message, _length)) => Ok(Some(message)),
43 Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
44 }
45 }
46}
47
48impl Encoder<Message> for MessageCodec {
49 type Error = io::Error;
50
51 fn encode(&mut self, message: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
52 let _ = match bincode::encode_to_vec(message, bincode::config::standard()) {
53 Ok(bytes) => self.codec.encode(Bytes::copy_from_slice(&bytes), dst),
54 Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
55 };
56
57 Ok(())
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use rand::{thread_rng, Rng};
64
65 use super::*;
66 use crate::{
67 core::id::Id,
68 message::{Chunk, FindKNodes, KNodes, Ping, Pong},
69 };
70
71 #[test]
72 fn codec_ping() {
73 let mut rng = thread_rng();
74
75 let message = Message::Ping(Ping {
76 nonce: rng.gen(),
77 id: Id::rand(),
78 });
79
80 let mut codec = MessageCodec::new();
81 let mut dst = BytesMut::new();
82
83 assert!(codec.encode(message.clone(), &mut dst).is_ok());
84 assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
85 }
86
87 #[test]
88 fn codec_pong() {
89 let mut rng = thread_rng();
90
91 let message = Message::Pong(Pong {
92 nonce: rng.gen(),
93 id: Id::rand(),
94 });
95
96 let mut codec = MessageCodec::new();
97 let mut dst = BytesMut::new();
98
99 assert!(codec.encode(message.clone(), &mut dst).is_ok());
100 assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
101 }
102
103 #[test]
104 fn codec_find_k_nodes() {
105 let mut rng = thread_rng();
106
107 let message = Message::FindKNodes(FindKNodes {
108 nonce: rng.gen(),
109 id: Id::rand(),
110 });
111
112 let mut codec = MessageCodec::new();
113 let mut dst = BytesMut::new();
114
115 assert!(codec.encode(message.clone(), &mut dst).is_ok());
116 assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
117 }
118
119 #[test]
120 fn codec_k_nodes() {
121 let mut rng = thread_rng();
122
123 let message = Message::KNodes(KNodes {
124 nonce: rng.gen(),
125 nodes: vec![(Id::from_u16(0), "127.0.0.1:0".parse().unwrap())],
126 });
127
128 let mut codec = MessageCodec::new();
129 let mut dst = BytesMut::new();
130
131 assert!(codec.encode(message.clone(), &mut dst).is_ok());
132 assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
133 }
134
135 #[test]
136 fn codec_chunk() {
137 use rand::Fill;
138
139 let mut rng = thread_rng();
140 let mut data = [0u8; 32];
141 data.try_fill(&mut rng).unwrap();
142
143 let message = Message::Chunk(Chunk {
144 nonce: rng.gen(),
145 height: rng.gen(),
146 data: Bytes::copy_from_slice(&data),
147 });
148
149 let mut codec = MessageCodec::new();
150 let mut dst = BytesMut::new();
151
152 assert!(codec.encode(message.clone(), &mut dst).is_ok());
153 assert_eq!(codec.decode(&mut dst).unwrap().unwrap(), message);
154 }
155}