memberlist_proto/
ack.rs

1use bytes::Bytes;
2
3use super::{Data, DataRef, DecodeError, EncodeError, WireType, merge, skip};
4
5/// Ack response is sent for a ping
6#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
9pub struct Ack {
10  /// The sequence number of the ack
11  #[viewit(
12    getter(const, attrs(doc = "Returns the sequence number of the ack")),
13    setter(
14      const,
15      attrs(doc = "Sets the sequence number of the ack (Builder pattern)")
16    )
17  )]
18  sequence_number: u32,
19  /// The payload of the ack
20  #[viewit(
21    getter(const, style = "ref", attrs(doc = "Returns the payload of the ack")),
22    setter(attrs(doc = "Sets the payload of the ack (Builder pattern)"))
23  )]
24  #[cfg_attr(any(feature = "arbitrary", test), arbitrary(with = crate::arbitrary_impl::bytes))]
25  payload: Bytes,
26}
27
28impl Ack {
29  const SEQUENCE_NUMBER_TAG: u8 = 1;
30  const SEQUENCE_NUMBER_BYTE: u8 = merge(WireType::Varint, Self::SEQUENCE_NUMBER_TAG);
31  const PAYLOAD_TAG: u8 = 2;
32  const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, Self::PAYLOAD_TAG);
33
34  /// Decodes the sequence number from the given buffer
35  #[inline]
36  pub fn decode_sequence_number(src: &[u8]) -> Result<(usize, u32), DecodeError> {
37    let mut offset = 0;
38    let mut sequence_number = None;
39    let buf_len = src.len();
40
41    while offset < buf_len {
42      match src[offset] {
43        Self::SEQUENCE_NUMBER_BYTE => {
44          offset += 1;
45          let (bytes_read, value) = <u32 as DataRef<u32>>::decode(&src[offset..])?;
46          offset += bytes_read;
47          sequence_number = Some(value);
48        }
49        _ => offset += skip("Ack", &src[offset..])?,
50      }
51    }
52
53    // Ensure the sequence_number was found
54    Ok((offset, sequence_number.unwrap_or(0)))
55  }
56
57  /// Create a new ack response with the given sequence number and empty payload.
58  #[inline]
59  pub const fn new(sequence_number: u32) -> Self {
60    Self {
61      sequence_number,
62      payload: Bytes::new(),
63    }
64  }
65
66  /// Sets the sequence number of the ack
67  #[inline]
68  pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
69    self.sequence_number = sequence_number;
70    self
71  }
72
73  /// Sets the payload of the ack
74  #[inline]
75  pub fn set_payload(&mut self, payload: Bytes) -> &mut Self {
76    self.payload = payload;
77    self
78  }
79
80  /// Consumes the [`Ack`] and returns the sequence number and payload
81  #[inline]
82  pub fn into_components(self) -> (u32, Bytes) {
83    (self.sequence_number, self.payload)
84  }
85}
86
87impl Data for Ack {
88  type Ref<'a> = AckRef<'a>;
89
90  #[inline]
91  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
92  where
93    Self: Sized,
94  {
95    Ok(Self {
96      sequence_number: val.sequence_number,
97      payload: Bytes::copy_from_slice(val.payload),
98    })
99  }
100
101  fn encoded_len(&self) -> usize {
102    let mut len = 1 + self.sequence_number.encoded_len();
103    let payload_len = self.payload.len();
104    if payload_len != 0 {
105      len += 1 + self.payload.encoded_len_with_length_delimited();
106    }
107    len
108  }
109
110  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
111    macro_rules! bail {
112      ($offset:expr, $remaining:ident) => {
113        if $offset >= $remaining {
114          return Err(EncodeError::insufficient_buffer(
115            self.encoded_len(),
116            $remaining,
117          ));
118        }
119      };
120    }
121
122    let len = buf.len();
123    let mut offset = 0;
124    bail!(offset, len);
125    buf[offset] = Self::SEQUENCE_NUMBER_BYTE;
126    offset += 1;
127    offset += self
128      .sequence_number
129      .encode(&mut buf[offset..])
130      .map_err(|e| e.update(self.encoded_len(), len))?;
131
132    let payload_len = self.payload.len();
133    if payload_len != 0 {
134      bail!(offset, len);
135      buf[offset] = Self::PAYLOAD_BYTE;
136      offset += 1;
137      offset += self
138        .payload
139        .encode_length_delimited(&mut buf[offset..])
140        .map_err(|e| e.update(self.encoded_len(), len))?;
141    }
142
143    #[cfg(debug_assertions)]
144    super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
145    Ok(offset)
146  }
147}
148
149/// The reference to an [`Ack`] message
150#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
151pub struct AckRef<'a> {
152  sequence_number: u32,
153  payload: &'a [u8],
154}
155
156impl<'a> AckRef<'a> {
157  /// Create a new ack reference with the given sequence number and payload
158  #[inline]
159  pub const fn new(sequence_number: u32, payload: &'a [u8]) -> Self {
160    Self {
161      sequence_number,
162      payload,
163    }
164  }
165
166  /// Returns the sequence number of the ack
167  #[inline]
168  pub const fn sequence_number(&self) -> u32 {
169    self.sequence_number
170  }
171
172  /// Returns the payload of the ack
173  #[inline]
174  pub const fn payload(&self) -> &'a [u8] {
175    self.payload
176  }
177}
178
179impl<'a> DataRef<'a, Ack> for AckRef<'a> {
180  fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError>
181  where
182    Self: Sized,
183  {
184    let mut offset = 0;
185    let mut sequence_number = None;
186    let mut payload = None;
187
188    while offset < src.len() {
189      // Parse the tag and wire type
190      match src[offset] {
191        Ack::SEQUENCE_NUMBER_BYTE => {
192          if sequence_number.is_some() {
193            return Err(DecodeError::duplicate_field(
194              "Ack",
195              "sequence_number",
196              Ack::SEQUENCE_NUMBER_TAG,
197            ));
198          }
199          offset += 1;
200          let (bytes_read, value) = <u32 as DataRef<u32>>::decode(&src[offset..])?;
201          offset += bytes_read;
202          sequence_number = Some(value);
203        }
204        Ack::PAYLOAD_BYTE => {
205          if payload.is_some() {
206            return Err(DecodeError::duplicate_field(
207              "Ack",
208              "payload",
209              Ack::PAYLOAD_TAG,
210            ));
211          }
212          offset += 1;
213          let (readed, data) = <&[u8] as DataRef<Bytes>>::decode_length_delimited(&src[offset..])?;
214          offset += readed;
215          payload = Some(data);
216        }
217        _ => offset += skip("Ack", &src[offset..])?,
218      }
219    }
220
221    Ok((
222      offset,
223      AckRef {
224        sequence_number: sequence_number.unwrap_or(0),
225        payload: payload.unwrap_or_default(),
226      },
227    ))
228  }
229}
230
231/// Nack response is sent for an indirect ping when the pinger doesn't hear from
232/// the ping-ee within the configured timeout. This lets the original node know
233/// that the indirect ping attempt happened but didn't succeed.
234#[viewit::viewit(
235  vis_all = "pub(crate)",
236  getters(vis_all = "pub"),
237  setters(vis_all = "pub", prefix = "with")
238)]
239#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
240#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
241#[repr(transparent)]
242pub struct Nack {
243  #[viewit(
244    getter(const, attrs(doc = "Returns the sequence number of the nack")),
245    setter(
246      const,
247      attrs(doc = "Sets the sequence number of the nack (Builder pattern)")
248    )
249  )]
250  sequence_number: u32,
251}
252
253impl Nack {
254  const SEQUENCE_NUMBER_TAG: u8 = 1;
255  const SEQUENCE_NUMBER_BYTE: u8 = merge(WireType::Varint, Self::SEQUENCE_NUMBER_TAG);
256
257  /// Create a new nack response with the given sequence number.
258  #[inline]
259  pub const fn new(sequence_number: u32) -> Self {
260    Self { sequence_number }
261  }
262
263  /// Sets the sequence number of the nack response
264  #[inline]
265  pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
266    self.sequence_number = sequence_number;
267    self
268  }
269}
270
271impl<'a> DataRef<'a, Self> for Nack {
272  fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError>
273  where
274    Self: Sized,
275  {
276    let mut sequence_number = None;
277    let mut offset = 0;
278    while offset < src.len() {
279      match src[offset] {
280        Self::SEQUENCE_NUMBER_BYTE => {
281          if sequence_number.is_some() {
282            return Err(DecodeError::duplicate_field(
283              "Nack",
284              "sequence_number",
285              Self::SEQUENCE_NUMBER_TAG,
286            ));
287          }
288          offset += 1;
289
290          let (bytes_read, value) = <u32 as DataRef<u32>>::decode(&src[offset..])?;
291          offset += bytes_read;
292          sequence_number = Some(value);
293        }
294        _ => offset += skip("Nack", &src[offset..])?,
295      }
296    }
297
298    Ok((
299      offset,
300      Self {
301        sequence_number: sequence_number.unwrap_or(0),
302      },
303    ))
304  }
305}
306
307impl Data for Nack {
308  type Ref<'a> = Self;
309
310  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
311  where
312    Self: Sized,
313  {
314    Ok(val)
315  }
316
317  fn encoded_len(&self) -> usize {
318    1 + self.sequence_number.encoded_len()
319  }
320
321  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
322    let mut offset = 0;
323    if buf.is_empty() {
324      return Err(EncodeError::insufficient_buffer(self.encoded_len(), 0));
325    }
326    buf[offset] = Self::SEQUENCE_NUMBER_BYTE;
327    offset += 1;
328    offset += self.sequence_number.encode(&mut buf[offset..])?;
329    #[cfg(debug_assertions)]
330    super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
331    Ok(offset)
332  }
333}
334
335#[cfg(test)]
336mod tests {
337  use super::*;
338  use arbitrary::{Arbitrary, Unstructured};
339
340  #[test]
341  fn test_access() {
342    let mut data = vec![0; 1024];
343    rand::fill(&mut data[..]);
344    let mut data = Unstructured::new(&data);
345    let mut ack = Ack::arbitrary(&mut data).unwrap();
346    ack.set_payload(Bytes::from_static(b"hello world"));
347    ack.set_sequence_number(100);
348    assert_eq!(ack.sequence_number(), 100);
349    assert_eq!(ack.payload(), &Bytes::from_static(b"hello world"));
350
351    let mut nack = Nack::arbitrary(&mut data).unwrap();
352    nack.set_sequence_number(100);
353    assert_eq!(nack.sequence_number(), 100);
354  }
355}