use rand::{Rng, rng};
use tracing::error;
use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
use crate::cluster::ClusterState;
use crate::routing::Shard;
enum PlanState<'a> {
Created,
PickedNone, Picked((NodeRef<'a>, Option<Shard>)),
Fallback {
iter: FallbackPlan<'a>,
target_to_filter_out: (NodeRef<'a>, Option<Shard>),
},
}
pub struct Plan<'a> {
policy: &'a dyn LoadBalancingPolicy,
routing_info: &'a RoutingInfo<'a>,
cluster: &'a ClusterState,
state: PlanState<'a>,
}
impl<'a> Plan<'a> {
pub fn new(
policy: &'a dyn LoadBalancingPolicy,
routing_info: &'a RoutingInfo<'a>,
cluster: &'a ClusterState,
) -> Self {
Self {
policy,
routing_info,
cluster,
state: PlanState::Created,
}
}
fn with_random_shard_if_unknown(
(node, shard): (NodeRef<'_>, Option<Shard>),
) -> (NodeRef<'_>, Shard) {
(
node,
shard.unwrap_or_else(|| {
let nr_shards = node
.sharder()
.map(|sharder| sharder.nr_shards.get())
.unwrap_or(1);
rng().random_range(0..nr_shards).into()
}),
)
}
}
impl<'a> Iterator for Plan<'a> {
type Item = (NodeRef<'a>, Shard);
fn next(&mut self) -> Option<Self::Item> {
match &mut self.state {
PlanState::Created => {
let picked = self.policy.pick(self.routing_info, self.cluster);
if let Some(picked) = picked {
self.state = PlanState::Picked(picked);
Some(Self::with_random_shard_if_unknown(picked))
} else {
let mut iter = self.policy.fallback(self.routing_info, self.cluster);
let first_fallback_node = iter.next();
if let Some(node) = first_fallback_node {
self.state = PlanState::Fallback {
iter,
target_to_filter_out: node,
};
Some(Self::with_random_shard_if_unknown(node))
} else {
error!(
"Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}",
self.routing_info
);
self.state = PlanState::PickedNone;
None
}
}
}
PlanState::Picked(node) => {
self.state = PlanState::Fallback {
iter: self.policy.fallback(self.routing_info, self.cluster),
target_to_filter_out: *node,
};
self.next()
}
PlanState::Fallback {
iter,
target_to_filter_out: node_to_filter_out,
} => {
for node in iter {
if node == *node_to_filter_out {
continue;
} else {
return Some(Self::with_random_shard_if_unknown(node));
}
}
None
}
PlanState::PickedNone => None,
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, str::FromStr, sync::Arc};
use crate::{
cluster::{Node, NodeAddr},
routing::locator::test::{create_locator, mock_metadata_for_token_aware_tests},
test_utils::setup_tracing,
};
use super::*;
fn expected_nodes() -> Vec<(Arc<Node>, Shard)> {
vec![(
Arc::new(Node::new_for_test(
None,
Some(NodeAddr::Translatable(
SocketAddr::from_str("127.0.0.1:9042").unwrap(),
)),
None,
None,
)),
42,
)]
}
#[derive(Debug)]
struct PickingNonePolicy {
expected_nodes: Vec<(Arc<Node>, Shard)>,
}
impl LoadBalancingPolicy for PickingNonePolicy {
fn pick<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> Option<(NodeRef<'a>, Option<Shard>)> {
None
}
fn fallback<'a>(
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterState,
) -> FallbackPlan<'a> {
Box::new(
self.expected_nodes
.iter()
.map(|(node_ref, shard)| (node_ref, Some(*shard))),
)
}
fn name(&self) -> String {
"PickingNone".into()
}
}
#[tokio::test]
async fn plan_calls_fallback_even_if_pick_returned_none() {
setup_tracing();
let policy = PickingNonePolicy {
expected_nodes: expected_nodes(),
};
let locator = create_locator(&mock_metadata_for_token_aware_tests());
let cluster_state = ClusterState {
known_peers: Default::default(),
all_nodes: Default::default(),
keyspaces: Default::default(),
locator,
};
let routing_info = RoutingInfo::default();
let plan = Plan::new(&policy, &routing_info, &cluster_state);
assert_eq!(
Vec::from_iter(plan.map(|(node, shard)| (node.clone(), shard))),
policy.expected_nodes
);
}
}