1use std::borrow::Cow;
2use std::cmp;
3use std::collections::{BTreeMap, HashMap};
4use std::net::SocketAddr;
5use std::time::{Duration, Instant};
6
7use smallvec::SmallVec;
8use tokio::sync::oneshot;
9use tracing::{debug, info, instrument, warn};
10
11const NODE_CACHE_TIMEOUT: Duration = Duration::from_secs(2);
12pub type Nodes = SmallVec<[SocketAddr; 5]>;
13
14pub async fn start_node_selector<S>(
18 local_node: SocketAddr,
19 local_dc: Cow<'static, str>,
20 mut selector: S,
21) -> NodeSelectorHandle
22where
23 S: NodeSelector + Send + 'static,
24{
25 let (tx, rx) = flume::bounded(100);
26
27 tokio::spawn(async move {
28 let mut total_nodes = 0;
29 let mut data_centers = BTreeMap::new();
30 let mut cached_nodes = HashMap::<Consistency, (Instant, Nodes)>::new();
31
32 while let Ok(op) = rx.recv_async().await {
33 match op {
34 Op::SetNodes {
35 data_centers: new_data_centers,
36 } => {
37 let mut new_total = 0;
38 for (name, nodes) in new_data_centers {
39 new_total += nodes.len();
40 data_centers.insert(name, NodeCycler::from(nodes));
41 }
42 total_nodes = new_total;
43 info!(
44 total_nodes = total_nodes,
45 num_data_centers = data_centers.len(),
46 "Node selector has updated eligible nodes.",
47 );
48
49 cached_nodes.clear();
50 },
51 Op::GetNodes { consistency, tx } => {
52 if let Some((last_refreshed, nodes)) = cached_nodes.get(&consistency)
53 {
54 if last_refreshed.elapsed() < NODE_CACHE_TIMEOUT {
55 let _ = tx.send(Ok(nodes.clone()));
56 continue;
57 }
58 }
59
60 let nodes = selector.select_nodes(
61 local_node,
62 &local_dc,
63 total_nodes,
64 &mut data_centers,
65 consistency,
66 );
67
68 if let Ok(ref nodes) = nodes {
69 cached_nodes
70 .insert(consistency, (Instant::now(), nodes.clone()));
71 }
72
73 let _ = tx.send(nodes);
74 },
75 }
76 }
77
78 info!("Node selector service has shutdown.");
79 });
80
81 NodeSelectorHandle { tx }
82}
83
84#[derive(Clone)]
85pub struct NodeSelectorHandle {
89 tx: flume::Sender<Op>,
90}
91
92impl NodeSelectorHandle {
93 pub(crate) async fn set_nodes(
95 &self,
96 data_centers: BTreeMap<Cow<'static, str>, Nodes>,
97 ) {
98 self.tx
99 .send_async(Op::SetNodes { data_centers })
100 .await
101 .expect("contact actor");
102 }
103
104 pub async fn get_nodes(
109 &self,
110 consistency: Consistency,
111 ) -> Result<Nodes, ConsistencyError> {
112 let (tx, rx) = oneshot::channel();
113
114 self.tx
115 .send_async(Op::GetNodes { consistency, tx })
116 .await
117 .expect("contact actor");
118
119 rx.await.expect("get actor response")
120 }
121}
122
123enum Op {
124 SetNodes {
125 data_centers: BTreeMap<Cow<'static, str>, Nodes>,
126 },
127 GetNodes {
128 consistency: Consistency,
129 tx: oneshot::Sender<Result<Nodes, ConsistencyError>>,
130 },
131}
132
133#[derive(Debug, thiserror::Error)]
134pub enum ConsistencyError {
135 #[error(
136 "Not enough nodes are present in the cluster to achieve this consistency level."
137 )]
138 NotEnoughNodes { live: usize, required: usize },
139
140 #[error(
141 "Failed to achieve the desired consistency level before the timeout \
142 ({timeout:?}) elapsed. Got {responses} responses but needed {required} responses."
143 )]
144 ConsistencyFailure {
145 responses: usize,
146 required: usize,
147 timeout: Duration,
148 },
149}
150
151#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
152pub enum Consistency {
160 None,
162
163 One,
165
166 Two,
168
169 Three,
171
172 Quorum,
177
178 LocalQuorum,
183
184 All,
186
187 EachQuorum,
195}
196
197pub trait NodeSelector {
198 fn select_nodes(
202 &mut self,
203 local_node: SocketAddr,
204 local_dc: &str,
205 total_nodes: usize,
206 data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
207 consistency: Consistency,
208 ) -> Result<Nodes, ConsistencyError>;
209}
210
211#[derive(Debug, Copy, Clone, Default)]
212pub struct DCAwareSelector;
220
221impl NodeSelector for DCAwareSelector {
222 fn select_nodes(
223 &mut self,
224 local_node: SocketAddr,
225 local_dc: &str,
226 total_nodes: usize,
227 data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
228 consistency: Consistency,
229 ) -> Result<Nodes, ConsistencyError> {
230 let mut selected_nodes = Nodes::new();
231
232 match consistency {
233 Consistency::One => {
234 return select_n_nodes(
235 local_node,
236 local_dc,
237 1,
238 total_nodes,
239 data_centers,
240 )
241 },
242 Consistency::Two => {
243 return select_n_nodes(
244 local_node,
245 local_dc,
246 2,
247 total_nodes,
248 data_centers,
249 )
250 },
251 Consistency::Three => {
252 return select_n_nodes(
253 local_node,
254 local_dc,
255 3,
256 total_nodes,
257 data_centers,
258 )
259 },
260 Consistency::Quorum => {
261 let majority = total_nodes / 2;
264
265 let mut dcs_iterators = data_centers
266 .iter()
267 .map(|(_, nodes)| {
268 nodes
269 .get_nodes()
270 .iter()
271 .copied()
272 .filter(|addr| addr != &local_node)
273 })
274 .collect::<Vec<_>>();
275 let mut previous_total = selected_nodes.len();
276 while selected_nodes.len() < majority {
277 let nodes = dcs_iterators.iter_mut().filter_map(|iter| iter.next());
278 selected_nodes.extend(nodes);
279
280 if previous_total == selected_nodes.len() {
282 return Err(ConsistencyError::NotEnoughNodes {
283 live: selected_nodes.len(),
284 required: majority,
285 });
286 }
287
288 previous_total = selected_nodes.len();
289 }
290 },
291 Consistency::LocalQuorum => {
292 if let Some(nodes) = data_centers.get(local_dc) {
293 let majority = nodes.len() / 2;
296 selected_nodes.extend(
297 nodes
298 .get_nodes()
299 .iter()
300 .copied()
301 .filter(|addr| addr != &local_node)
302 .take(majority),
303 );
304 }
305 },
306 Consistency::All => selected_nodes.extend(
307 data_centers
308 .values()
309 .flat_map(|cycler| cycler.nodes.clone())
310 .filter(|addr| addr != &local_node),
311 ),
312 Consistency::EachQuorum => {
313 for (name, nodes) in data_centers {
314 let majority = if name == local_dc {
315 nodes.len() / 2
318 } else {
319 (nodes.len() / 2) + 1
320 };
321
322 selected_nodes.extend(
323 nodes
324 .get_nodes()
325 .iter()
326 .copied()
327 .filter(|addr| addr != &local_node)
328 .take(majority),
329 );
330 }
331 },
332 Consistency::None => {},
333 }
334
335 Ok(selected_nodes)
336 }
337}
338
339#[instrument(name = "dc-aware-selector")]
340fn select_n_nodes(
377 local_node: SocketAddr,
378 local_dc: &str,
379 n: usize,
380 total_nodes: usize,
381 data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
382) -> Result<Nodes, ConsistencyError> {
383 use rand::seq::IteratorRandom;
384 let mut rng = rand::thread_rng();
385
386 let num_nodes_outside_dc = total_nodes
387 - data_centers
388 .get(local_dc)
389 .map(|nodes| nodes.len())
390 .unwrap_or_default();
391 let can_skip_local_dc = num_nodes_outside_dc >= n;
392
393 let num_data_centers = if can_skip_local_dc {
394 data_centers.len() - 1
395 } else {
396 data_centers.len()
397 };
398
399 let mut num_extra_nodes = 0;
400 let selected_dcs = if num_data_centers <= n {
401 num_extra_nodes = n - num_data_centers;
402 data_centers
403 .iter_mut()
404 .filter(|(dc, _)| !(can_skip_local_dc && (dc.as_ref() == local_dc)))
405 .collect::<Vec<_>>()
406 } else {
407 data_centers
408 .iter_mut()
409 .filter(|(dc, _)| !(can_skip_local_dc && (dc.as_ref() == local_dc)))
410 .choose_multiple(&mut rng, n)
411 };
412
413 let mut dc_count = selected_dcs.len();
414 let mut selected_nodes = Nodes::new();
415 for (_, dc_nodes) in selected_dcs.into_iter() {
416 let node = match dc_nodes.next() {
417 Some(node) => {
418 if node == local_node {
419 if dc_nodes.len() <= 1 {
420 num_extra_nodes += 1;
421 dc_count -= 1;
422 continue;
423 }
424
425 dc_nodes.next().unwrap()
426 } else {
427 node
428 }
429 },
430 None => {
432 num_extra_nodes += 1;
433 dc_count -= 1;
434 continue;
435 },
436 };
437
438 selected_nodes.push(node);
439
440 if num_extra_nodes == 0 {
441 continue;
442 }
443
444 let num_extra_nodes_per_dc = num_extra_nodes / cmp::max(dc_count - 1, 1);
445 for _ in 0..num_extra_nodes_per_dc {
446 if let Some(node) = dc_nodes.next() {
447 if node == local_node || selected_nodes.contains(&node) {
448 continue;
449 }
450
451 selected_nodes.push(node);
452 num_extra_nodes -= 1;
453 }
454 }
455
456 dc_count -= 1;
457 }
458
459 if selected_nodes.len() >= n {
460 debug!(selected_node = ?selected_nodes, "Nodes have been selected for the given parameters.");
461 Ok(selected_nodes)
462 } else {
463 warn!(
464 live_nodes = total_nodes - 1,
465 required_node = n,
466 "Failed to meet consistency level due to shortage of live nodes"
467 );
468 Err(ConsistencyError::NotEnoughNodes {
469 live: selected_nodes.len(),
470 required: n,
471 })
472 }
473}
474
475#[derive(Debug)]
476pub struct NodeCycler {
477 cursor: usize,
478 nodes: Nodes,
479}
480
481impl NodeCycler {
482 pub fn extend(&mut self, iter: impl Iterator<Item = SocketAddr>) {
484 self.nodes.extend(iter);
485 }
486
487 pub fn get_nodes_mut(&mut self) -> &mut Nodes {
489 &mut self.nodes
490 }
491
492 pub fn get_nodes(&self) -> &Nodes {
494 &self.nodes
495 }
496
497 #[inline]
498 pub fn len(&self) -> usize {
500 self.nodes.len()
501 }
502}
503
504impl From<Nodes> for NodeCycler {
505 fn from(nodes: Nodes) -> Self {
506 Self { cursor: 0, nodes }
507 }
508}
509
510impl Iterator for NodeCycler {
511 type Item = SocketAddr;
512
513 fn next(&mut self) -> Option<Self::Item> {
514 if self.cursor >= self.nodes.len() {
515 self.cursor = 0;
516 }
517
518 let res = self.nodes.get(self.cursor).copied();
519
520 self.cursor += 1;
521
522 res
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use std::borrow::Cow;
529 use std::collections::BTreeMap;
530 use std::fmt::Display;
531 use std::net::{IpAddr, SocketAddr};
532
533 use crate::nodes_selector::{
534 select_n_nodes,
535 Consistency,
536 DCAwareSelector,
537 NodeCycler,
538 NodeSelector,
539 };
540 use crate::Nodes;
541
542 #[test]
543 fn test_dc_aware_selector() {
544 let mut selector = DCAwareSelector::default();
545
546 let nodes = selector
547 .select_nodes(
548 make_addr(0, 0),
549 "dc-0",
550 3,
551 &mut make_dc(vec![3]),
552 Consistency::None,
553 )
554 .expect("Consistency requirements should be met");
555 assert!(nodes.is_empty(), "No nodes should be selected");
556
557 let nodes = selector
558 .select_nodes(
559 make_addr(0, 0),
560 "dc-0",
561 3,
562 &mut make_dc(vec![3]),
563 Consistency::One,
564 )
565 .expect("Consistency requirements should be met");
566 assert_eq!(nodes.len(), 1, "1 node should be selected");
567
568 let nodes = selector
569 .select_nodes(
570 make_addr(0, 0),
571 "dc-0",
572 3,
573 &mut make_dc(vec![3]),
574 Consistency::Two,
575 )
576 .expect("Consistency requirements should be met");
577 assert_eq!(nodes.len(), 2, "2 nodes should be selected");
578
579 selector
580 .select_nodes(
581 make_addr(0, 0),
582 "dc-0",
583 3,
584 &mut make_dc(vec![3]),
585 Consistency::Three,
586 )
587 .expect_err("Consistency requirements should not be met when only 2 remote nodes available");
588
589 let nodes = selector
590 .select_nodes(
591 make_addr(0, 0),
592 "dc-0",
593 3,
594 &mut make_dc(vec![3]),
595 Consistency::All,
596 )
597 .expect("Consistency requirements should be met");
598 assert_eq!(nodes.len(), 2, "2 nodes should be selected");
599
600 let nodes = selector
601 .select_nodes(
602 make_addr(0, 0),
603 "dc-0",
604 3,
605 &mut make_dc(vec![3]),
606 Consistency::Quorum,
607 )
608 .expect("Consistency requirements should be met");
609 assert_eq!(nodes.len(), 1, "1 node should be selected");
610
611 let nodes = selector
612 .select_nodes(
613 make_addr(0, 0),
614 "dc-0",
615 2,
616 &mut make_dc(vec![2]),
617 Consistency::Quorum,
618 )
619 .expect("Consistency requirements should be met");
620 assert_eq!(nodes.len(), 1, "1 node should be selected");
621
622 let nodes = selector
623 .select_nodes(
624 make_addr(0, 0),
625 "dc-0",
626 1,
627 &mut make_dc(vec![1]),
628 Consistency::Quorum,
629 )
630 .expect("Consistency requirements should be met");
631 assert!(nodes.is_empty(), "0 nodes should be selected");
632
633 let nodes = selector
634 .select_nodes(
635 make_addr(0, 0),
636 "dc-0",
637 4,
638 &mut make_dc(vec![4]),
639 Consistency::Quorum,
640 )
641 .expect("Consistency requirements should be met");
642 assert_eq!(nodes.len(), 2, "2 nodes should be selected");
643
644 let nodes = selector
645 .select_nodes(
646 make_addr(0, 0),
647 "dc-0",
648 5,
649 &mut make_dc(vec![5]),
650 Consistency::Quorum,
651 )
652 .expect("Consistency requirements should be met");
653 assert_eq!(nodes.len(), 2, "2 nodes should be selected");
654
655 let nodes = selector
656 .select_nodes(
657 make_addr(0, 0),
658 "dc-0",
659 3,
660 &mut make_dc(vec![3]),
661 Consistency::LocalQuorum,
662 )
663 .expect("Consistency requirements should be met");
664 assert_eq!(nodes.len(), 1, "1 node should be selected");
665
666 let nodes = selector
667 .select_nodes(
668 make_addr(0, 0),
669 "dc-0",
670 3,
671 &mut make_dc(vec![3]),
672 Consistency::EachQuorum,
673 )
674 .expect("Consistency requirements should be met");
675 assert_eq!(nodes.len(), 1, "1 node should be selected");
676
677 let nodes = selector
678 .select_nodes(
679 make_addr(0, 0),
680 "dc-0",
681 6,
682 &mut make_dc(vec![3, 3]),
683 Consistency::LocalQuorum,
684 )
685 .expect("Consistency requirements should be met");
686 assert_eq!(nodes.len(), 1, "1 node should be selected");
687
688 let nodes = selector
689 .select_nodes(
690 make_addr(0, 0),
691 "dc-0",
692 3,
693 &mut make_dc(vec![3, 3]),
694 Consistency::EachQuorum,
695 )
696 .expect("Consistency requirements should be met");
697 assert_eq!(nodes.len(), 3, "3 nodes should be selected");
698 }
699
700 #[test]
701 fn test_dc_aware_selector_cycling() {
702 let addr = make_addr(0, 0);
703 let total_nodes = 6;
704 let mut dc = make_dc(vec![3, 2, 1]);
705 let mut selector = DCAwareSelector;
706
707 let nodes = selector
708 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::All)
709 .expect("Get nodes");
710 assert_eq!(
711 nodes.len(),
712 total_nodes - 1,
713 "Expected all nodes to be selected except for local node."
714 );
715
716 let nodes = selector
717 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::None)
718 .expect("Get nodes");
719 assert!(nodes.is_empty(), "Expected no nodes to be selected.");
720
721 let nodes = selector
722 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::EachQuorum)
723 .expect("Get nodes");
724 assert_eq!(
725 nodes.as_ref(),
726 [
727 make_addr(0, 1),
728 make_addr(1, 0),
729 make_addr(1, 1),
730 make_addr(2, 0),
731 ]
732 );
733
734 let nodes = selector
735 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::LocalQuorum)
736 .expect("Get nodes");
737 assert_eq!(nodes.as_ref(), [make_addr(0, 1)]);
738
739 let nodes = selector
740 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Quorum)
741 .expect("Get nodes");
742 assert_eq!(
743 nodes.as_ref(),
744 [make_addr(0, 1), make_addr(1, 0), make_addr(2, 0),]
745 );
746
747 let mut dc = make_dc(vec![1]);
748 selector
749 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::One)
750 .expect_err("Node selector should reject consistency level.");
751 let mut dc = make_dc(vec![2]);
752 selector
753 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Two)
754 .expect_err("Node selector should reject consistency level.");
755 let mut dc = make_dc(vec![1, 1]);
756 selector
757 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Two)
758 .expect_err("Node selector should reject consistency level.");
759 let mut dc = make_dc(vec![1, 1, 1]);
760 selector
761 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Three)
762 .expect_err("Node selector should reject consistency level.");
763 let mut dc = make_dc(vec![2, 1]);
764 selector
765 .select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Three)
766 .expect_err("Node selector should reject consistency level.");
767 }
768
769 #[test]
770 fn test_select_n_nodes_equal_dc_count() {
771 let addr = make_addr(0, 0);
772 let total_nodes = 6;
773 let mut dc = make_dc(vec![3, 2, 1]);
774
775 let nodes =
777 select_n_nodes(addr, "dc-0", 3, total_nodes, &mut dc).expect("get nodes");
778 assert_eq!(
779 nodes.as_ref(),
780 [make_addr(1, 0), make_addr(1, 1), make_addr(2, 0),],
781 );
782
783 let nodes =
784 select_n_nodes(addr, "dc-0", 2, total_nodes, &mut dc).expect("get nodes");
785 assert_eq!(nodes.as_ref(), [make_addr(1, 0), make_addr(2, 0)]);
786
787 let nodes =
788 select_n_nodes(addr, "dc-0", 0, total_nodes, &mut dc).expect("get nodes");
789 assert_eq!(nodes.as_ref(), Vec::<SocketAddr>::new());
790
791 let nodes =
793 select_n_nodes(addr, "dc-1", 3, total_nodes, &mut dc).expect("get nodes");
794 assert_eq!(
795 nodes.as_ref(),
796 [make_addr(0, 1), make_addr(0, 2), make_addr(2, 0),],
797 );
798
799 let nodes =
800 select_n_nodes(addr, "dc-1", 2, total_nodes, &mut dc).expect("get nodes");
801 assert_eq!(nodes.as_ref(), [make_addr(0, 1), make_addr(2, 0)]);
802
803 let nodes =
804 select_n_nodes(addr, "dc-1", 0, total_nodes, &mut dc).expect("get nodes");
805 assert_eq!(nodes.as_ref(), Vec::<SocketAddr>::new());
806
807 let nodes =
809 select_n_nodes(addr, "dc-2", 3, total_nodes, &mut dc).expect("get nodes");
810 assert_eq!(
811 nodes.as_ref(),
812 [make_addr(0, 2), make_addr(1, 1), make_addr(1, 0),],
813 );
814
815 let nodes =
816 select_n_nodes(addr, "dc-2", 2, total_nodes, &mut dc).expect("get nodes");
817 assert_eq!(nodes.as_ref(), [make_addr(0, 1), make_addr(1, 1),],);
818
819 let nodes =
820 select_n_nodes(addr, "dc-2", 0, total_nodes, &mut dc).expect("get nodes");
821 assert_eq!(nodes.as_ref(), Vec::<SocketAddr>::new());
822 }
823
824 #[test]
825 fn test_select_n_nodes_less_dc_count() {
826 let addr = make_addr(0, 0);
827 let total_nodes = 5;
828 let mut dc = make_dc(vec![3, 2]);
829
830 let nodes =
832 select_n_nodes(addr, "dc-0", 3, total_nodes, &mut dc).expect("get nodes");
833 assert_eq!(
834 nodes.as_ref(),
835 [make_addr(0, 1), make_addr(0, 2), make_addr(1, 0)],
836 );
837
838 let nodes =
839 select_n_nodes(addr, "dc-0", 2, total_nodes, &mut dc).expect("get nodes");
840 assert_eq!(nodes.as_ref(), [make_addr(1, 1), make_addr(1, 0)]);
841
842 let nodes =
843 select_n_nodes(addr, "dc-0", 0, total_nodes, &mut dc).expect("get nodes");
844 assert_eq!(nodes.as_ref(), Vec::<SocketAddr>::new());
845 }
846
847 fn make_dc(distribution: Vec<usize>) -> BTreeMap<Cow<'static, str>, NodeCycler> {
848 let mut dc = BTreeMap::new();
849 for (dc_n, num_nodes) in distribution.into_iter().enumerate() {
850 let name = to_dc_name(dc_n);
851
852 let mut nodes = Nodes::new();
853 for i in 0..num_nodes {
854 let addr = make_addr(dc_n as u8, i as u8);
855 nodes.push(addr);
856 }
857
858 dc.insert(name, NodeCycler::from(nodes));
859 }
860
861 dc
862 }
863
864 fn make_addr(dc_id: u8, node_n: u8) -> SocketAddr {
865 SocketAddr::new(IpAddr::from([127, dc_id, 0, node_n]), 80)
866 }
867
868 fn to_dc_name(dc: impl Display) -> Cow<'static, str> {
869 Cow::Owned(format!("dc-{}", dc))
870 }
871}