1use byteorder::{ByteOrder, NetworkEndian};
2use nodecraft::{CheapClone, Node};
3use transformable::Transformable;
4
5use super::MAX_ENCODED_LEN_SIZE;
6
7macro_rules! bail_ping {
8 (
9 $(#[$meta:meta])*
10 $name: ident
11 ) => {
12 $(#[$meta])*
13 #[viewit::viewit(
14 getters(vis_all = "pub"),
15 setters(vis_all = "pub", prefix = "with")
16 )]
17 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
18 #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
19 #[cfg_attr(feature = "rkyv", derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive))]
20 #[cfg_attr(feature = "rkyv", rkyv(compare(PartialEq)))]
21 pub struct $name<I, A> {
22 #[viewit(
24 getter(const, attrs(doc = "Returns the sequence number of the ack")),
25 setter(
26 const,
27 attrs(doc = "Sets the sequence number of the ack (Builder pattern)")
28 )
29 )]
30 sequence_number: u32,
31 #[viewit(
33 getter(const, style = "ref", attrs(doc = "Returns the source node of the ping message")),
34 setter(attrs(doc = "Sets the source node of the ping message (Builder pattern)"))
35 )]
36 source: Node<I, A>,
37
38 #[viewit(
42 getter(const, style = "ref", attrs(doc = "Returns the target node of the ping message")),
43 setter(attrs(doc = "Sets the target node of the ping message (Builder pattern)"))
44 )]
45 target: Node<I, A>,
46 }
47
48 impl<I, A> $name<I, A> {
49 #[inline]
51 pub const fn new(sequence_number: u32, source: Node<I, A>, target: Node<I, A>) -> Self {
52 Self {
53 sequence_number,
54 source,
55 target,
56 }
57 }
58
59 #[inline]
61 pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
62 self.sequence_number = sequence_number;
63 self
64 }
65
66 #[inline]
68 pub fn set_source(&mut self, source: Node<I, A>) -> &mut Self {
69 self.source = source;
70 self
71 }
72
73 #[inline]
75 pub fn set_target(&mut self, target: Node<I, A>) -> &mut Self {
76 self.target = target;
77 self
78 }
79 }
80
81 impl<I: CheapClone, A: CheapClone> CheapClone for $name<I, A> {
82 fn cheap_clone(&self) -> Self {
83 Self {
84 sequence_number: self.sequence_number,
85 source: self.source.cheap_clone(),
86 target: self.target.cheap_clone(),
87 }
88 }
89 }
90
91 #[cfg(feature = "rkyv")]
92 const _: () = {
93 use core::fmt::Debug;
94 use rkyv::Archive;
95
96 paste::paste! {
97 impl<I: Debug + Archive, A: Debug + Archive> core::fmt::Debug for [< Archived $name >] <I, A>
98 where
99 I::Archived: Debug,
100 A::Archived: Debug,
101 {
102 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103 f.debug_struct(std::any::type_name::<Self>())
104 .field("sequence_number", &self.sequence_number)
105 .field("target", &self.target)
106 .field("source", &self.source)
107 .finish()
108 }
109 }
110
111 impl<I: Archive, A: Archive> PartialEq for [< Archived $name >] <I, A>
112 where
113 I::Archived: PartialEq,
114 A::Archived: PartialEq,
115 {
116 fn eq(&self, other: &Self) -> bool {
117 self.sequence_number == other.sequence_number
118 && self.target == other.target
119 && self.source == other.source
120 }
121 }
122
123 impl<I: Archive, A: Archive> Eq for [< Archived $name >] <I, A>
124 where
125 I::Archived: Eq,
126 A::Archived: Eq,
127 {
128 }
129
130 impl<I: Archive, A: Archive> Clone for [< Archived $name >] <I, A>
131 where
132 I::Archived: Clone,
133 A::Archived: Clone,
134 {
135 fn clone(&self) -> Self {
136 Self {
137 sequence_number: self.sequence_number,
138 target: self.target.clone(),
139 source: self.source.clone(),
140 }
141 }
142 }
143
144 impl<I: Archive, A: Archive> core::hash::Hash for [< Archived $name >] <I, A>
145 where
146 I::Archived: core::hash::Hash,
147 A::Archived: core::hash::Hash,
148 {
149 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
150 self.sequence_number.hash(state);
151 self.target.hash(state);
152 self.source.hash(state);
153 }
154 }
155 }
156 };
157
158 paste::paste! {
159 #[doc = concat!("Error when transforming a [`", stringify!($name), "`]")]
160 #[derive(thiserror::Error)]
161 pub enum [< $name TransformError >]<I: Transformable, A: Transformable> {
162 #[error("source node: {0}")]
164 Source(<Node<I, A> as Transformable>::Error),
165 #[error("target node: {0}")]
167 Target(<Node<I, A> as Transformable>::Error),
168 #[error("encode buffer is too small")]
170 BufferTooSmall,
171 #[error("not enough bytes to decode")]
173 NotEnoughBytes,
174 #[error("the encoded bytes is too large")]
176 TooLarge,
177 }
178
179 impl<I: Transformable, A: Transformable> core::fmt::Debug for [< $name TransformError >]<I, A> {
180 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181 write!(f, "{}", self)
182 }
183 }
184
185 impl<I: Transformable, A: Transformable> Transformable for $name<I, A> {
186 type Error = [< $name TransformError >]<I, A>;
187
188 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
189 let encoded_len = self.encoded_len();
190 if encoded_len as u64 > u32::MAX as u64 {
191 return Err(Self::Error::TooLarge);
192 }
193
194 if dst.len() < encoded_len {
195 return Err(Self::Error::BufferTooSmall);
196 }
197
198 let mut offset = 0;
199 NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
200 offset += MAX_ENCODED_LEN_SIZE;
201 NetworkEndian::write_u32(&mut dst[offset..], self.sequence_number);
202 offset += core::mem::size_of::<u32>();
203 offset += self.source.encode(&mut dst[offset..]).map_err(Self::Error::Source)?;
204 offset += self.target.encode(&mut dst[offset..]).map_err(Self::Error::Target)?;
205
206 debug_assert_eq!(
207 offset, encoded_len,
208 "expect bytes written ({encoded_len}) not match actual bytes writtend ({offset})"
209 );
210 Ok(offset)
211 }
212
213 fn encoded_len(&self) -> usize {
214 MAX_ENCODED_LEN_SIZE
215 + core::mem::size_of::<u32>()
216 + self.source.encoded_len()
217 + self.target.encoded_len()
218 }
219
220 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
221 where
222 Self: Sized
223 {
224 if src.len() < MAX_ENCODED_LEN_SIZE {
225 return Err(Self::Error::NotEnoughBytes);
226 }
227 let encoded_len = NetworkEndian::read_u32(src) as usize;
228 if src.len() < encoded_len {
229 return Err(Self::Error::NotEnoughBytes);
230 }
231 let mut offset = MAX_ENCODED_LEN_SIZE;
232 let sequence_number = NetworkEndian::read_u32(&src[offset..]);
233 offset += core::mem::size_of::<u32>();
234 let (source_len, source) = Node::decode(&src[offset..]).map_err(Self::Error::Source)?;
235 offset += source_len;
236 let (target_len, target) = Node::decode(&src[offset..]).map_err(Self::Error::Target)?;
237 offset += target_len;
238
239 debug_assert_eq!(
240 offset, encoded_len,
241 "expect bytes read ({encoded_len}) not match actual bytes read ({offset})"
242 );
243 Ok((offset, Self { sequence_number, source, target }))
244 }
245 }
246 }
247
248 #[cfg(test)]
249 const _: () = {
250 use rand::{Rng, distr::Alphanumeric, rng, random};
251
252 impl $name<smol_str::SmolStr, std::net::SocketAddr> {
253 pub(crate) fn generate(size: usize) -> Self {
254 let trng = rng();
255 let source = trng.sample_iter(&Alphanumeric).take(size).collect::<Vec<u8>>();
256 let source = String::from_utf8(source).unwrap();
257 let source = Node::new(source.into(), format!("127.0.0.1:{}", rng().random_range(0..65535))
258 .parse()
259 .unwrap());
260 let trng = rng();
261 let target = trng.sample_iter(&Alphanumeric).take(size).collect::<Vec<u8>>();
262 let target = String::from_utf8(target).unwrap();
263 let target = Node::new(target.into(), format!("127.0.0.1:{}", rng().random_range(0..65535)).parse().unwrap());
264
265 Self {
266 sequence_number: random(),
267 source,
268 target,
269 }
270 }
271 }
272 };
273 };
274}
275
276bail_ping!(
277 #[doc = "Ping is sent to a target to check if it is alive"]
278 Ping
279);
280bail_ping!(
281 #[doc = "IndirectPing is sent to a target to check if it is alive"]
282 IndirectPing
283);
284
285impl<I, A> From<IndirectPing<I, A>> for Ping<I, A> {
286 fn from(ping: IndirectPing<I, A>) -> Self {
287 Self {
288 sequence_number: ping.sequence_number,
289 source: ping.source,
290 target: ping.target,
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_ping() {
301 for i in 0..100 {
302 let ping = Ping::<_, std::net::SocketAddr>::generate(i);
303 let mut buf = vec![0; ping.encoded_len()];
304 let encoded_len = ping.encode(&mut buf).unwrap();
305 assert_eq!(encoded_len, ping.encoded_len());
306 let (readed, decoded) = Ping::<_, std::net::SocketAddr>::decode(&buf).unwrap();
307 assert_eq!(readed, encoded_len);
308 assert_eq!(decoded, ping);
309 }
310 }
311
312 #[test]
313 fn test_indirect_ping() {
314 for i in 0..100 {
315 let ping = IndirectPing::<_, std::net::SocketAddr>::generate(i);
316 let mut buf = vec![0; ping.encoded_len()];
317 let encoded_len = ping.encode(&mut buf).unwrap();
318 assert_eq!(encoded_len, ping.encoded_len());
319 let (readed, decoded) = IndirectPing::<_, std::net::SocketAddr>::decode(&buf).unwrap();
320 assert_eq!(readed, encoded_len);
321 assert_eq!(decoded, ping);
322 }
323 }
324}