1use core::marker::PhantomData;
2
3use bytes::Bytes;
4use nodecraft::CheapClone;
5use triomphe::Arc;
6
7use crate::RepeatedDecoder;
8
9use super::{Data, DataRef, DecodeError, EncodeError, WireType, merge, skip};
10
11mod state;
12pub use state::*;
13
14#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
16#[derive(Debug, PartialEq, Eq, Hash)]
17#[cfg_attr(any(feature = "arbitrary", test), derive(arbitrary::Arbitrary))]
18pub struct PushPull<I, A> {
19 #[viewit(
21 getter(
22 const,
23 attrs(doc = "Returns whether the push pull message is a join message")
24 ),
25 setter(
26 const,
27 attrs(doc = "Sets whether the push pull message is a join message (Builder pattern)")
28 )
29 )]
30 join: bool,
31 #[viewit(
33 getter(
34 const,
35 style = "ref",
36 attrs(doc = "Returns the states of the push pull message")
37 ),
38 setter(attrs(doc = "Sets the states of the push pull message (Builder pattern)"))
39 )]
40 #[cfg_attr(any(feature = "arbitrary", test), arbitrary(with = crate::arbitrary_impl::triomphe_arc))]
41 states: Arc<[PushNodeState<I, A>]>,
42 #[viewit(
44 getter(
45 const,
46 style = "ref",
47 attrs(doc = "Returns the user data of the push pull message")
48 ),
49 setter(attrs(doc = "Sets the user data of the push pull message (Builder pattern)"))
50 )]
51 #[cfg_attr(any(feature = "arbitrary", test), arbitrary(with = crate::arbitrary_impl::bytes))]
52 user_data: Bytes,
53}
54
55impl<I, A> Clone for PushPull<I, A> {
56 fn clone(&self) -> Self {
57 Self {
58 join: self.join,
59 states: self.states.clone(),
60 user_data: self.user_data.clone(),
61 }
62 }
63}
64
65impl<I, A> CheapClone for PushPull<I, A> {
66 fn cheap_clone(&self) -> Self {
67 Self {
68 join: self.join,
69 states: self.states.clone(),
70 user_data: self.user_data.clone(),
71 }
72 }
73}
74
75const JOIN_TAG: u8 = 1;
76const JOIN_BYTE: u8 = merge(WireType::Varint, JOIN_TAG);
77const STATES_TAG: u8 = 2;
78const STATES_BYTE: u8 = merge(WireType::LengthDelimited, STATES_TAG);
79const USER_DATA_TAG: u8 = 3;
80const USER_DATA_BYTE: u8 = merge(WireType::LengthDelimited, USER_DATA_TAG);
81
82impl<I, A> PushPull<I, A> {
83 #[inline]
85 pub fn new(join: bool, states: impl Iterator<Item = PushNodeState<I, A>>) -> Self {
86 Self {
87 states: Arc::from_iter(states),
88 user_data: Bytes::new(),
89 join,
90 }
91 }
92
93 #[inline]
95 pub fn into_components(self) -> (bool, Bytes, Arc<[PushNodeState<I, A>]>) {
96 (self.join, self.user_data, self.states)
97 }
98}
99
100impl<I, A> Data for PushPull<I, A>
101where
102 I: Data,
103 A: Data,
104{
105 type Ref<'a> = PushPullRef<'a, I::Ref<'a>, A::Ref<'a>>;
106
107 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
108 where
109 Self: Sized,
110 {
111 val
112 .states
113 .iter::<PushNodeState<I, A>>()
114 .map(|res| res.and_then(PushNodeState::from_ref))
115 .collect::<Result<Arc<[_]>, DecodeError>>()
116 .map(|states| Self {
117 join: val.join,
118 states,
119 user_data: Bytes::copy_from_slice(val.user_data),
120 })
121 }
122
123 fn encoded_len(&self) -> usize {
124 let mut len = 0;
125 if self.join {
126 len += 1 + 1; }
128
129 for i in self.states.iter() {
130 len += 1 + i.encoded_len_with_length_delimited();
131 }
132
133 let user_data_len = self.user_data.len();
134
135 if user_data_len != 0 {
136 len += 1 + self.user_data.encoded_len_with_length_delimited();
137 }
138
139 len
140 }
141
142 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
143 macro_rules! bail {
144 ($this:ident($offset:expr, $len:ident)) => {
145 if $offset >= $len {
146 return Err(EncodeError::insufficient_buffer($offset, $len));
147 }
148 };
149 }
150
151 let mut offset = 0;
152 let len = buf.len();
153 if self.join {
154 if len < 2 {
155 return Err(EncodeError::insufficient_buffer(self.encoded_len(), len));
156 }
157 buf[offset] = JOIN_BYTE;
158 offset += 1;
159 buf[offset] = 1;
160 offset += 1;
161 }
162
163 for i in self.states.iter() {
164 bail!(self(offset, len));
165 buf[offset] = STATES_BYTE;
166 offset += 1;
167 {
168 offset += i
169 .encode_length_delimited(&mut buf[offset..])
170 .map_err(|e| e.update(self.encoded_len(), len))?
171 }
172 }
173
174 let user_data_len = self.user_data.len();
175 if user_data_len != 0 {
176 bail!(self(offset, len));
177 buf[offset] = USER_DATA_BYTE;
178 offset += 1;
179 offset += self.user_data.encode_length_delimited(&mut buf[offset..])?;
180 }
181
182 #[cfg(debug_assertions)]
183 super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
184 Ok(offset)
185 }
186}
187
188#[viewit::viewit(getters(vis_all = "pub"), setters(skip))]
190#[derive(Debug, PartialEq, Eq, Hash)]
191pub struct PushPullRef<'a, I, A> {
192 #[viewit(getter(
194 const,
195 attrs(doc = "Returns whether the push pull message is a join message")
196 ))]
197 join: bool,
198 #[viewit(getter(
200 const,
201 style = "ref",
202 attrs(doc = "Returns the states of the push pull message")
203 ))]
204 states: RepeatedDecoder<'a>,
205 #[viewit(getter(const, attrs(doc = "Returns the user data of the push pull message")))]
207 user_data: &'a [u8],
208
209 #[viewit(getter(skip))]
210 _m: PhantomData<(I, A)>,
211}
212
213impl<I, A> Clone for PushPullRef<'_, I, A> {
214 fn clone(&self) -> Self {
215 *self
216 }
217}
218
219impl<I, A> Copy for PushPullRef<'_, I, A> {}
220
221impl<'a, I, A> DataRef<'a, PushPull<I, A>> for PushPullRef<'a, I::Ref<'a>, A::Ref<'a>>
222where
223 I: Data,
224 A: Data,
225{
226 fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> {
227 let mut offset = 0;
228 let mut join = None;
229 let mut node_state_offsets = None;
230 let mut num_states = 0;
231 let mut user_data = None;
232
233 while offset < src.len() {
234 match src[offset] {
235 JOIN_BYTE => {
236 if join.is_some() {
237 return Err(DecodeError::duplicate_field("PushPull", "join", JOIN_TAG));
238 }
239 offset += 1;
240
241 if offset >= src.len() {
242 return Err(DecodeError::buffer_underflow());
243 }
244 let (read, val) = <bool as Data>::decode(&src[offset..])?;
245 offset += read;
246 join = Some(val);
247 }
248 STATES_BYTE => {
249 let readed = super::skip("PushPull", &src[offset..])?;
250 if let Some((ref mut fnso, ref mut lnso)) = node_state_offsets {
251 if *fnso > offset {
252 *fnso = offset;
253 }
254
255 if *lnso < offset + readed {
256 *lnso = offset + readed;
257 }
258 } else {
259 node_state_offsets = Some((offset, offset + readed));
260 }
261 num_states += 1;
262 offset += readed;
263 }
264 USER_DATA_BYTE => {
265 if user_data.is_some() {
266 return Err(DecodeError::duplicate_field(
267 "PushPull",
268 "user_data",
269 USER_DATA_TAG,
270 ));
271 }
272 offset += 1;
273
274 let (readed, value) = <&[u8] as DataRef<Bytes>>::decode_length_delimited(&src[offset..])?;
275 offset += readed;
276 user_data = Some(value);
277 }
278 _ => offset += skip("PushPull", &src[offset..])?,
279 }
280 }
281
282 let join = join.unwrap_or_default();
283 let user_data = user_data.unwrap_or_default();
284 Ok((
285 offset,
286 Self {
287 join,
288 states: {
289 let val =
290 RepeatedDecoder::new(STATES_TAG, WireType::LengthDelimited, src).with_nums(num_states);
291 if let Some((first, last)) = node_state_offsets {
292 val.with_offsets(first, last)
293 } else {
294 val
295 }
296 },
297 user_data,
298 _m: PhantomData,
299 },
300 ))
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::net::SocketAddr;
307
308 use arbitrary::{Arbitrary, Unstructured};
309
310 use crate::{DelegateVersion, Meta, ProtocolVersion, State};
311
312 use super::*;
313
314 #[test]
315 fn test_push_pull_clone_and_cheap_clone() {
316 let mut data = vec![0; 1024];
317 rand::fill(&mut data[..]);
318 let mut data = Unstructured::new(&data);
319
320 let push_pull = PushPull::<String, SocketAddr>::arbitrary(&mut data).unwrap();
321 let cloned = push_pull.clone();
322 let cheap_cloned = push_pull.cheap_clone();
323 assert_eq!(cloned, push_pull);
324 assert_eq!(cheap_cloned, push_pull);
325 let cloned1 = format!("{:?}", cloned);
326 let cheap_cloned1 = format!("{:?}", cheap_cloned);
327 assert_eq!(cloned1, cheap_cloned1);
328 }
329
330 #[test]
331 fn test_push_node_state() {
332 let mut data = vec![0; 1024];
333 rand::fill(&mut data[..]);
334 let mut data = Unstructured::new(&data);
335
336 let mut state = PushNodeState::<String, SocketAddr>::arbitrary(&mut data).unwrap();
337 state.set_id("test".into());
338 assert_eq!(state.id(), "test");
339 state.set_address(SocketAddr::from(([127, 0, 0, 1], 8080)));
340 assert_eq!(state.address(), &SocketAddr::from(([127, 0, 0, 1], 8080)));
341 state.set_meta(Meta::try_from("test").unwrap());
342 assert_eq!(state.meta(), &Meta::try_from("test").unwrap());
343 state.set_incarnation(100);
344 assert_eq!(state.incarnation(), 100);
345
346 state.set_state(State::Alive);
347 assert_eq!(state.state(), State::Alive);
348
349 state.set_protocol_version(ProtocolVersion::V1);
350 assert_eq!(state.protocol_version(), ProtocolVersion::V1);
351
352 state.set_delegate_version(DelegateVersion::V1);
353 assert_eq!(state.delegate_version(), DelegateVersion::V1);
354 }
355}