use std::{
collections::HashSet,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use crate::agent::route_provider::{
dynamic_routing::{
health_check::HealthCheckStatus, node::Node, snapshot::routing_snapshot::RoutingSnapshot,
},
RoutesStats,
};
#[derive(Default, Debug, Clone)]
pub struct RoundRobinRoutingSnapshot {
current_idx: Arc<AtomicUsize>,
healthy_nodes: HashSet<Node>,
existing_nodes: HashSet<Node>,
}
impl RoundRobinRoutingSnapshot {
pub fn new() -> Self {
Self {
current_idx: Arc::new(AtomicUsize::new(0)),
healthy_nodes: HashSet::new(),
existing_nodes: HashSet::new(),
}
}
}
impl RoutingSnapshot for RoundRobinRoutingSnapshot {
fn has_nodes(&self) -> bool {
!self.healthy_nodes.is_empty()
}
fn next_node(&self) -> Option<Node> {
if self.healthy_nodes.is_empty() {
return None;
}
let prev_idx = self.current_idx.fetch_add(1, Ordering::Relaxed);
self.healthy_nodes
.iter()
.nth(prev_idx % self.healthy_nodes.len())
.cloned()
}
fn next_n_nodes(&self, n: usize) -> Vec<Node> {
if n == 0 {
return Vec::new();
}
let healthy_nodes = Vec::from_iter(self.healthy_nodes.clone());
let healthy_count = healthy_nodes.len();
if n >= healthy_count {
return healthy_nodes.clone();
}
let idx = self.current_idx.fetch_add(n, Ordering::Relaxed) % healthy_count;
let mut nodes = Vec::with_capacity(n);
if healthy_count - idx >= n {
nodes.extend_from_slice(&healthy_nodes[idx..idx + n]);
} else {
nodes.extend_from_slice(&healthy_nodes[idx..]);
nodes.extend_from_slice(&healthy_nodes[..n - nodes.len()]);
}
nodes
}
fn sync_nodes(&mut self, nodes: &[Node]) -> bool {
let new_nodes = HashSet::from_iter(nodes.iter().cloned());
let nodes_removed: Vec<_> = self
.existing_nodes
.difference(&new_nodes)
.cloned()
.collect();
let has_removed_nodes = !nodes_removed.is_empty();
let nodes_added: Vec<_> = new_nodes
.difference(&self.existing_nodes)
.cloned()
.collect();
let has_added_nodes = !nodes_added.is_empty();
self.existing_nodes.extend(nodes_added);
nodes_removed.iter().for_each(|node| {
self.existing_nodes.remove(node);
self.healthy_nodes.remove(node);
});
has_added_nodes || has_removed_nodes
}
fn update_node(&mut self, node: &Node, health: HealthCheckStatus) -> bool {
if !self.existing_nodes.contains(node) {
return false;
}
if health.is_healthy() {
self.healthy_nodes.insert(node.clone())
} else {
self.healthy_nodes.remove(node)
}
}
fn routes_stats(&self) -> RoutesStats {
RoutesStats::new(self.existing_nodes.len(), Some(self.healthy_nodes.len()))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::slice;
use std::time::Duration;
use std::{collections::HashSet, sync::atomic::Ordering};
use crate::agent::route_provider::dynamic_routing::{
health_check::HealthCheckStatus,
node::Node,
snapshot::{
round_robin_routing::RoundRobinRoutingSnapshot, routing_snapshot::RoutingSnapshot,
},
};
#[test]
fn test_snapshot_init() {
let snapshot = RoundRobinRoutingSnapshot::new();
assert!(snapshot.healthy_nodes.is_empty());
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert_eq!(snapshot.current_idx.load(Ordering::SeqCst), 0);
assert!(snapshot.next_node().is_none());
}
#[test]
fn test_update_of_non_existing_node_always_returns_false() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node = Node::new("api1.com").unwrap();
let healthy = HealthCheckStatus::new(Some(Duration::from_secs(1)));
let unhealthy = HealthCheckStatus::new(None);
let is_updated = snapshot.update_node(&node, healthy);
assert!(!is_updated);
assert!(snapshot.existing_nodes.is_empty());
assert!(snapshot.next_node().is_none());
let is_updated = snapshot.update_node(&node, unhealthy);
assert!(!is_updated);
assert!(snapshot.existing_nodes.is_empty());
assert!(snapshot.next_node().is_none());
}
#[test]
fn test_update_of_existing_unhealthy_node_with_healthy_node_returns_true() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node = Node::new("api1.com").unwrap();
snapshot.existing_nodes.insert(node.clone());
let health = HealthCheckStatus::new(Some(Duration::from_secs(1)));
let is_updated = snapshot.update_node(&node, health);
assert!(is_updated);
assert!(snapshot.has_nodes());
assert_eq!(snapshot.next_node().unwrap(), node);
assert_eq!(snapshot.current_idx.load(Ordering::SeqCst), 1);
}
#[test]
fn test_update_of_existing_healthy_node_with_unhealthy_node_returns_true() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node = Node::new("api1.com").unwrap();
snapshot.existing_nodes.insert(node.clone());
snapshot.healthy_nodes.insert(node.clone());
let unhealthy = HealthCheckStatus::new(None);
let is_updated = snapshot.update_node(&node, unhealthy);
assert!(is_updated);
assert!(!snapshot.has_nodes());
assert!(snapshot.next_node().is_none());
}
#[test]
fn test_sync_node_scenarios() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_1));
assert!(nodes_changed);
assert!(snapshot.healthy_nodes.is_empty());
assert_eq!(
snapshot.existing_nodes,
HashSet::from_iter(vec![node_1.clone()])
);
snapshot.healthy_nodes.insert(node_1.clone());
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_1));
assert!(!nodes_changed);
assert_eq!(
snapshot.existing_nodes,
HashSet::from_iter(vec![node_1.clone()])
);
assert_eq!(snapshot.healthy_nodes, HashSet::from_iter(vec![node_1]));
let node_2 = Node::new("api2.com").unwrap();
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_2));
assert!(nodes_changed);
assert_eq!(
snapshot.existing_nodes,
HashSet::from_iter(vec![node_2.clone()])
);
assert!(snapshot.healthy_nodes.is_empty());
snapshot.healthy_nodes.insert(node_2.clone());
let node_3 = Node::new("api3.com").unwrap();
let nodes_changed = snapshot.sync_nodes(&[node_3.clone(), node_2.clone()]);
assert!(nodes_changed);
assert_eq!(
snapshot.existing_nodes,
HashSet::from_iter(vec![node_3.clone(), node_2.clone()])
);
assert_eq!(snapshot.healthy_nodes, HashSet::from_iter(vec![node_2]));
snapshot.healthy_nodes.insert(node_3);
let nodes_changed = snapshot.sync_nodes(&[]);
assert!(nodes_changed);
assert!(snapshot.existing_nodes.is_empty());
assert!(snapshot.healthy_nodes.is_empty());
let nodes_changed = snapshot.sync_nodes(&[]);
assert!(!nodes_changed);
assert!(snapshot.existing_nodes.is_empty());
}
#[test]
fn test_next_node() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let node_2 = Node::new("api2.com").unwrap();
let node_3 = Node::new("api3.com").unwrap();
let nodes = vec![node_1, node_2, node_3];
snapshot.existing_nodes.extend(nodes.clone());
snapshot.healthy_nodes.extend(nodes.clone());
let n = 6;
let mut count_map = HashMap::new();
for _ in 0..n {
let node = snapshot.next_node().unwrap();
count_map.entry(node).and_modify(|v| *v += 1).or_insert(1);
}
let k = 2;
assert_eq!(
count_map.len(),
nodes.len(),
"The number of unique elements is not {}",
nodes.len()
);
for (item, &count) in &count_map {
assert_eq!(
count, k,
"Element {:?} does not appear exactly {} times",
item, k
);
}
}
#[test]
fn test_n_nodes() {
let mut snapshot = RoundRobinRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let node_2 = Node::new("api2.com").unwrap();
let node_3 = Node::new("api3.com").unwrap();
let node_4 = Node::new("api4.com").unwrap();
let node_5 = Node::new("api5.com").unwrap();
let nodes = vec![
node_1.clone(),
node_2.clone(),
node_3.clone(),
node_4.clone(),
node_5.clone(),
];
snapshot.healthy_nodes.extend(nodes.clone());
let mut n_nodes: Vec<_> = snapshot.next_n_nodes(3);
n_nodes.extend(snapshot.next_n_nodes(3));
n_nodes.extend(snapshot.next_n_nodes(4));
n_nodes.extend(snapshot.next_n_nodes(5));
let k = 3;
let mut count_map = HashMap::new();
for item in n_nodes.iter() {
count_map.entry(item).and_modify(|v| *v += 1).or_insert(1);
}
assert_eq!(
count_map.len(),
nodes.len(),
"The number of unique elements is not {}",
nodes.len()
);
for (item, &count) in &count_map {
assert_eq!(
count, k,
"Element {:?} does not appear exactly {} times",
item, k
);
}
}
}