memberlist_types/
ack.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use bytes::Bytes;
3use transformable::{utils::*, Transformable};
4
5const MAX_INLINED_BYTES: usize = 64;
6
7/// Ack response is sent for a ping
8#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
11#[cfg_attr(
12  feature = "rkyv",
13  derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
14)]
15#[cfg_attr(
16  feature = "rkyv",
17  rkyv(compare(PartialEq), derive(Debug, PartialEq, Eq, Hash),)
18)]
19pub struct Ack {
20  /// The sequence number of the ack
21  #[viewit(
22    getter(const, attrs(doc = "Returns the sequence number of the ack")),
23    setter(
24      const,
25      attrs(doc = "Sets the sequence number of the ack (Builder pattern)")
26    )
27  )]
28  sequence_number: u32,
29  /// The payload of the ack
30  #[viewit(
31    getter(const, style = "ref", attrs(doc = "Returns the payload of the ack")),
32    setter(attrs(doc = "Sets the payload of the ack (Builder pattern)"))
33  )]
34  payload: Bytes,
35}
36
37impl Ack {
38  /// Create a new ack response with the given sequence number and empty payload.
39  #[inline]
40  pub const fn new(sequence_number: u32) -> Self {
41    Self {
42      sequence_number,
43      payload: Bytes::new(),
44    }
45  }
46
47  /// Sets the sequence number of the ack
48  #[inline]
49  pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
50    self.sequence_number = sequence_number;
51    self
52  }
53
54  /// Sets the payload of the ack
55  #[inline]
56  pub fn set_payload(&mut self, payload: Bytes) -> &mut Self {
57    self.payload = payload;
58    self
59  }
60
61  /// Consumes the [`Ack`] and returns the sequence number and payload
62  #[inline]
63  pub fn into_components(self) -> (u32, Bytes) {
64    (self.sequence_number, self.payload)
65  }
66}
67
68/// Error that can occur when transforming an ack response.
69#[derive(Debug, thiserror::Error)]
70pub enum AckTransformError {
71  /// The buffer did not contain enough bytes to encode an ack response.
72  #[error("encode buffer too small")]
73  InsufficientBuffer(#[from] InsufficientBuffer),
74  /// The buffer did not contain enough bytes to decode an ack response.
75  #[error("the buffer did not contain enough bytes to decode Ack")]
76  NotEnoughBytes,
77  /// Varint decoding error
78  #[error("fail to decode sequence number: {0}")]
79  DecodeVarint(#[from] DecodeVarintError),
80}
81
82impl Transformable for Ack {
83  type Error = AckTransformError;
84
85  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
86    let encoded_len = self.encoded_len();
87
88    if encoded_len > dst.len() {
89      return Err(Self::Error::InsufficientBuffer(
90        InsufficientBuffer::with_information(encoded_len as u64, dst.len() as u64),
91      ));
92    }
93
94    let mut offset = 0;
95    NetworkEndian::write_u32(dst, encoded_len as u32);
96    offset += core::mem::size_of::<u32>();
97    NetworkEndian::write_u32(&mut dst[offset..], self.sequence_number);
98    offset += core::mem::size_of::<u32>();
99
100    let payload_size = self.payload.len();
101    if !self.payload.is_empty() {
102      dst[offset..offset + payload_size].copy_from_slice(&self.payload);
103      offset += payload_size;
104    }
105
106    debug_assert_eq!(
107      offset, encoded_len,
108      "expect bytes written ({encoded_len}) not match actual bytes writtend ({offset})"
109    );
110    Ok(offset)
111  }
112
113  fn encoded_len(&self) -> usize {
114    core::mem::size_of::<u32>() + core::mem::size_of::<u32>() + self.payload.len()
115  }
116
117  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
118  where
119    Self: Sized,
120  {
121    let mut offset = 0;
122    if core::mem::size_of::<u32>() > src.len() {
123      return Err(Self::Error::NotEnoughBytes);
124    }
125
126    let total_len = NetworkEndian::read_u32(&src[offset..]);
127    offset += core::mem::size_of::<u32>();
128    let sequence_number = NetworkEndian::read_u32(&src[offset..]);
129    offset += core::mem::size_of::<u32>();
130
131    if total_len as usize == 2 * core::mem::size_of::<u32>() {
132      return Ok((
133        offset,
134        Self {
135          sequence_number,
136          payload: Bytes::new(),
137        },
138      ));
139    }
140
141    if total_len as usize - core::mem::size_of::<u32>() > src.len() {
142      return Err(Self::Error::NotEnoughBytes);
143    }
144
145    let payload = Bytes::copy_from_slice(&src[offset..total_len as usize]);
146    Ok((
147      total_len as usize,
148      Self {
149        sequence_number,
150        payload,
151      },
152    ))
153  }
154
155  fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
156  where
157    Self: Sized,
158  {
159    let mut buf = [0; 8];
160    reader.read_exact(&mut buf)?;
161    let total_len = NetworkEndian::read_u32(&buf) as usize;
162    let sequence_number = NetworkEndian::read_u32(&buf[core::mem::size_of::<u32>()..]);
163
164    if total_len == 2 * core::mem::size_of::<u32>() {
165      return Ok((
166        total_len,
167        Self {
168          sequence_number,
169          payload: Bytes::new(),
170        },
171      ));
172    }
173
174    let payload_len = total_len - core::mem::size_of::<u32>() * 2;
175    if payload_len <= MAX_INLINED_BYTES {
176      let mut buf = [0; MAX_INLINED_BYTES];
177      reader.read_exact(&mut buf[..payload_len])?;
178      let payload = Bytes::copy_from_slice(&buf[..payload_len]);
179      Ok((
180        total_len,
181        Self {
182          sequence_number,
183          payload,
184        },
185      ))
186    } else {
187      let mut payload = vec![0; payload_len];
188      reader.read_exact(&mut payload)?;
189      Ok((
190        total_len,
191        Self {
192          sequence_number,
193          payload: payload.into(),
194        },
195      ))
196    }
197  }
198
199  async fn decode_from_async_reader<R: futures::AsyncRead + Send + Unpin>(
200    reader: &mut R,
201  ) -> std::io::Result<(usize, Self)>
202  where
203    Self: Sized,
204  {
205    use futures::AsyncReadExt;
206
207    let mut buf = [0; 8];
208    reader.read_exact(&mut buf).await?;
209
210    let total_len = NetworkEndian::read_u32(&buf) as usize;
211    let sequence_number = NetworkEndian::read_u32(&buf[core::mem::size_of::<u32>()..]);
212
213    if total_len == 2 * core::mem::size_of::<u32>() {
214      return Ok((
215        total_len,
216        Self {
217          sequence_number,
218          payload: Bytes::new(),
219        },
220      ));
221    }
222
223    let payload_len = total_len - core::mem::size_of::<u32>() * 2;
224    if payload_len <= MAX_INLINED_BYTES {
225      let mut buf = [0; MAX_INLINED_BYTES];
226      reader.read_exact(&mut buf[..payload_len]).await?;
227      let payload = Bytes::copy_from_slice(&buf[..payload_len]);
228      Ok((
229        total_len,
230        Self {
231          sequence_number,
232          payload,
233        },
234      ))
235    } else {
236      let mut payload = vec![0; payload_len];
237      reader.read_exact(&mut payload).await?;
238      Ok((
239        total_len,
240        Self {
241          sequence_number,
242          payload: payload.into(),
243        },
244      ))
245    }
246  }
247}
248
249/// Nack response is sent for an indirect ping when the pinger doesn't hear from
250/// the ping-ee within the configured timeout. This lets the original node know
251/// that the indirect ping attempt happened but didn't succeed.
252#[viewit::viewit(
253  vis_all = "pub(crate)",
254  getters(vis_all = "pub"),
255  setters(vis_all = "pub", prefix = "with")
256)]
257#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
258#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
259#[cfg_attr(feature = "serde", serde(transparent))]
260#[cfg_attr(
261  feature = "rkyv",
262  derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
263)]
264#[cfg_attr(
265  feature = "rkyv",
266  rkyv(derive(Debug, Clone, PartialEq, Eq, Hash), compare(PartialEq))
267)]
268#[repr(transparent)]
269pub struct Nack {
270  #[viewit(
271    getter(const, attrs(doc = "Returns the sequence number of the nack")),
272    setter(
273      const,
274      attrs(doc = "Sets the sequence number of the nack (Builder pattern)")
275    )
276  )]
277  sequence_number: u32,
278}
279
280impl Nack {
281  /// Create a new nack response with the given sequence number.
282  #[inline]
283  pub const fn new(sequence_number: u32) -> Self {
284    Self { sequence_number }
285  }
286
287  /// Sets the sequence number of the nack response
288  #[inline]
289  pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
290    self.sequence_number = sequence_number;
291    self
292  }
293}
294
295impl Transformable for Nack {
296  type Error = <u32 as Transformable>::Error;
297
298  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
299    <u32 as Transformable>::encode(&self.sequence_number, dst)
300  }
301
302  fn encoded_len(&self) -> usize {
303    <u32 as Transformable>::encoded_len(&self.sequence_number)
304  }
305
306  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
307  where
308    Self: Sized,
309  {
310    let (n, sequence_number) = <u32 as Transformable>::decode(src)?;
311    Ok((n, Self { sequence_number }))
312  }
313
314  async fn encode_to_async_writer<W: futures::io::AsyncWrite + Send + Unpin>(
315    &self,
316    writer: &mut W,
317  ) -> std::io::Result<usize> {
318    <u32 as Transformable>::encode_to_async_writer(&self.sequence_number, writer).await
319  }
320
321  fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
322    <u32 as Transformable>::encode_to_writer(&self.sequence_number, writer)
323  }
324
325  fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
326  where
327    Self: Sized,
328  {
329    <u32 as Transformable>::decode_from_reader(reader)
330      .map(|(n, sequence_number)| (n, Self { sequence_number }))
331  }
332
333  async fn decode_from_async_reader<R: futures::io::AsyncRead + Send + Unpin>(
334    reader: &mut R,
335  ) -> std::io::Result<(usize, Self)>
336  where
337    Self: Sized,
338  {
339    <u32 as Transformable>::decode_from_async_reader(reader)
340      .await
341      .map(|(n, sequence_number)| (n, Self { sequence_number }))
342  }
343}
344
345#[cfg(test)]
346const _: () = {
347  use rand::random;
348
349  impl Ack {
350    /// Create a new ack response with the given sequence number and random payload.
351    #[inline]
352    pub fn random(payload_size: usize) -> Self {
353      let sequence_number = random();
354      let payload = (0..payload_size)
355        .map(|_| random())
356        .collect::<Vec<_>>()
357        .into();
358      Self {
359        sequence_number,
360        payload,
361      }
362    }
363  }
364
365  impl Nack {
366    /// Create a new nack response with the given sequence number.
367    #[inline]
368    pub fn random() -> Self {
369      Self {
370        sequence_number: random(),
371      }
372    }
373  }
374};
375
376#[cfg(test)]
377mod tests {
378  use super::*;
379  use futures::io::Cursor as FCursor;
380  use std::io::Cursor;
381
382  #[tokio::test]
383  async fn test_ack_response_encode_decode() {
384    for i in 0..100 {
385      // Generate and test 100 random instances
386      let ack_response = Ack::random(i);
387      let mut buf = vec![0; ack_response.encoded_len()];
388      let encoded = ack_response.encode(&mut buf).unwrap();
389      assert_eq!(encoded, buf.len());
390      let (read, decoded) = Ack::decode(&buf).unwrap();
391      assert_eq!(read, buf.len());
392      assert_eq!(ack_response.sequence_number, decoded.sequence_number);
393      assert_eq!(ack_response.payload, decoded.payload);
394      let mut cur = Cursor::new(&buf);
395      let (_, decoded) = Ack::decode_from_reader(&mut cur).unwrap();
396      assert_eq!(ack_response.sequence_number, decoded.sequence_number);
397      assert_eq!(ack_response.payload, decoded.payload);
398      let mut cur = FCursor::new(&buf);
399      let (_, decoded) = Ack::decode_from_async_reader(&mut cur).await.unwrap();
400      assert_eq!(ack_response.sequence_number, decoded.sequence_number);
401      assert_eq!(ack_response.payload, decoded.payload);
402
403      // Test encode/decode from reader
404      let mut buf = Vec::new();
405      ack_response.encode_to_writer(&mut buf).unwrap();
406      let mut buf = Cursor::new(buf);
407      let (_, decoded) = Ack::decode_from_reader(&mut buf).unwrap();
408      assert_eq!(ack_response.sequence_number, decoded.sequence_number);
409      assert_eq!(ack_response.payload, decoded.payload);
410
411      // Test encode/decode from async reader
412      let mut buf = Vec::new();
413      ack_response.encode_to_async_writer(&mut buf).await.unwrap();
414      let mut buf = FCursor::new(buf);
415      let (_, decoded) = Ack::decode_from_async_reader(&mut buf).await.unwrap();
416      assert_eq!(ack_response.sequence_number, decoded.sequence_number);
417      assert_eq!(ack_response.payload, decoded.payload);
418    }
419  }
420
421  #[tokio::test]
422  async fn test_nack_response_encode_decode() {
423    for _ in 0..100 {
424      // Generate and test 100 random instances
425      let nack_response = Nack::random();
426      let mut buf = vec![0; nack_response.encoded_len()];
427      let encoded = nack_response.encode(&mut buf).unwrap();
428      assert_eq!(encoded, buf.len());
429      let (read, decoded) = Nack::decode(&buf).unwrap();
430      assert_eq!(read, buf.len());
431      assert_eq!(nack_response.sequence_number, decoded.sequence_number);
432      let mut cur = Cursor::new(&buf);
433      let (_, decoded) = Nack::decode_from_reader(&mut cur).unwrap();
434      assert_eq!(nack_response.sequence_number, decoded.sequence_number);
435      let mut cur = FCursor::new(&buf);
436      let (_, decoded) = Nack::decode_from_async_reader(&mut cur).await.unwrap();
437      assert_eq!(nack_response.sequence_number, decoded.sequence_number);
438
439      // Test encode/decode from reader
440      let mut buf = Vec::new();
441      nack_response.encode_to_writer(&mut buf).unwrap();
442      let mut buf = Cursor::new(buf);
443      let (_, decoded) = Nack::decode_from_reader(&mut buf).unwrap();
444      assert_eq!(nack_response.sequence_number, decoded.sequence_number);
445
446      // Test encode/decode from async reader
447      let mut buf = Vec::new();
448      nack_response
449        .encode_to_async_writer(&mut buf)
450        .await
451        .unwrap();
452      let mut buf = FCursor::new(buf);
453      let (_, decoded) = Nack::decode_from_async_reader(&mut buf).await.unwrap();
454      assert_eq!(nack_response.sequence_number, decoded.sequence_number);
455    }
456  }
457
458  #[test]
459  fn test_access() {
460    let mut ack = Ack::random(100);
461    ack.set_payload(Bytes::from_static(b"hello world"));
462    ack.set_sequence_number(100);
463    assert_eq!(ack.sequence_number(), 100);
464    assert_eq!(ack.payload(), &Bytes::from_static(b"hello world"));
465
466    let mut nack = Nack::random();
467    nack.set_sequence_number(100);
468    assert_eq!(nack.sequence_number(), 100);
469  }
470}