memberlist_proto/
push_pull.rs

1use core::marker::PhantomData;
2
3use bytes::Bytes;
4use nodecraft::CheapClone;
5use triomphe::Arc;
6
7use crate::RepeatedDecoder;
8
9use super::{Data, DataRef, DecodeError, EncodeError, WireType, merge, skip};
10
11mod state;
12pub use state::*;
13
14/// Push pull message.
15#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
16#[derive(Debug, PartialEq, Eq, Hash)]
17#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
18pub struct PushPull<I, A> {
19  /// Whether the push pull message is a join message.
20  #[viewit(
21    getter(
22      const,
23      attrs(doc = "Returns whether the push pull message is a join message")
24    ),
25    setter(
26      const,
27      attrs(doc = "Sets whether the push pull message is a join message (Builder pattern)")
28    )
29  )]
30  join: bool,
31  /// The states of the push pull message.
32  #[viewit(
33    getter(
34      const,
35      style = "ref",
36      attrs(doc = "Returns the states of the push pull message")
37    ),
38    setter(attrs(doc = "Sets the states of the push pull message (Builder pattern)"))
39  )]
40  #[cfg_attr(any(feature = "arbitrary", test), arbitrary(with = crate::arbitrary_impl::triomphe_arc))]
41  states: Arc<[PushNodeState<I, A>]>,
42  /// The user data of the push pull message.
43  #[viewit(
44    getter(
45      const,
46      style = "ref",
47      attrs(doc = "Returns the user data of the push pull message")
48    ),
49    setter(attrs(doc = "Sets the user data of the push pull message (Builder pattern)"))
50  )]
51  #[cfg_attr(any(feature = "arbitrary", test), arbitrary(with = crate::arbitrary_impl::bytes))]
52  user_data: Bytes,
53}
54
55impl<I, A> Clone for PushPull<I, A> {
56  fn clone(&self) -> Self {
57    Self {
58      join: self.join,
59      states: self.states.clone(),
60      user_data: self.user_data.clone(),
61    }
62  }
63}
64
65impl<I, A> CheapClone for PushPull<I, A> {
66  fn cheap_clone(&self) -> Self {
67    Self {
68      join: self.join,
69      states: self.states.clone(),
70      user_data: self.user_data.clone(),
71    }
72  }
73}
74
75const JOIN_TAG: u8 = 1;
76const JOIN_BYTE: u8 = merge(WireType::Varint, JOIN_TAG);
77const STATES_TAG: u8 = 2;
78const STATES_BYTE: u8 = merge(WireType::LengthDelimited, STATES_TAG);
79const USER_DATA_TAG: u8 = 3;
80const USER_DATA_BYTE: u8 = merge(WireType::LengthDelimited, USER_DATA_TAG);
81
82impl<I, A> PushPull<I, A> {
83  /// Create a new [`PushPull`] message.
84  #[inline]
85  pub fn new(join: bool, states: impl Iterator<Item = PushNodeState<I, A>>) -> Self {
86    Self {
87      states: Arc::from_iter(states),
88      user_data: Bytes::new(),
89      join,
90    }
91  }
92
93  /// Consumes the [`PushPull`] and returns the states and user data.
94  #[inline]
95  pub fn into_components(self) -> (bool, Bytes, Arc<[PushNodeState<I, A>]>) {
96    (self.join, self.user_data, self.states)
97  }
98}
99
100impl<I, A> Data for PushPull<I, A>
101where
102  I: Data,
103  A: Data,
104{
105  type Ref<'a> = PushPullRef<'a, I::Ref<'a>, A::Ref<'a>>;
106
107  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
108  where
109    Self: Sized,
110  {
111    val
112      .states
113      .iter::<PushNodeState<I, A>>()
114      .map(|res| res.and_then(PushNodeState::from_ref))
115      .collect::<Result<Arc<[_]>, DecodeError>>()
116      .map(|states| Self {
117        join: val.join,
118        states,
119        user_data: Bytes::copy_from_slice(val.user_data),
120      })
121  }
122
123  fn encoded_len(&self) -> usize {
124    let mut len = 0;
125    if self.join {
126      len += 1 + 1; // join
127    }
128
129    for i in self.states.iter() {
130      len += 1 + i.encoded_len_with_length_delimited();
131    }
132
133    let user_data_len = self.user_data.len();
134
135    if user_data_len != 0 {
136      len += 1 + self.user_data.encoded_len_with_length_delimited();
137    }
138
139    len
140  }
141
142  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
143    macro_rules! bail {
144      ($this:ident($offset:expr, $len:ident)) => {
145        if $offset >= $len {
146          return Err(EncodeError::insufficient_buffer($offset, $len));
147        }
148      };
149    }
150
151    let mut offset = 0;
152    let len = buf.len();
153    if self.join {
154      if len < 2 {
155        return Err(EncodeError::insufficient_buffer(self.encoded_len(), len));
156      }
157      buf[offset] = JOIN_BYTE;
158      offset += 1;
159      buf[offset] = 1;
160      offset += 1;
161    }
162
163    for i in self.states.iter() {
164      bail!(self(offset, len));
165      buf[offset] = STATES_BYTE;
166      offset += 1;
167      {
168        offset += i
169          .encode_length_delimited(&mut buf[offset..])
170          .map_err(|e| e.update(self.encoded_len(), len))?
171      }
172    }
173
174    let user_data_len = self.user_data.len();
175    if user_data_len != 0 {
176      bail!(self(offset, len));
177      buf[offset] = USER_DATA_BYTE;
178      offset += 1;
179      offset += self.user_data.encode_length_delimited(&mut buf[offset..])?;
180    }
181
182    #[cfg(debug_assertions)]
183    super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
184    Ok(offset)
185  }
186}
187
188/// The reference type of Push pull message.
189#[viewit::viewit(getters(vis_all = "pub"), setters(skip))]
190#[derive(Debug, PartialEq, Eq, Hash)]
191pub struct PushPullRef<'a, I, A> {
192  /// Whether the push pull message is a join message.
193  #[viewit(getter(
194    const,
195    attrs(doc = "Returns whether the push pull message is a join message")
196  ))]
197  join: bool,
198  /// The states of the push pull message.
199  #[viewit(getter(
200    const,
201    style = "ref",
202    attrs(doc = "Returns the states of the push pull message")
203  ))]
204  states: RepeatedDecoder<'a>,
205  /// The user data of the push pull message.
206  #[viewit(getter(const, attrs(doc = "Returns the user data of the push pull message")))]
207  user_data: &'a [u8],
208
209  #[viewit(getter(skip))]
210  _m: PhantomData<(I, A)>,
211}
212
213impl<I, A> Clone for PushPullRef<'_, I, A> {
214  fn clone(&self) -> Self {
215    *self
216  }
217}
218
219impl<I, A> Copy for PushPullRef<'_, I, A> {}
220
221impl<'a, I, A> DataRef<'a, PushPull<I, A>> for PushPullRef<'a, I::Ref<'a>, A::Ref<'a>>
222where
223  I: Data,
224  A: Data,
225{
226  fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> {
227    let mut offset = 0;
228    let mut join = None;
229    let mut node_state_offsets = None;
230    let mut num_states = 0;
231    let mut user_data = None;
232
233    while offset < src.len() {
234      match src[offset] {
235        JOIN_BYTE => {
236          if join.is_some() {
237            return Err(DecodeError::duplicate_field("PushPull", "join", JOIN_TAG));
238          }
239          offset += 1;
240
241          if offset >= src.len() {
242            return Err(DecodeError::buffer_underflow());
243          }
244          let (read, val) = <bool as Data>::decode(&src[offset..])?;
245          offset += read;
246          join = Some(val);
247        }
248        STATES_BYTE => {
249          let readed = super::skip("PushPull", &src[offset..])?;
250          if let Some((ref mut fnso, ref mut lnso)) = node_state_offsets {
251            if *fnso > offset {
252              *fnso = offset;
253            }
254
255            if *lnso < offset + readed {
256              *lnso = offset + readed;
257            }
258          } else {
259            node_state_offsets = Some((offset, offset + readed));
260          }
261          num_states += 1;
262          offset += readed;
263        }
264        USER_DATA_BYTE => {
265          if user_data.is_some() {
266            return Err(DecodeError::duplicate_field(
267              "PushPull",
268              "user_data",
269              USER_DATA_TAG,
270            ));
271          }
272          offset += 1;
273
274          let (readed, value) = <&[u8] as DataRef<Bytes>>::decode_length_delimited(&src[offset..])?;
275          offset += readed;
276          user_data = Some(value);
277        }
278        _ => offset += skip("PushPull", &src[offset..])?,
279      }
280    }
281
282    let join = join.unwrap_or_default();
283    let user_data = user_data.unwrap_or_default();
284    Ok((
285      offset,
286      Self {
287        join,
288        states: {
289          let val =
290            RepeatedDecoder::new(STATES_TAG, WireType::LengthDelimited, src).with_nums(num_states);
291          if let Some((first, last)) = node_state_offsets {
292            val.with_offsets(first, last)
293          } else {
294            val
295          }
296        },
297        user_data,
298        _m: PhantomData,
299      },
300    ))
301  }
302}
303
304#[cfg(test)]
305mod tests {
306  use std::net::SocketAddr;
307
308  use arbitrary::{Arbitrary, Unstructured};
309
310  use crate::{DelegateVersion, Meta, ProtocolVersion, State};
311
312  use super::*;
313
314  #[test]
315  fn test_push_pull_clone_and_cheap_clone() {
316    let mut data = vec![0; 1024];
317    rand::fill(&mut data[..]);
318    let mut data = Unstructured::new(&data);
319
320    let push_pull = PushPull::<String, SocketAddr>::arbitrary(&mut data).unwrap();
321    let cloned = push_pull.clone();
322    let cheap_cloned = push_pull.cheap_clone();
323    assert_eq!(cloned, push_pull);
324    assert_eq!(cheap_cloned, push_pull);
325    let cloned1 = format!("{:?}", cloned);
326    let cheap_cloned1 = format!("{:?}", cheap_cloned);
327    assert_eq!(cloned1, cheap_cloned1);
328  }
329
330  #[test]
331  fn test_push_node_state() {
332    let mut data = vec![0; 1024];
333    rand::fill(&mut data[..]);
334    let mut data = Unstructured::new(&data);
335
336    let mut state = PushNodeState::<String, SocketAddr>::arbitrary(&mut data).unwrap();
337    state.set_id("test".into());
338    assert_eq!(state.id(), "test");
339    state.set_address(SocketAddr::from(([127, 0, 0, 1], 8080)));
340    assert_eq!(state.address(), &SocketAddr::from(([127, 0, 0, 1], 8080)));
341    state.set_meta(Meta::try_from("test").unwrap());
342    assert_eq!(state.meta(), &Meta::try_from("test").unwrap());
343    state.set_incarnation(100);
344    assert_eq!(state.incarnation(), 100);
345
346    state.set_state(State::Alive);
347    assert_eq!(state.state(), State::Alive);
348
349    state.set_protocol_version(ProtocolVersion::V1);
350    assert_eq!(state.protocol_version(), ProtocolVersion::V1);
351
352    state.set_delegate_version(DelegateVersion::V1);
353    assert_eq!(state.delegate_version(), DelegateVersion::V1);
354  }
355}