memberlist_types/
alive.rs

1use crate::MetaError;
2
3use super::{
4  version::{UnknownDelegateVersion, UnknownProtocolVersion},
5  DelegateVersion, Meta, ProtocolVersion, MAX_ENCODED_LEN_SIZE,
6};
7
8use byteorder::{ByteOrder, NetworkEndian};
9use nodecraft::{CheapClone, Node, NodeTransformError};
10use transformable::Transformable;
11
12/// Alive message
13#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
16#[cfg_attr(
17  feature = "rkyv",
18  derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
19)]
20#[cfg_attr(feature = "rkyv", rkyv(compare(PartialEq)))]
21pub struct Alive<I, A> {
22  /// The incarnation of the alive message
23  #[viewit(
24    getter(const, attrs(doc = "Returns the incarnation of the alive message")),
25    setter(
26      const,
27      attrs(doc = "Sets the incarnation of the alive message (Builder pattern)")
28    )
29  )]
30  incarnation: u32,
31  /// The meta of the alive message
32  #[viewit(
33    getter(
34      const,
35      style = "ref",
36      attrs(doc = "Returns the meta of the alive message")
37    ),
38    setter(attrs(doc = "Sets the meta of the alive message (Builder pattern)"))
39  )]
40  meta: Meta,
41  /// The node of the alive message
42  #[viewit(
43    getter(
44      const,
45      style = "ref",
46      attrs(doc = "Returns the node of the alive message")
47    ),
48    setter(attrs(doc = "Sets the node of the alive message (Builder pattern)"))
49  )]
50  node: Node<I, A>,
51  /// The protocol version of the alive message is speaking
52  #[viewit(
53    getter(
54      const,
55      attrs(doc = "Returns the protocol version of the alive message is speaking")
56    ),
57    setter(
58      const,
59      attrs(doc = "Sets the protocol version of the alive message is speaking (Builder pattern)")
60    )
61  )]
62  protocol_version: ProtocolVersion,
63  /// The delegate version of the alive message is speaking
64  #[viewit(
65    getter(
66      const,
67      attrs(doc = "Returns the delegate version of the alive message is speaking")
68    ),
69    setter(
70      const,
71      attrs(doc = "Sets the delegate version of the alive message is speaking (Builder pattern)")
72    )
73  )]
74  delegate_version: DelegateVersion,
75}
76
77impl<I, A> Alive<I, A> {
78  /// Construct a new alive message with the given incarnation, meta, node, protocol version and delegate version.
79  #[inline]
80  pub const fn new(incarnation: u32, node: Node<I, A>) -> Self {
81    Self {
82      incarnation,
83      meta: Meta::empty(),
84      node,
85      protocol_version: ProtocolVersion::V1,
86      delegate_version: DelegateVersion::V1,
87    }
88  }
89
90  /// Sets the incarnation of the alive message.
91  #[inline]
92  pub fn set_incarnation(&mut self, incarnation: u32) -> &mut Self {
93    self.incarnation = incarnation;
94    self
95  }
96
97  /// Sets the meta of the alive message.
98  #[inline]
99  pub fn set_meta(&mut self, meta: Meta) -> &mut Self {
100    self.meta = meta;
101    self
102  }
103
104  /// Sets the node of the alive message.
105  #[inline]
106  pub fn set_node(&mut self, node: Node<I, A>) -> &mut Self {
107    self.node = node;
108    self
109  }
110
111  /// Set the protocol version of the alive message is speaking.
112  #[inline]
113  pub fn set_protocol_version(&mut self, protocol_version: ProtocolVersion) -> &mut Self {
114    self.protocol_version = protocol_version;
115    self
116  }
117
118  /// Set the delegate version of the alive message is speaking.
119  #[inline]
120  pub fn set_delegate_version(&mut self, delegate_version: DelegateVersion) -> &mut Self {
121    self.delegate_version = delegate_version;
122    self
123  }
124}
125
126impl<I: CheapClone, A: CheapClone> CheapClone for Alive<I, A> {
127  fn cheap_clone(&self) -> Self {
128    Self {
129      incarnation: self.incarnation,
130      meta: self.meta.clone(),
131      node: self.node.cheap_clone(),
132      protocol_version: self.protocol_version,
133      delegate_version: self.delegate_version,
134    }
135  }
136}
137
138/// Alive transform error.
139#[derive(thiserror::Error)]
140pub enum AliveTransformError<I: Transformable, A: Transformable> {
141  /// Node transform error.
142  #[error("node transform error: {0}")]
143  Node(#[from] NodeTransformError<I, A>),
144  /// Meta transform error.
145  #[error("meta transform error: {0}")]
146  Meta(#[from] MetaError),
147  /// Message too large.
148  #[error("encoded message too large, max 4294967295 got {0}")]
149  TooLarge(u64),
150  /// Encode buffer too small.
151  #[error("encode buffer too small")]
152  BufferTooSmall,
153  /// The buffer did not contain enough bytes to decode Alive.
154  #[error("the buffer did not contain enough bytes to decode Alive")]
155  NotEnoughBytes,
156  /// Invalid protocol version.
157  #[error("unknown protocol version: {0}")]
158  UnknownProtocolVersion(#[from] UnknownProtocolVersion),
159  /// Invalid delegate version.
160  #[error("unknown delegate version: {0}")]
161  UnknownDelegateVersion(#[from] UnknownDelegateVersion),
162}
163
164impl<I: Transformable, A: Transformable> core::fmt::Debug for AliveTransformError<I, A> {
165  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166    write!(f, "{}", self)
167  }
168}
169
170impl<I, A> Transformable for Alive<I, A>
171where
172  I: Transformable + 'static,
173  A: Transformable + 'static,
174{
175  type Error = AliveTransformError<I, A>;
176
177  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
178    let encoded_len = self.encoded_len();
179    if encoded_len as u64 > u32::MAX as u64 {
180      return Err(Self::Error::TooLarge(encoded_len as u64));
181    }
182
183    if encoded_len > dst.len() {
184      return Err(Self::Error::BufferTooSmall);
185    }
186
187    let mut offset = 0;
188    NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
189    offset += MAX_ENCODED_LEN_SIZE;
190
191    NetworkEndian::write_u32(&mut dst[offset..], self.incarnation);
192    offset += core::mem::size_of::<u32>();
193
194    offset += self
195      .meta
196      .encode(&mut dst[offset..])
197      .map_err(Self::Error::Meta)?;
198
199    offset += self
200      .node
201      .encode(&mut dst[offset..])
202      .map_err(Self::Error::Node)?;
203
204    dst[offset] = self.protocol_version as u8;
205    offset += 1;
206
207    dst[offset] = self.delegate_version as u8;
208    offset += 1;
209
210    debug_assert_eq!(
211      offset, encoded_len,
212      "expect bytes written ({encoded_len}) not match actual bytes written ({offset})"
213    );
214    Ok(offset)
215  }
216
217  fn encoded_len(&self) -> usize {
218    MAX_ENCODED_LEN_SIZE
219      + core::mem::size_of::<u32>() // incarnation
220      + self.meta.encoded_len()
221      + self.node.encoded_len()
222      + 1 // protocol_version
223      + 1 // delegate_version
224  }
225
226  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
227  where
228    Self: Sized,
229  {
230    let mut offset = 0;
231    if core::mem::size_of::<u32>() > src.len() {
232      return Err(Self::Error::NotEnoughBytes);
233    }
234
235    let encoded_len = NetworkEndian::read_u32(&src[offset..]) as usize;
236    offset += MAX_ENCODED_LEN_SIZE;
237    if encoded_len > src.len() {
238      return Err(Self::Error::NotEnoughBytes);
239    }
240
241    let incarnation = NetworkEndian::read_u32(&src[offset..]);
242    offset += core::mem::size_of::<u32>();
243
244    let (meta_len, meta) = Meta::decode(&src[offset..]).map_err(Self::Error::Meta)?;
245    offset += meta_len;
246
247    let (node_len, node) = Node::decode(&src[offset..]).map_err(Self::Error::Node)?;
248    offset += node_len;
249
250    if 1 + offset > src.len() {
251      return Err(Self::Error::NotEnoughBytes);
252    }
253    let protocol_version =
254      ProtocolVersion::try_from(src[offset]).map_err(Self::Error::UnknownProtocolVersion)?;
255    offset += 1;
256
257    if 1 + offset > src.len() {
258      return Err(Self::Error::NotEnoughBytes);
259    }
260    let delegate_version =
261      DelegateVersion::try_from(src[offset]).map_err(Self::Error::UnknownDelegateVersion)?;
262    offset += 1;
263
264    Ok((
265      offset,
266      Self {
267        incarnation,
268        meta,
269        node,
270        protocol_version,
271        delegate_version,
272      },
273    ))
274  }
275}
276
277#[cfg(feature = "rkyv")]
278const _: () = {
279  use core::fmt::Debug;
280  use rkyv::Archive;
281
282  impl<I: Debug + Archive, A: Debug + Archive> core::fmt::Debug for ArchivedAlive<I, A>
283  where
284    I::Archived: Debug,
285    A::Archived: Debug,
286  {
287    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288      f.debug_struct("Alive")
289        .field("incarnation", &self.incarnation)
290        .field("meta", &self.meta)
291        .field("node", &self.node)
292        .field("protocol_version", &self.protocol_version)
293        .field("delegate_version", &self.delegate_version)
294        .finish()
295    }
296  }
297
298  impl<I: Archive, A: Archive> PartialEq for ArchivedAlive<I, A>
299  where
300    I::Archived: PartialEq,
301    A::Archived: PartialEq,
302  {
303    fn eq(&self, other: &Self) -> bool {
304      self.incarnation == other.incarnation
305        && self.meta == other.meta
306        && self.node == other.node
307        && self.protocol_version == other.protocol_version
308        && self.delegate_version == other.delegate_version
309    }
310  }
311
312  impl<I: Archive, A: Archive> Eq for ArchivedAlive<I, A>
313  where
314    I::Archived: Eq,
315    A::Archived: Eq,
316  {
317  }
318
319  impl<I: Archive, A: Archive> core::hash::Hash for ArchivedAlive<I, A>
320  where
321    I::Archived: core::hash::Hash,
322    A::Archived: core::hash::Hash,
323  {
324    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
325      self.incarnation.hash(state);
326      self.meta.hash(state);
327      self.node.hash(state);
328      self.protocol_version.hash(state);
329      self.delegate_version.hash(state);
330    }
331  }
332};
333
334#[cfg(test)]
335const _: () = {
336  use std::net::SocketAddr;
337
338  use rand::{distr::Alphanumeric, random, rng, Rng};
339  use smol_str::SmolStr;
340
341  impl Alive<SmolStr, SocketAddr> {
342    pub(crate) fn random(size: usize) -> Self {
343      let id = rng()
344        .sample_iter(Alphanumeric)
345        .take(size)
346        .collect::<Vec<u8>>();
347      let id = String::from_utf8(id).unwrap().into();
348      Self {
349        incarnation: random(),
350        meta: (0..size)
351          .map(|_| random::<u8>())
352          .collect::<Vec<_>>()
353          .try_into()
354          .unwrap(),
355        node: Node::new(
356          id,
357          format!("127.0.0.1:{}", rng().random_range(0..65535))
358            .parse()
359            .unwrap(),
360        ),
361        protocol_version: ProtocolVersion::V1,
362        delegate_version: DelegateVersion::V1,
363      }
364    }
365  }
366};
367
368#[cfg(test)]
369mod tests {
370  use super::*;
371
372  #[test]
373  fn test_encode_decode() {
374    for i in 0..100 {
375      let alive = Alive::random(i);
376      let mut buf = vec![0; alive.encoded_len()];
377      let encoded_len = alive.encode(&mut buf).unwrap();
378      assert_eq!(encoded_len, alive.encoded_len());
379      let (decoded_len, decoded) = Alive::decode(&buf).unwrap();
380      assert_eq!(decoded_len, encoded_len);
381      assert_eq!(decoded, alive);
382    }
383  }
384
385  #[test]
386  fn test_access() {
387    let mut alive = Alive::random(16);
388    alive.set_incarnation(1);
389    assert_eq!(alive.incarnation(), 1);
390    alive.set_meta(Meta::empty());
391    assert_eq!(alive.meta(), &Meta::empty());
392    alive.set_node(Node::new("a".into(), "127.0.0.1:8081".parse().unwrap()));
393    assert_eq!(alive.node().id(), "a");
394    alive.set_protocol_version(ProtocolVersion::V1);
395    assert_eq!(alive.protocol_version(), ProtocolVersion::V1);
396    alive.set_delegate_version(DelegateVersion::V1);
397    assert_eq!(alive.delegate_version(), DelegateVersion::V1);
398  }
399}