memberlist_types/
ping.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use nodecraft::{CheapClone, Node};
3use transformable::Transformable;
4
5use super::MAX_ENCODED_LEN_SIZE;
6
7macro_rules! bail_ping {
8  (
9    $(#[$meta:meta])*
10    $name: ident
11  ) => {
12    $(#[$meta])*
13    #[viewit::viewit(
14      getters(vis_all = "pub"),
15      setters(vis_all = "pub", prefix = "with")
16    )]
17    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
18    #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
19    #[cfg_attr(feature = "rkyv", derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive))]
20    #[cfg_attr(feature = "rkyv", rkyv(compare(PartialEq)))]
21    pub struct $name<I, A> {
22      /// The sequence number of the ack
23      #[viewit(
24        getter(const, attrs(doc = "Returns the sequence number of the ack")),
25        setter(
26          const,
27          attrs(doc = "Sets the sequence number of the ack (Builder pattern)")
28        )
29      )]
30      sequence_number: u32,
31      /// Source target, used for a direct reply
32      #[viewit(
33        getter(const, style = "ref", attrs(doc = "Returns the source node of the ping message")),
34        setter(attrs(doc = "Sets the source node of the ping message (Builder pattern)"))
35      )]
36      source: Node<I, A>,
37
38      /// [`Node`] is sent so the target can verify they are
39      /// the intended recipient. This is to protect again an agent
40      /// restart with a new name.
41      #[viewit(
42        getter(const, style = "ref", attrs(doc = "Returns the target node of the ping message")),
43        setter(attrs(doc = "Sets the target node of the ping message (Builder pattern)"))
44      )]
45      target: Node<I, A>,
46    }
47
48    impl<I, A> $name<I, A> {
49      /// Create a new message
50      #[inline]
51      pub const fn new(sequence_number: u32, source: Node<I, A>, target: Node<I, A>) -> Self {
52        Self {
53          sequence_number,
54          source,
55          target,
56        }
57      }
58
59      /// Sets the sequence number of the message
60      #[inline]
61      pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
62        self.sequence_number = sequence_number;
63        self
64      }
65
66      /// Sets the source node of the message
67      #[inline]
68      pub fn set_source(&mut self, source: Node<I, A>) -> &mut Self {
69        self.source = source;
70        self
71      }
72
73      /// Sets the target node of the message
74      #[inline]
75      pub fn set_target(&mut self, target: Node<I, A>) -> &mut Self {
76        self.target = target;
77        self
78      }
79    }
80
81    impl<I: CheapClone, A: CheapClone> CheapClone for $name<I, A> {
82      fn cheap_clone(&self) -> Self {
83        Self {
84          sequence_number: self.sequence_number,
85          source: self.source.cheap_clone(),
86          target: self.target.cheap_clone(),
87        }
88      }
89    }
90
91    #[cfg(feature = "rkyv")]
92    const _: () = {
93      use core::fmt::Debug;
94      use rkyv::Archive;
95
96      paste::paste! {
97        impl<I: Debug + Archive, A: Debug + Archive> core::fmt::Debug for [< Archived $name >] <I, A>
98        where
99          I::Archived: Debug,
100          A::Archived: Debug,
101        {
102          fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103            f.debug_struct(std::any::type_name::<Self>())
104              .field("sequence_number", &self.sequence_number)
105              .field("target", &self.target)
106              .field("source", &self.source)
107              .finish()
108          }
109        }
110
111        impl<I: Archive, A: Archive> PartialEq for [< Archived $name >] <I, A>
112        where
113          I::Archived: PartialEq,
114          A::Archived: PartialEq,
115        {
116          fn eq(&self, other: &Self) -> bool {
117            self.sequence_number == other.sequence_number
118              && self.target == other.target
119              && self.source == other.source
120          }
121        }
122
123        impl<I: Archive, A: Archive> Eq for [< Archived $name >] <I, A>
124        where
125          I::Archived: Eq,
126          A::Archived: Eq,
127        {
128        }
129
130        impl<I: Archive, A: Archive> Clone for [< Archived $name >] <I, A>
131        where
132          I::Archived: Clone,
133          A::Archived: Clone,
134        {
135          fn clone(&self) -> Self {
136            Self {
137              sequence_number: self.sequence_number,
138              target: self.target.clone(),
139              source: self.source.clone(),
140            }
141          }
142        }
143
144        impl<I: Archive, A: Archive> core::hash::Hash for [< Archived $name >] <I, A>
145        where
146          I::Archived: core::hash::Hash,
147          A::Archived: core::hash::Hash,
148        {
149          fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
150            self.sequence_number.hash(state);
151            self.target.hash(state);
152            self.source.hash(state);
153          }
154        }
155      }
156    };
157
158    paste::paste! {
159      #[doc = concat!("Error when transforming a [`", stringify!($name), "`]")]
160      #[derive(thiserror::Error)]
161      pub enum [< $name TransformError >]<I: Transformable, A: Transformable> {
162        /// Error transforming the source node
163        #[error("source node: {0}")]
164        Source(<Node<I, A> as Transformable>::Error),
165        /// Error transforming the target node
166        #[error("target node: {0}")]
167        Target(<Node<I, A> as Transformable>::Error),
168        /// Encode buffer is too small
169        #[error("encode buffer is too small")]
170        BufferTooSmall,
171        /// Not enough bytes to decode
172        #[error("not enough bytes to decode")]
173        NotEnoughBytes,
174        /// The encoded bytes is too large
175        #[error("the encoded bytes is too large")]
176        TooLarge,
177      }
178
179      impl<I: Transformable, A: Transformable> core::fmt::Debug for [< $name TransformError >]<I, A> {
180        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181          write!(f, "{}", self)
182        }
183      }
184
185      impl<I: Transformable, A: Transformable> Transformable for $name<I, A> {
186        type Error = [< $name TransformError >]<I, A>;
187
188        fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
189          let encoded_len = self.encoded_len();
190          if encoded_len as u64 > u32::MAX as u64 {
191            return Err(Self::Error::TooLarge);
192          }
193
194          if dst.len() < encoded_len {
195            return Err(Self::Error::BufferTooSmall);
196          }
197
198          let mut offset = 0;
199          NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
200          offset += MAX_ENCODED_LEN_SIZE;
201          NetworkEndian::write_u32(&mut dst[offset..], self.sequence_number);
202          offset += core::mem::size_of::<u32>();
203          offset += self.source.encode(&mut dst[offset..]).map_err(Self::Error::Source)?;
204          offset += self.target.encode(&mut dst[offset..]).map_err(Self::Error::Target)?;
205
206          debug_assert_eq!(
207            offset, encoded_len,
208            "expect bytes written ({encoded_len}) not match actual bytes writtend ({offset})"
209          );
210          Ok(offset)
211        }
212
213        fn encoded_len(&self) -> usize {
214          MAX_ENCODED_LEN_SIZE
215            + core::mem::size_of::<u32>()
216            + self.source.encoded_len()
217            + self.target.encoded_len()
218        }
219
220        fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
221        where
222          Self: Sized
223        {
224          if src.len() < MAX_ENCODED_LEN_SIZE {
225            return Err(Self::Error::NotEnoughBytes);
226          }
227          let encoded_len = NetworkEndian::read_u32(src) as usize;
228          if src.len() < encoded_len {
229            return Err(Self::Error::NotEnoughBytes);
230          }
231          let mut offset = MAX_ENCODED_LEN_SIZE;
232          let sequence_number = NetworkEndian::read_u32(&src[offset..]);
233          offset += core::mem::size_of::<u32>();
234          let (source_len, source) = Node::decode(&src[offset..]).map_err(Self::Error::Source)?;
235          offset += source_len;
236          let (target_len, target) = Node::decode(&src[offset..]).map_err(Self::Error::Target)?;
237          offset += target_len;
238
239          debug_assert_eq!(
240            offset, encoded_len,
241            "expect bytes read ({encoded_len}) not match actual bytes read ({offset})"
242          );
243          Ok((offset, Self { sequence_number, source, target }))
244        }
245      }
246    }
247
248    #[cfg(test)]
249    const _: () = {
250      use rand::{Rng, distr::Alphanumeric, rng, random};
251
252      impl $name<smol_str::SmolStr, std::net::SocketAddr> {
253        pub(crate) fn generate(size: usize) -> Self {
254          let trng = rng();
255          let source = trng.sample_iter(&Alphanumeric).take(size).collect::<Vec<u8>>();
256          let source = String::from_utf8(source).unwrap();
257          let source = Node::new(source.into(), format!("127.0.0.1:{}", rng().random_range(0..65535))
258          .parse()
259          .unwrap());
260          let trng = rng();
261          let target = trng.sample_iter(&Alphanumeric).take(size).collect::<Vec<u8>>();
262          let target = String::from_utf8(target).unwrap();
263          let target = Node::new(target.into(), format!("127.0.0.1:{}", rng().random_range(0..65535)).parse().unwrap());
264
265          Self {
266            sequence_number: random(),
267            source,
268            target,
269          }
270        }
271      }
272    };
273  };
274}
275
276bail_ping!(
277  #[doc = "Ping is sent to a target to check if it is alive"]
278  Ping
279);
280bail_ping!(
281  #[doc = "IndirectPing is sent to a target to check if it is alive"]
282  IndirectPing
283);
284
285impl<I, A> From<IndirectPing<I, A>> for Ping<I, A> {
286  fn from(ping: IndirectPing<I, A>) -> Self {
287    Self {
288      sequence_number: ping.sequence_number,
289      source: ping.source,
290      target: ping.target,
291    }
292  }
293}
294
295#[cfg(test)]
296mod tests {
297  use super::*;
298
299  #[test]
300  fn test_ping() {
301    for i in 0..100 {
302      let ping = Ping::<_, std::net::SocketAddr>::generate(i);
303      let mut buf = vec![0; ping.encoded_len()];
304      let encoded_len = ping.encode(&mut buf).unwrap();
305      assert_eq!(encoded_len, ping.encoded_len());
306      let (readed, decoded) = Ping::<_, std::net::SocketAddr>::decode(&buf).unwrap();
307      assert_eq!(readed, encoded_len);
308      assert_eq!(decoded, ping);
309    }
310  }
311
312  #[test]
313  fn test_indirect_ping() {
314    for i in 0..100 {
315      let ping = IndirectPing::<_, std::net::SocketAddr>::generate(i);
316      let mut buf = vec![0; ping.encoded_len()];
317      let encoded_len = ping.encode(&mut buf).unwrap();
318      assert_eq!(encoded_len, ping.encoded_len());
319      let (readed, decoded) = IndirectPing::<_, std::net::SocketAddr>::decode(&buf).unwrap();
320      assert_eq!(readed, encoded_len);
321      assert_eq!(decoded, ping);
322    }
323  }
324}