1use std::{
2 collections::{HashMap, VecDeque},
3 sync::{
4 Arc,
5 atomic::{AtomicBool, AtomicU32, AtomicUsize},
6 },
7};
8
9use agnostic_lite::{AsyncSpawner, RuntimeLite};
10use async_channel::{Receiver, Sender};
11use async_lock::{Mutex, RwLock};
12
13use atomic_refcell::AtomicRefCell;
14use futures::stream::FuturesUnordered;
15use nodecraft::{CheapClone, Node, resolver::AddressResolver};
16use rand::RngExt;
17
18use super::{
19 Options,
20 awareness::Awareness,
21 broadcast::MemberlistBroadcast,
22 delegate::{Delegate, VoidDelegate},
23 error::Error,
24 proto::{Message, PushNodeState, TinyVec},
25 queue::TransmitLimitedQueue,
26 state::{AckManager, LocalNodeState},
27 suspicion::Suspicion,
28 transport::Transport,
29};
30
31#[cfg(feature = "encryption")]
32use super::keyring::Keyring;
33
34#[cfg(any(test, feature = "test"))]
35pub(crate) mod tests;
36
37#[viewit::viewit]
38pub(crate) struct HotData {
39 sequence_num: AtomicU32,
40 incarnation: AtomicU32,
41 push_pull_req: AtomicU32,
42 leave: AtomicBool,
43 num_nodes: Arc<AtomicU32>,
44}
45
46impl HotData {
47 fn new() -> Self {
48 Self {
49 sequence_num: AtomicU32::new(0),
50 incarnation: AtomicU32::new(0),
51 num_nodes: Arc::new(AtomicU32::new(0)),
52 push_pull_req: AtomicU32::new(0),
53 leave: AtomicBool::new(false),
54 }
55 }
56}
57
58#[viewit::viewit]
59pub(crate) struct MessageHandoff<I, A> {
60 msg: Message<I, A>,
61 from: A,
62}
63
64#[viewit::viewit]
65pub(crate) struct MessageQueue<I, A> {
66 high: VecDeque<MessageHandoff<I, A>>,
68 low: VecDeque<MessageHandoff<I, A>>,
70}
71
72impl<I, A> MessageQueue<I, A> {
73 const fn new() -> Self {
74 Self {
75 high: VecDeque::new(),
76 low: VecDeque::new(),
77 }
78 }
79}
80
81pub(crate) struct Member<T, D>
83where
84 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
85 T: Transport,
86{
87 pub(crate) state: LocalNodeState<T::Id, T::ResolvedAddress>,
88 pub(crate) suspicion: Option<Suspicion<T, D>>,
89}
90
91impl<T, D> core::fmt::Debug for Member<T, D>
92where
93 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
94 T: Transport,
95{
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("Member")
98 .field("state", &self.state)
99 .finish()
100 }
101}
102
103impl<T, D> core::ops::Deref for Member<T, D>
104where
105 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
106 T: Transport,
107{
108 type Target = LocalNodeState<T::Id, T::ResolvedAddress>;
109
110 fn deref(&self) -> &Self::Target {
111 &self.state
112 }
113}
114
115pub(crate) struct Members<T, D>
116where
117 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
118 T: Transport,
119{
120 pub(crate) local: Node<T::Id, T::ResolvedAddress>,
121 pub(crate) nodes: TinyVec<Member<T, D>>,
122 pub(crate) node_map: HashMap<T::Id, usize>,
123}
124
125impl<T, D> core::ops::Index<usize> for Members<T, D>
126where
127 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
128 T: Transport,
129{
130 type Output = Member<T, D>;
131
132 fn index(&self, index: usize) -> &Self::Output {
133 &self.nodes[index]
134 }
135}
136
137impl<T, D> core::ops::IndexMut<usize> for Members<T, D>
138where
139 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
140 T: Transport,
141{
142 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
143 &mut self.nodes[index]
144 }
145}
146
147impl<T, D> rand::seq::IndexedRandom for Members<T, D>
148where
149 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
150 T: Transport,
151{
152 fn len(&self) -> usize {
153 self.nodes.len()
154 }
155}
156
157impl<T, D> rand::seq::SliceRandom for Members<T, D>
158where
159 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
160 T: Transport,
161{
162 fn shuffle<R>(&mut self, rng: &mut R)
163 where
164 R: rand::Rng + ?Sized,
165 {
166 #[inline]
170 fn gen_index<R: rand::Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize {
171 if ubound <= (u32::MAX as usize) {
172 rng.random_range(0..ubound as u32) as usize
173 } else {
174 rng.random_range(0..ubound)
175 }
176 }
177
178 for i in (1..self.nodes.len()).rev() {
179 let ridx = gen_index(rng, i + 1);
181 let curr = self.node_map.get_mut(self.nodes[i].state.id()).unwrap();
182 *curr = ridx;
183 let target = self.node_map.get_mut(self.nodes[ridx].state.id()).unwrap();
184 *target = i;
185 self.nodes.swap(i, ridx);
186 }
187 }
188
189 fn partial_shuffle<R>(
190 &mut self,
191 _rng: &mut R,
192 _amount: usize,
193 ) -> (&mut [Self::Output], &mut [Self::Output])
194 where
195 Self::Output: Sized,
196 R: rand::Rng + ?Sized,
197 {
198 unreachable!()
199 }
200}
201
202impl<T, D> Members<T, D>
203where
204 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
205 T: Transport,
206{
207 fn new(local: Node<T::Id, T::ResolvedAddress>) -> Self {
208 Self {
209 nodes: TinyVec::new(),
210 node_map: HashMap::new(),
211 local,
212 }
213 }
214}
215
216impl<T, D> Members<T, D>
217where
218 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
219 T: Transport,
220{
221 pub(crate) fn any_alive(&self) -> bool {
222 for m in self.nodes.iter() {
223 if !m.dead_or_left() && m.id().ne(self.local.id()) {
224 return true;
225 }
226 }
227
228 false
229 }
230}
231
232pub(crate) struct MemberlistCore<T, D>
233where
234 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
235 T: Transport,
236{
237 pub(crate) id: T::Id,
238 pub(crate) hot: HotData,
239 pub(crate) awareness: Awareness,
240 pub(crate) broadcast:
241 TransmitLimitedQueue<MemberlistBroadcast<T::Id, T::ResolvedAddress>, Arc<AtomicU32>>,
242 pub(crate) leave_broadcast_tx: Sender<()>,
243 pub(crate) leave_broadcast_rx: Receiver<()>,
244 pub(crate) handles: AtomicRefCell<
245 FuturesUnordered<<<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()>>,
246 >,
247 pub(crate) probe_index: AtomicUsize,
248 pub(crate) handoff_tx: Sender<()>,
249 pub(crate) handoff_rx: Receiver<()>,
250 pub(crate) queue: Mutex<MessageQueue<T::Id, T::ResolvedAddress>>,
251 pub(crate) nodes: Arc<RwLock<Members<T, D>>>,
252 pub(crate) ack_manager: AckManager<T::Runtime>,
253 pub(crate) transport: Arc<T>,
254 pub(crate) shutdown_tx: Sender<()>,
256 pub(crate) advertise: T::ResolvedAddress,
257 pub(crate) opts: Arc<Options>,
258 #[cfg(feature = "encryption")]
259 pub(crate) keyring: Option<Keyring>,
260}
261
262impl<T, D> MemberlistCore<T, D>
263where
264 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
265 T: Transport,
266{
267 pub(crate) async fn shutdown(&self) -> Result<(), T::Error> {
268 if !self.shutdown_tx.close() {
269 return Ok(());
270 }
271
272 if let Err(e) = self.transport.shutdown().await {
276 tracing::error!(err=%e, "memberlist: failed to shutdown transport");
277 return Err(e);
278 }
279
280 Ok(())
281 }
282}
283
284impl<T, D> Drop for MemberlistCore<T, D>
285where
286 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
287 T: Transport,
288{
289 fn drop(&mut self) {
290 self.shutdown_tx.close();
291 }
292}
293
294pub struct Memberlist<
306 T,
307 D = VoidDelegate<
308 <T as Transport>::Id,
309 <<T as Transport>::Resolver as AddressResolver>::ResolvedAddress,
310 >,
311> where
312 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
313 T: Transport,
314{
315 pub(crate) inner: Arc<MemberlistCore<T, D>>,
316 pub(crate) delegate: Option<Arc<D>>,
317}
318
319impl<T, D> Clone for Memberlist<T, D>
320where
321 T: Transport,
322 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
323{
324 fn clone(&self) -> Self {
325 Self {
326 inner: self.inner.clone(),
327 delegate: self.delegate.clone(),
328 }
329 }
330}
331
332impl<T, D> Memberlist<T, D>
333where
334 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
335 T: Transport,
336{
337 pub(crate) async fn new_in(
338 transport: T,
339 delegate: Option<D>,
340 opts: Options,
341 ) -> Result<(Receiver<()>, T::ResolvedAddress, Self), Error<T, D>> {
342 let (handoff_tx, handoff_rx) = async_channel::bounded(1);
343 let (leave_broadcast_tx, leave_broadcast_rx) = async_channel::bounded(1);
344
345 let advertise = transport.advertise_address();
349 let id = transport.local_id();
350 let node = Node::new(id.clone(), advertise.clone());
351 let awareness = Awareness::new(
352 opts.awareness_max_multiplier as isize,
353 #[cfg(feature = "metrics")]
354 Arc::new(vec![]),
355 );
356 let hot = HotData::new();
357 let num_nodes = hot.num_nodes.clone();
358 let broadcast = TransmitLimitedQueue::new(opts.retransmit_mult, num_nodes);
359
360 let (shutdown_tx, shutdown_rx) = async_channel::bounded(1);
361
362 #[cfg(feature = "encryption")]
363 let keyring = match (opts.primary_key, opts.secret_keys.is_empty()) {
364 (None, false) => {
365 tracing::warn!("memberlist: using first key in keyring as primary key");
366 let mut iter = opts.secret_keys.iter().copied();
367 let pk = iter.next().unwrap();
368 let keyring = Keyring::with_keys(pk, iter);
369 Some(keyring)
370 }
371 (Some(pk), true) => Some(Keyring::new(pk)),
372 (Some(pk), false) => Some(Keyring::with_keys(pk, opts.secret_keys.iter().copied())),
373 (None, true) => None,
374 };
375
376 let this = Memberlist {
377 inner: Arc::new(MemberlistCore {
378 id: id.cheap_clone(),
379 hot,
380 awareness,
381 broadcast,
382 leave_broadcast_tx,
383 leave_broadcast_rx,
384 probe_index: AtomicUsize::new(0),
385 handles: AtomicRefCell::new(FuturesUnordered::new()),
386 handoff_tx,
387 handoff_rx,
388 queue: Mutex::new(MessageQueue::new()),
389 nodes: Arc::new(RwLock::new(Members::new(node))),
390 ack_manager: AckManager::new(),
391 shutdown_tx,
392 advertise: advertise.cheap_clone(),
393 transport: Arc::new(transport),
394 opts: Arc::new(opts),
395 #[cfg(feature = "encryption")]
396 keyring,
397 }),
398 delegate: delegate.map(Arc::new),
399 };
400
401 {
402 let handles = this.inner.handles.borrow();
403 handles.push(this.stream_listener(shutdown_rx.clone()));
404 handles.push(this.packet_handler(shutdown_rx.clone()));
405 handles.push(this.packet_listener(shutdown_rx.clone()));
406 #[cfg(feature = "metrics")]
407 handles.push(this.check_broadcast_queue_depth(shutdown_rx.clone()));
408 }
409
410 Ok((shutdown_rx, this.inner.advertise.cheap_clone(), this))
411 }
412}
413
414impl<T, D> Memberlist<T, D>
416where
417 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
418 T: Transport,
419{
420 #[inline]
421 pub(crate) fn get_advertise(&self) -> &T::ResolvedAddress {
422 &self.inner.advertise
423 }
424
425 #[inline]
427 pub(crate) async fn any_alive(&self) -> bool {
428 self
429 .inner
430 .nodes
431 .read()
432 .await
433 .nodes
434 .iter()
435 .any(|n| !n.state.dead_or_left() && n.state.server.id().ne(&self.inner.id))
436 }
437
438 #[cfg(feature = "metrics")]
439 fn check_broadcast_queue_depth(
440 &self,
441 shutdown_rx: Receiver<()>,
442 ) -> <<T::Runtime as RuntimeLite>::Spawner as AsyncSpawner>::JoinHandle<()> {
443 use futures::{FutureExt, StreamExt};
444
445 let queue_check_interval = self.inner.opts.queue_check_interval;
446 let this = self.clone();
447
448 <T::Runtime as RuntimeLite>::spawn(async move {
449 let tick = <T::Runtime as RuntimeLite>::interval(queue_check_interval);
450 futures::pin_mut!(tick);
451 loop {
452 futures::select! {
453 _ = shutdown_rx.recv().fuse() => {
454 tracing::debug!("memberlist: broadcast queue checker exits");
455 return;
456 },
457 _ = tick.next().fuse() => {
458 let numq = this.inner.broadcast.num_queued().await;
459 metrics::histogram!("memberlist.queue.broadcasts").record(numq as f64);
460 }
461 }
462 }
463 })
464 }
465
466 pub(crate) async fn verify_protocol(
467 &self,
468 _remote: &[PushNodeState<T::Id, T::ResolvedAddress>],
469 ) -> Result<(), Error<T, D>> {
470 Ok(())
473 }
474}