1use crate::MetaError;
2
3use super::{
4 version::{UnknownDelegateVersion, UnknownProtocolVersion},
5 DelegateVersion, Meta, ProtocolVersion, MAX_ENCODED_LEN_SIZE,
6};
7
8use byteorder::{ByteOrder, NetworkEndian};
9use nodecraft::{CheapClone, Node, NodeTransformError};
10use transformable::Transformable;
11
12#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
16#[cfg_attr(
17 feature = "rkyv",
18 derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
19)]
20#[cfg_attr(feature = "rkyv", rkyv(compare(PartialEq)))]
21pub struct Alive<I, A> {
22 #[viewit(
24 getter(const, attrs(doc = "Returns the incarnation of the alive message")),
25 setter(
26 const,
27 attrs(doc = "Sets the incarnation of the alive message (Builder pattern)")
28 )
29 )]
30 incarnation: u32,
31 #[viewit(
33 getter(
34 const,
35 style = "ref",
36 attrs(doc = "Returns the meta of the alive message")
37 ),
38 setter(attrs(doc = "Sets the meta of the alive message (Builder pattern)"))
39 )]
40 meta: Meta,
41 #[viewit(
43 getter(
44 const,
45 style = "ref",
46 attrs(doc = "Returns the node of the alive message")
47 ),
48 setter(attrs(doc = "Sets the node of the alive message (Builder pattern)"))
49 )]
50 node: Node<I, A>,
51 #[viewit(
53 getter(
54 const,
55 attrs(doc = "Returns the protocol version of the alive message is speaking")
56 ),
57 setter(
58 const,
59 attrs(doc = "Sets the protocol version of the alive message is speaking (Builder pattern)")
60 )
61 )]
62 protocol_version: ProtocolVersion,
63 #[viewit(
65 getter(
66 const,
67 attrs(doc = "Returns the delegate version of the alive message is speaking")
68 ),
69 setter(
70 const,
71 attrs(doc = "Sets the delegate version of the alive message is speaking (Builder pattern)")
72 )
73 )]
74 delegate_version: DelegateVersion,
75}
76
77impl<I, A> Alive<I, A> {
78 #[inline]
80 pub const fn new(incarnation: u32, node: Node<I, A>) -> Self {
81 Self {
82 incarnation,
83 meta: Meta::empty(),
84 node,
85 protocol_version: ProtocolVersion::V1,
86 delegate_version: DelegateVersion::V1,
87 }
88 }
89
90 #[inline]
92 pub fn set_incarnation(&mut self, incarnation: u32) -> &mut Self {
93 self.incarnation = incarnation;
94 self
95 }
96
97 #[inline]
99 pub fn set_meta(&mut self, meta: Meta) -> &mut Self {
100 self.meta = meta;
101 self
102 }
103
104 #[inline]
106 pub fn set_node(&mut self, node: Node<I, A>) -> &mut Self {
107 self.node = node;
108 self
109 }
110
111 #[inline]
113 pub fn set_protocol_version(&mut self, protocol_version: ProtocolVersion) -> &mut Self {
114 self.protocol_version = protocol_version;
115 self
116 }
117
118 #[inline]
120 pub fn set_delegate_version(&mut self, delegate_version: DelegateVersion) -> &mut Self {
121 self.delegate_version = delegate_version;
122 self
123 }
124}
125
126impl<I: CheapClone, A: CheapClone> CheapClone for Alive<I, A> {
127 fn cheap_clone(&self) -> Self {
128 Self {
129 incarnation: self.incarnation,
130 meta: self.meta.clone(),
131 node: self.node.cheap_clone(),
132 protocol_version: self.protocol_version,
133 delegate_version: self.delegate_version,
134 }
135 }
136}
137
138#[derive(thiserror::Error)]
140pub enum AliveTransformError<I: Transformable, A: Transformable> {
141 #[error("node transform error: {0}")]
143 Node(#[from] NodeTransformError<I, A>),
144 #[error("meta transform error: {0}")]
146 Meta(#[from] MetaError),
147 #[error("encoded message too large, max 4294967295 got {0}")]
149 TooLarge(u64),
150 #[error("encode buffer too small")]
152 BufferTooSmall,
153 #[error("the buffer did not contain enough bytes to decode Alive")]
155 NotEnoughBytes,
156 #[error("unknown protocol version: {0}")]
158 UnknownProtocolVersion(#[from] UnknownProtocolVersion),
159 #[error("unknown delegate version: {0}")]
161 UnknownDelegateVersion(#[from] UnknownDelegateVersion),
162}
163
164impl<I: Transformable, A: Transformable> core::fmt::Debug for AliveTransformError<I, A> {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 write!(f, "{}", self)
167 }
168}
169
170impl<I, A> Transformable for Alive<I, A>
171where
172 I: Transformable + 'static,
173 A: Transformable + 'static,
174{
175 type Error = AliveTransformError<I, A>;
176
177 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
178 let encoded_len = self.encoded_len();
179 if encoded_len as u64 > u32::MAX as u64 {
180 return Err(Self::Error::TooLarge(encoded_len as u64));
181 }
182
183 if encoded_len > dst.len() {
184 return Err(Self::Error::BufferTooSmall);
185 }
186
187 let mut offset = 0;
188 NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
189 offset += MAX_ENCODED_LEN_SIZE;
190
191 NetworkEndian::write_u32(&mut dst[offset..], self.incarnation);
192 offset += core::mem::size_of::<u32>();
193
194 offset += self
195 .meta
196 .encode(&mut dst[offset..])
197 .map_err(Self::Error::Meta)?;
198
199 offset += self
200 .node
201 .encode(&mut dst[offset..])
202 .map_err(Self::Error::Node)?;
203
204 dst[offset] = self.protocol_version as u8;
205 offset += 1;
206
207 dst[offset] = self.delegate_version as u8;
208 offset += 1;
209
210 debug_assert_eq!(
211 offset, encoded_len,
212 "expect bytes written ({encoded_len}) not match actual bytes written ({offset})"
213 );
214 Ok(offset)
215 }
216
217 fn encoded_len(&self) -> usize {
218 MAX_ENCODED_LEN_SIZE
219 + core::mem::size_of::<u32>() + self.meta.encoded_len()
221 + self.node.encoded_len()
222 + 1 + 1 }
225
226 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
227 where
228 Self: Sized,
229 {
230 let mut offset = 0;
231 if core::mem::size_of::<u32>() > src.len() {
232 return Err(Self::Error::NotEnoughBytes);
233 }
234
235 let encoded_len = NetworkEndian::read_u32(&src[offset..]) as usize;
236 offset += MAX_ENCODED_LEN_SIZE;
237 if encoded_len > src.len() {
238 return Err(Self::Error::NotEnoughBytes);
239 }
240
241 let incarnation = NetworkEndian::read_u32(&src[offset..]);
242 offset += core::mem::size_of::<u32>();
243
244 let (meta_len, meta) = Meta::decode(&src[offset..]).map_err(Self::Error::Meta)?;
245 offset += meta_len;
246
247 let (node_len, node) = Node::decode(&src[offset..]).map_err(Self::Error::Node)?;
248 offset += node_len;
249
250 if 1 + offset > src.len() {
251 return Err(Self::Error::NotEnoughBytes);
252 }
253 let protocol_version =
254 ProtocolVersion::try_from(src[offset]).map_err(Self::Error::UnknownProtocolVersion)?;
255 offset += 1;
256
257 if 1 + offset > src.len() {
258 return Err(Self::Error::NotEnoughBytes);
259 }
260 let delegate_version =
261 DelegateVersion::try_from(src[offset]).map_err(Self::Error::UnknownDelegateVersion)?;
262 offset += 1;
263
264 Ok((
265 offset,
266 Self {
267 incarnation,
268 meta,
269 node,
270 protocol_version,
271 delegate_version,
272 },
273 ))
274 }
275}
276
277#[cfg(feature = "rkyv")]
278const _: () = {
279 use core::fmt::Debug;
280 use rkyv::Archive;
281
282 impl<I: Debug + Archive, A: Debug + Archive> core::fmt::Debug for ArchivedAlive<I, A>
283 where
284 I::Archived: Debug,
285 A::Archived: Debug,
286 {
287 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288 f.debug_struct("Alive")
289 .field("incarnation", &self.incarnation)
290 .field("meta", &self.meta)
291 .field("node", &self.node)
292 .field("protocol_version", &self.protocol_version)
293 .field("delegate_version", &self.delegate_version)
294 .finish()
295 }
296 }
297
298 impl<I: Archive, A: Archive> PartialEq for ArchivedAlive<I, A>
299 where
300 I::Archived: PartialEq,
301 A::Archived: PartialEq,
302 {
303 fn eq(&self, other: &Self) -> bool {
304 self.incarnation == other.incarnation
305 && self.meta == other.meta
306 && self.node == other.node
307 && self.protocol_version == other.protocol_version
308 && self.delegate_version == other.delegate_version
309 }
310 }
311
312 impl<I: Archive, A: Archive> Eq for ArchivedAlive<I, A>
313 where
314 I::Archived: Eq,
315 A::Archived: Eq,
316 {
317 }
318
319 impl<I: Archive, A: Archive> core::hash::Hash for ArchivedAlive<I, A>
320 where
321 I::Archived: core::hash::Hash,
322 A::Archived: core::hash::Hash,
323 {
324 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
325 self.incarnation.hash(state);
326 self.meta.hash(state);
327 self.node.hash(state);
328 self.protocol_version.hash(state);
329 self.delegate_version.hash(state);
330 }
331 }
332};
333
334#[cfg(test)]
335const _: () = {
336 use std::net::SocketAddr;
337
338 use rand::{distr::Alphanumeric, random, rng, Rng};
339 use smol_str::SmolStr;
340
341 impl Alive<SmolStr, SocketAddr> {
342 pub(crate) fn random(size: usize) -> Self {
343 let id = rng()
344 .sample_iter(Alphanumeric)
345 .take(size)
346 .collect::<Vec<u8>>();
347 let id = String::from_utf8(id).unwrap().into();
348 Self {
349 incarnation: random(),
350 meta: (0..size)
351 .map(|_| random::<u8>())
352 .collect::<Vec<_>>()
353 .try_into()
354 .unwrap(),
355 node: Node::new(
356 id,
357 format!("127.0.0.1:{}", rng().random_range(0..65535))
358 .parse()
359 .unwrap(),
360 ),
361 protocol_version: ProtocolVersion::V1,
362 delegate_version: DelegateVersion::V1,
363 }
364 }
365 }
366};
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_encode_decode() {
374 for i in 0..100 {
375 let alive = Alive::random(i);
376 let mut buf = vec![0; alive.encoded_len()];
377 let encoded_len = alive.encode(&mut buf).unwrap();
378 assert_eq!(encoded_len, alive.encoded_len());
379 let (decoded_len, decoded) = Alive::decode(&buf).unwrap();
380 assert_eq!(decoded_len, encoded_len);
381 assert_eq!(decoded, alive);
382 }
383 }
384
385 #[test]
386 fn test_access() {
387 let mut alive = Alive::random(16);
388 alive.set_incarnation(1);
389 assert_eq!(alive.incarnation(), 1);
390 alive.set_meta(Meta::empty());
391 assert_eq!(alive.meta(), &Meta::empty());
392 alive.set_node(Node::new("a".into(), "127.0.0.1:8081".parse().unwrap()));
393 assert_eq!(alive.node().id(), "a");
394 alive.set_protocol_version(ProtocolVersion::V1);
395 assert_eq!(alive.protocol_version(), ProtocolVersion::V1);
396 alive.set_delegate_version(DelegateVersion::V1);
397 assert_eq!(alive.delegate_version(), DelegateVersion::V1);
398 }
399}