use std::borrow::Cow;
use std::cmp;
use std::collections::{BTreeMap, HashMap};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
const NODE_CACHE_TIMEOUT: Duration = Duration::from_secs(2);
pub async fn start_node_selector<S>(
local_node: SocketAddr,
local_dc: Cow<'static, str>,
mut selector: S,
) -> NodeSelectorHandle
where
S: NodeSelector + Send + 'static,
{
let (tx, rx) = flume::bounded(100);
tokio::spawn(async move {
let mut total_nodes = 0;
let mut data_centers = BTreeMap::new();
let mut cached_nodes = HashMap::<Consistency, (Instant, Vec<SocketAddr>)>::new();
while let Ok(op) = rx.recv_async().await {
match op {
Op::SetNodes {
data_centers: new_data_centers,
} => {
let mut new_total = 0;
for (name, nodes) in new_data_centers {
new_total += nodes.len();
data_centers.insert(name, NodeCycler::from(nodes));
}
total_nodes = new_total;
info!(
total_nodes = total_nodes,
num_data_centers = data_centers.len(),
"Node selector has updated eligible nodes.",
);
cached_nodes.clear();
},
Op::GetNodes { consistency, tx } => {
if let Some((last_refreshed, nodes)) = cached_nodes.get(&consistency)
{
if last_refreshed.elapsed() < NODE_CACHE_TIMEOUT {
let _ = tx.send(Ok(nodes.clone()));
continue;
}
}
let nodes = selector.select_nodes(
local_node,
&local_dc,
total_nodes,
&mut data_centers,
consistency,
);
if let Ok(ref nodes) = nodes {
cached_nodes
.insert(consistency, (Instant::now(), nodes.clone()));
}
let _ = tx.send(nodes);
},
}
}
info!("Node selector service has shutdown.");
});
NodeSelectorHandle { tx }
}
#[derive(Clone)]
pub struct NodeSelectorHandle {
tx: flume::Sender<Op>,
}
impl NodeSelectorHandle {
pub async fn set_nodes(
&self,
data_centers: BTreeMap<Cow<'static, str>, Vec<SocketAddr>>,
) {
self.tx
.send_async(Op::SetNodes { data_centers })
.await
.expect("contact actor");
}
pub async fn get_nodes(
&self,
consistency: Consistency,
) -> Result<Vec<SocketAddr>, ConsistencyError> {
let (tx, rx) = oneshot::channel();
self.tx
.send_async(Op::GetNodes { consistency, tx })
.await
.expect("contact actor");
rx.await.expect("get actor response")
}
}
enum Op {
SetNodes {
data_centers: BTreeMap<Cow<'static, str>, Vec<SocketAddr>>,
},
GetNodes {
consistency: Consistency,
tx: oneshot::Sender<Result<Vec<SocketAddr>, ConsistencyError>>,
},
}
#[derive(Debug, thiserror::Error)]
pub enum ConsistencyError {
#[error(
"Not enough nodes are present in the cluster to achieve this consistency level."
)]
NotEnoughNodes { live: usize, required: usize },
#[error(
"Failed to achieve the desired consistency level before the timeout \
({timeout:?}) elapsed. Got {responses} responses but needed {required} responses."
)]
ConsistencyFailure {
responses: usize,
required: usize,
timeout: Duration,
},
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Consistency {
None,
One,
Two,
Three,
Quorum,
LocalQuorum,
All,
EachQuorum,
}
pub trait NodeSelector {
fn select_nodes(
&mut self,
local_node: SocketAddr,
local_dc: &str,
total_nodes: usize,
data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
consistency: Consistency,
) -> Result<Vec<SocketAddr>, ConsistencyError>;
}
#[derive(Debug, Copy, Clone, Default)]
pub struct DCAwareSelector;
impl NodeSelector for DCAwareSelector {
fn select_nodes(
&mut self,
local_node: SocketAddr,
local_dc: &str,
total_nodes: usize,
data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
consistency: Consistency,
) -> Result<Vec<SocketAddr>, ConsistencyError> {
let mut selected_nodes = Vec::new();
match consistency {
Consistency::One => {
return select_n_nodes(
local_node,
local_dc,
1,
total_nodes,
data_centers,
)
},
Consistency::Two => {
return select_n_nodes(
local_node,
local_dc,
2,
total_nodes,
data_centers,
)
},
Consistency::Three => {
return select_n_nodes(
local_node,
local_dc,
3,
total_nodes,
data_centers,
)
},
Consistency::Quorum => {
let majority = total_nodes / 2;
let mut dcs_iterators = data_centers
.iter()
.map(|(_, nodes)| {
nodes
.get_nodes()
.iter()
.copied()
.filter(|addr| addr != &local_node)
})
.collect::<Vec<_>>();
let mut previous_total = selected_nodes.len();
while selected_nodes.len() < majority {
let nodes = dcs_iterators.iter_mut().filter_map(|iter| iter.next());
selected_nodes.extend(nodes);
if previous_total == selected_nodes.len() {
return Err(ConsistencyError::NotEnoughNodes {
live: selected_nodes.len(),
required: majority,
});
}
previous_total = selected_nodes.len();
}
},
Consistency::LocalQuorum => {
if let Some(nodes) = data_centers.get(local_dc) {
let majority = nodes.len() / 2;
selected_nodes.extend(
nodes
.get_nodes()
.iter()
.copied()
.filter(|addr| addr != &local_node)
.take(majority),
);
}
},
Consistency::All => selected_nodes.extend(
data_centers
.values()
.flat_map(|cycler| cycler.nodes.clone())
.filter(|addr| addr != &local_node),
),
Consistency::EachQuorum => {
for (name, nodes) in data_centers {
let majority = if name == local_dc {
nodes.len() / 2
} else {
(nodes.len() / 2) + 1
};
selected_nodes.extend(
nodes
.get_nodes()
.iter()
.copied()
.filter(|addr| addr != &local_node)
.take(majority),
);
}
},
Consistency::None => {},
}
Ok(selected_nodes)
}
}
#[instrument(name = "dc-aware-selector")]
fn select_n_nodes(
local_node: SocketAddr,
local_dc: &str,
n: usize,
total_nodes: usize,
data_centers: &mut BTreeMap<Cow<'static, str>, NodeCycler>,
) -> Result<Vec<SocketAddr>, ConsistencyError> {
use rand::seq::IteratorRandom;
let mut rng = rand::thread_rng();
let num_nodes_outside_dc = total_nodes
- data_centers
.get(local_dc)
.map(|nodes| nodes.len())
.unwrap_or_default();
let can_skip_local_dc = num_nodes_outside_dc >= n;
let num_data_centers = if can_skip_local_dc {
data_centers.len() - 1
} else {
data_centers.len()
};
let mut num_extra_nodes = 0;
let selected_dcs = if num_data_centers <= n {
num_extra_nodes = n - num_data_centers;
data_centers
.iter_mut()
.filter(|(dc, _)| !(can_skip_local_dc && (dc.as_ref() == local_dc)))
.collect::<Vec<_>>()
} else {
data_centers
.iter_mut()
.filter(|(dc, _)| !(can_skip_local_dc && (dc.as_ref() == local_dc)))
.choose_multiple(&mut rng, n)
};
let mut dc_count = selected_dcs.len();
let mut selected_nodes = Vec::new();
for (_, dc_nodes) in selected_dcs.into_iter() {
let node = match dc_nodes.next() {
Some(node) => {
if node == local_node {
if dc_nodes.len() <= 1 {
num_extra_nodes += 1;
dc_count -= 1;
continue;
}
dc_nodes.next().unwrap()
} else {
node
}
},
None => {
num_extra_nodes += 1;
dc_count -= 1;
continue;
},
};
selected_nodes.push(node);
if num_extra_nodes == 0 {
continue;
}
let num_extra_nodes_per_dc = num_extra_nodes / cmp::max(dc_count - 1, 1);
for _ in 0..num_extra_nodes_per_dc {
if let Some(node) = dc_nodes.next() {
if selected_nodes.contains(&node) {
continue;
}
selected_nodes.push(node);
num_extra_nodes -= 1;
}
}
dc_count -= 1;
}
if selected_nodes.len() >= n {
debug!(selected_node = ?selected_nodes, "Nodes have been selected for the given parameters.");
Ok(selected_nodes)
} else {
warn!(
live_nodes = total_nodes - 1,
required_node = n,
"Failed to meet consistency level due to shortage of live nodes"
);
Err(ConsistencyError::NotEnoughNodes {
live: selected_nodes.len(),
required: n,
})
}
}
#[derive(Debug)]
pub struct NodeCycler {
cursor: usize,
nodes: Vec<SocketAddr>,
}
impl NodeCycler {
pub fn extend(&mut self, iter: impl Iterator<Item = SocketAddr>) {
self.nodes.extend(iter);
}
pub fn get_nodes_mut(&mut self) -> &mut Vec<SocketAddr> {
&mut self.nodes
}
pub fn get_nodes(&self) -> &Vec<SocketAddr> {
&self.nodes
}
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
}
impl From<Vec<SocketAddr>> for NodeCycler {
fn from(nodes: Vec<SocketAddr>) -> Self {
Self { cursor: 0, nodes }
}
}
impl Iterator for NodeCycler {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
if self.cursor >= self.nodes.len() {
self.cursor = 0;
}
let res = self.nodes.get(self.cursor).copied();
self.cursor += 1;
res
}
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, SocketAddr};
use crate::nodes_selector::{
select_n_nodes,
Consistency,
DCAwareSelector,
NodeCycler,
NodeSelector,
};
#[test]
fn test_dc_aware_selector() {
let addr = make_addr(0, 0);
let total_nodes = 6;
let mut dc = make_dc(vec![3, 2, 1]);
let mut selector = DCAwareSelector;
let nodes = selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::All)
.expect("Get nodes");
assert_eq!(
nodes.len(),
total_nodes - 1,
"Expected all nodes to be selected except for local node."
);
let nodes = selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::None)
.expect("Get nodes");
assert!(nodes.is_empty(), "Expected no nodes to be selected.");
let nodes = selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::EachQuorum)
.expect("Get nodes");
assert_eq!(
nodes,
vec![
make_addr(0, 1),
make_addr(1, 0),
make_addr(1, 1),
make_addr(2, 0),
]
);
let nodes = selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::LocalQuorum)
.expect("Get nodes");
assert_eq!(nodes, vec![make_addr(0, 1)]);
let nodes = selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Quorum)
.expect("Get nodes");
assert_eq!(
nodes,
vec![make_addr(0, 1), make_addr(1, 0), make_addr(2, 0),]
);
let mut dc = make_dc(vec![1]);
selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::One)
.expect_err("Node selector should reject consistency level.");
let mut dc = make_dc(vec![2]);
selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Two)
.expect_err("Node selector should reject consistency level.");
let mut dc = make_dc(vec![1, 1]);
selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Two)
.expect_err("Node selector should reject consistency level.");
let mut dc = make_dc(vec![1, 1, 1]);
selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Three)
.expect_err("Node selector should reject consistency level.");
let mut dc = make_dc(vec![2, 1]);
selector
.select_nodes(addr, "dc-0", total_nodes, &mut dc, Consistency::Three)
.expect_err("Node selector should reject consistency level.");
}
#[test]
fn test_select_n_nodes_equal_dc_count() {
let addr = make_addr(0, 0);
let total_nodes = 6;
let mut dc = make_dc(vec![3, 2, 1]);
let nodes =
select_n_nodes(addr, "dc-0", 3, total_nodes, &mut dc).expect("get nodes");
assert_eq!(
nodes,
vec![make_addr(1, 0), make_addr(1, 1), make_addr(2, 0),],
);
let nodes =
select_n_nodes(addr, "dc-0", 2, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, vec![make_addr(1, 0), make_addr(2, 0),],);
let nodes =
select_n_nodes(addr, "dc-0", 0, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, Vec::<SocketAddr>::new());
let nodes =
select_n_nodes(addr, "dc-1", 3, total_nodes, &mut dc).expect("get nodes");
assert_eq!(
nodes,
vec![make_addr(0, 1), make_addr(0, 2), make_addr(2, 0),],
);
let nodes =
select_n_nodes(addr, "dc-1", 2, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, vec![make_addr(0, 1), make_addr(2, 0),],);
let nodes =
select_n_nodes(addr, "dc-1", 0, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, Vec::<SocketAddr>::new());
let nodes =
select_n_nodes(addr, "dc-2", 3, total_nodes, &mut dc).expect("get nodes");
assert_eq!(
nodes,
vec![make_addr(0, 2), make_addr(0, 0), make_addr(1, 1),],
);
let nodes =
select_n_nodes(addr, "dc-2", 2, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, vec![make_addr(0, 1), make_addr(1, 0),],);
let nodes =
select_n_nodes(addr, "dc-2", 0, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, Vec::<SocketAddr>::new());
}
#[test]
fn test_select_n_nodes_less_dc_count() {
let addr = make_addr(0, 0);
let total_nodes = 5;
let mut dc = make_dc(vec![3, 2]);
let nodes =
select_n_nodes(addr, "dc-0", 3, total_nodes, &mut dc).expect("get nodes");
assert_eq!(
nodes,
vec![make_addr(0, 1), make_addr(0, 2), make_addr(1, 0),],
);
let nodes =
select_n_nodes(addr, "dc-0", 2, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, vec![make_addr(1, 1), make_addr(1, 0),],);
let nodes =
select_n_nodes(addr, "dc-0", 0, total_nodes, &mut dc).expect("get nodes");
assert_eq!(nodes, Vec::<SocketAddr>::new());
}
fn make_dc(distribution: Vec<usize>) -> BTreeMap<Cow<'static, str>, NodeCycler> {
let mut dc = BTreeMap::new();
for (dc_n, num_nodes) in distribution.into_iter().enumerate() {
let name = to_dc_name(dc_n);
let mut nodes = Vec::new();
for i in 0..num_nodes {
let addr = make_addr(dc_n as u8, i as u8);
nodes.push(addr);
}
dc.insert(name, NodeCycler::from(nodes));
}
dc
}
fn make_addr(dc_id: u8, node_n: u8) -> SocketAddr {
SocketAddr::new(IpAddr::from([127, dc_id, 0, node_n]), 80)
}
fn to_dc_name(dc: impl Display) -> Cow<'static, str> {
Cow::Owned(format!("dc-{}", dc))
}
}