use std::{
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
time::Duration,
};
use arc_swap::ArcSwap;
use rand::RngExt;
use crate::agent::route_provider::{
dynamic_routing::{
health_check::HealthCheckStatus, node::Node, snapshot::routing_snapshot::RoutingSnapshot,
},
RoutesStats,
};
const WINDOW_SIZE: usize = 15;
const LAMBDA_DECAY: f64 = 0.3;
fn generate_exp_decaying_weights(n: usize, lambda: f64) -> Vec<f64> {
let mut weights: Vec<f64> = Vec::with_capacity(n);
for i in 0..n {
let weight = (-lambda * i as f64).exp();
weights.push(weight);
}
weights
}
#[derive(Clone, Debug)]
struct RoutingCandidateNode {
node: Node,
score: f64,
}
impl RoutingCandidateNode {
fn new(node: Node, score: f64) -> Self {
Self { node, score }
}
}
#[derive(Clone, Debug)]
struct NodeMetrics {
window_size: usize,
is_healthy: bool,
latencies: VecDeque<f64>,
availabilities: VecDeque<bool>,
score: f64,
}
impl NodeMetrics {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
is_healthy: false,
latencies: VecDeque::with_capacity(window_size + 1),
availabilities: VecDeque::with_capacity(window_size + 1),
score: 0.0,
}
}
pub fn add_latency_measurement(&mut self, latency: Option<Duration>) {
self.is_healthy = latency.is_some();
if let Some(duration) = latency {
self.latencies.push_back(duration.as_secs_f64());
while self.latencies.len() > self.window_size {
self.latencies.pop_front();
}
self.availabilities.push_back(true);
} else {
self.availabilities.push_back(false);
}
while self.availabilities.len() > self.window_size {
self.availabilities.pop_front();
}
}
}
fn compute_score(
window_weights: &[f64],
window_weights_sum: f64,
availabilities: &VecDeque<bool>,
latencies: &VecDeque<f64>,
use_availability_penalty: bool,
) -> f64 {
let weights_size = window_weights.len();
let availabilities_size = availabilities.len();
let latencies_size = latencies.len();
if weights_size < availabilities_size {
panic!(
"Configuration error: Weights array of size {weights_size} is smaller than array of availabilities of size {availabilities_size}.",
);
} else if weights_size < latencies_size {
panic!(
"Configuration error: Weights array of size {weights_size} is smaller than array of latencies of size {latencies_size}.",
);
}
let score_a = if !use_availability_penalty {
1.0
} else if availabilities.is_empty() {
0.0
} else {
let mut score = 0.0;
for (idx, availability) in availabilities.iter().rev().enumerate() {
score += window_weights[idx] * (*availability as u8 as f64);
}
let weights_sum = if availabilities_size < weights_size {
let partial_weights_sum: f64 = window_weights.iter().take(availabilities_size).sum();
partial_weights_sum
} else {
window_weights_sum
};
score /= weights_sum;
score
};
let score_l = if latencies.is_empty() {
0.0
} else {
let mut score = 0.0;
for (idx, latency) in latencies.iter().rev().enumerate() {
score += window_weights[idx] / latency;
}
let weights_sum = if latencies_size < weights_size {
let partial_weights_sum: f64 = window_weights.iter().take(latencies.len()).sum();
partial_weights_sum
} else {
window_weights_sum
};
score /= weights_sum;
score
};
score_l * score_a
}
#[derive(Default, Debug, Clone)]
pub struct LatencyRoutingSnapshot {
k_top_nodes: Option<usize>,
existing_nodes: HashMap<Node, NodeMetrics>,
routing_candidates: Arc<ArcSwap<Vec<RoutingCandidateNode>>>,
window_weights: Vec<f64>,
window_weights_sum: f64,
use_availability_penalty: bool,
}
impl LatencyRoutingSnapshot {
pub fn new() -> Self {
let window_weights = generate_exp_decaying_weights(WINDOW_SIZE, LAMBDA_DECAY);
let window_weights_sum: f64 = window_weights.iter().sum();
Self {
k_top_nodes: None,
existing_nodes: HashMap::new(),
routing_candidates: Arc::new(ArcSwap::new(vec![].into())),
use_availability_penalty: true,
window_weights,
window_weights_sum,
}
}
pub fn set_k_top_nodes(mut self, k_top_nodes: usize) -> Self {
self.k_top_nodes = Some(k_top_nodes);
self
}
#[allow(unused)]
pub fn set_availability_penalty(mut self, use_penalty: bool) -> Self {
self.use_availability_penalty = use_penalty;
self
}
#[allow(unused)]
pub fn set_window_weights(mut self, weights: &[f64]) -> Self {
self.window_weights_sum = weights.iter().sum();
self.window_weights = weights.to_vec();
self
}
fn publish_routing_candidates(&self) {
let mut routing_candidates: Vec<RoutingCandidateNode> = self
.existing_nodes
.iter()
.filter(|(_, v)| v.is_healthy)
.map(|(k, v)| RoutingCandidateNode::new(k.clone(), v.score))
.collect();
if let Some(k_top) = self.k_top_nodes {
routing_candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if routing_candidates.len() > k_top {
routing_candidates.truncate(k_top);
}
}
self.routing_candidates.store(Arc::new(routing_candidates));
}
}
#[inline(always)]
fn weighted_sample(weighted_nodes: &[RoutingCandidateNode], number: f64) -> Option<usize> {
if !(0.0..=1.0).contains(&number) || weighted_nodes.is_empty() {
return None;
}
let sum: f64 = weighted_nodes.iter().map(|n| n.score).sum();
if sum == 0.0 {
return None;
}
let mut weighted_number = number * sum;
for (idx, node) in weighted_nodes.iter().enumerate() {
weighted_number -= node.score;
if weighted_number <= 0.0 {
return Some(idx);
}
}
Some(weighted_nodes.len() - 1)
}
impl RoutingSnapshot for LatencyRoutingSnapshot {
fn has_nodes(&self) -> bool {
!self.routing_candidates.load().is_empty()
}
fn next_node(&self) -> Option<Node> {
self.next_n_nodes(1).into_iter().next()
}
fn next_n_nodes(&self, n: usize) -> Vec<Node> {
if n == 0 {
return Vec::new();
}
let mut routing_candidates: Vec<RoutingCandidateNode> =
self.routing_candidates.load().as_ref().clone();
let n = std::cmp::min(n, routing_candidates.len());
let mut nodes = Vec::with_capacity(n);
let mut rng = rand::rng();
for _ in 0..n {
let rand_num = rng.random::<f64>();
if let Some(idx) = weighted_sample(routing_candidates.as_slice(), rand_num) {
let removed_node = routing_candidates.swap_remove(idx);
nodes.push(removed_node.node);
}
}
nodes
}
fn sync_nodes(&mut self, nodes: &[Node]) -> bool {
let new_nodes: HashSet<&Node> = nodes.iter().collect();
let mut has_changes = false;
self.existing_nodes.retain(|node, _| {
let keep = new_nodes.contains(node);
if !keep {
has_changes = true;
}
keep
});
for node in nodes {
if !self.existing_nodes.contains_key(node) {
self.existing_nodes
.insert(node.clone(), NodeMetrics::new(self.window_weights.len()));
has_changes = true;
}
}
if has_changes {
self.publish_routing_candidates();
}
has_changes
}
fn update_node(&mut self, node: &Node, health: HealthCheckStatus) -> bool {
let updated_node: &mut NodeMetrics = match self.existing_nodes.get_mut(node) {
Some(metrics) => metrics,
None => return false,
};
updated_node.add_latency_measurement(health.latency());
updated_node.score = compute_score(
&self.window_weights,
self.window_weights_sum,
&updated_node.availabilities,
&updated_node.latencies,
self.use_availability_penalty,
);
self.publish_routing_candidates();
true
}
fn routes_stats(&self) -> RoutesStats {
RoutesStats::new(
self.existing_nodes.len(),
Some(self.routing_candidates.load().len()),
)
}
}
#[cfg(test)]
mod tests {
use std::{
collections::{HashMap, VecDeque},
slice,
time::Duration,
};
use crate::agent::route_provider::{
dynamic_routing::{
health_check::HealthCheckStatus,
node::Node,
snapshot::{
latency_based_routing::{
compute_score, weighted_sample, LatencyRoutingSnapshot, NodeMetrics,
RoutingCandidateNode,
},
routing_snapshot::RoutingSnapshot,
},
},
RoutesStats,
};
#[test]
fn test_snapshot_init() {
let snapshot = LatencyRoutingSnapshot::new();
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next_node().is_none());
assert!(snapshot.next_n_nodes(1).is_empty());
assert_eq!(snapshot.routes_stats(), RoutesStats::new(0, Some(0)));
}
#[test]
fn test_update_for_non_existing_node_fails() {
let mut snapshot = LatencyRoutingSnapshot::new();
let node = Node::new("api1.com").unwrap();
let health = HealthCheckStatus::new(Some(Duration::from_secs(1)));
let is_updated = snapshot.update_node(&node, health);
assert!(!is_updated);
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next_node().is_none());
assert_eq!(snapshot.routes_stats(), RoutesStats::new(0, Some(0)));
}
#[test]
fn test_update_for_existing_node_succeeds() {
let mut snapshot = LatencyRoutingSnapshot::new()
.set_window_weights(&[2.0, 1.0])
.set_availability_penalty(false);
let node = Node::new("api1.com").unwrap();
let health = HealthCheckStatus::new(Some(Duration::from_secs(1)));
snapshot.sync_nodes(slice::from_ref(&node));
assert_eq!(snapshot.routes_stats(), RoutesStats::new(1, Some(0)));
let is_updated = snapshot.update_node(&node, health);
assert!(is_updated);
assert!(snapshot.has_nodes());
let metrics = snapshot.existing_nodes.get(&node).unwrap();
assert_eq!(metrics.score, (2.0 / 1.0) / 2.0);
assert_eq!(snapshot.next_node().unwrap(), node);
assert_eq!(snapshot.routes_stats(), RoutesStats::new(1, Some(1)));
let health = HealthCheckStatus::new(Some(Duration::from_secs(2)));
let is_updated = snapshot.update_node(&node, health);
assert!(is_updated);
let metrics = snapshot.existing_nodes.get(&node).unwrap();
assert_eq!(metrics.score, (2.0 / 2.0 + 1.0 / 1.0) / 3.0);
let health = HealthCheckStatus::new(None);
let is_updated = snapshot.update_node(&node, health);
assert!(is_updated);
let metrics = snapshot.existing_nodes.get(&node).unwrap();
assert_eq!(metrics.score, (2.0 / 2.0 + 1.0 / 1.0) / 3.0);
assert!(!snapshot.has_nodes());
assert_eq!(snapshot.existing_nodes.len(), 1);
assert!(snapshot.next_node().is_none());
assert_eq!(snapshot.routes_stats(), RoutesStats::new(1, Some(0)));
let health = HealthCheckStatus::new(Some(Duration::from_secs(3)));
let is_updated = snapshot.update_node(&node, health);
assert!(is_updated);
let metrics = snapshot.existing_nodes.get(&node).unwrap();
assert_eq!(metrics.score, (2.0 / 3.0 + 1.0 / 2.0) / 3.0);
}
#[test]
fn test_sync_node_scenarios() {
let mut snapshot = LatencyRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_1));
assert!(nodes_changed);
assert!(snapshot.existing_nodes.contains_key(&node_1));
assert!(!snapshot.has_nodes());
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_1));
assert!(!nodes_changed);
assert_eq!(
snapshot.existing_nodes.keys().collect::<Vec<_>>(),
vec![&node_1]
);
let node_2 = Node::new("api2.com").unwrap();
let nodes_changed = snapshot.sync_nodes(slice::from_ref(&node_2));
assert!(nodes_changed);
assert_eq!(
snapshot.existing_nodes.keys().collect::<Vec<_>>(),
vec![&node_2]
);
assert!(!snapshot.has_nodes());
let node_3 = Node::new("api3.com").unwrap();
let nodes_changed = snapshot.sync_nodes(&[node_3.clone(), node_2.clone()]);
assert!(nodes_changed);
let mut keys = snapshot.existing_nodes.keys().collect::<Vec<_>>();
keys.sort_by(|a, b| a.domain().cmp(b.domain()));
assert_eq!(keys, vec![&node_2, &node_3]);
assert!(!snapshot.has_nodes());
let nodes_changed = snapshot.sync_nodes(&[node_3.clone(), node_2.clone()]);
assert!(!nodes_changed);
let mut keys = snapshot.existing_nodes.keys().collect::<Vec<_>>();
keys.sort_by(|a, b| a.domain().cmp(b.domain()));
assert_eq!(keys, vec![&node_2, &node_3]);
assert!(!snapshot.has_nodes());
let nodes_changed = snapshot.sync_nodes(&[]);
assert!(nodes_changed);
assert!(snapshot.existing_nodes.is_empty());
let nodes_changed = snapshot.sync_nodes(&[]);
assert!(!nodes_changed);
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
}
#[test]
fn test_weighted_sample() {
let node = Node::new("api1.com").unwrap();
let arr: &[RoutingCandidateNode] = &[];
let idx = weighted_sample(arr, 0.5);
assert_eq!(idx, None);
let arr = &[RoutingCandidateNode::new(node.clone(), 1.0)];
let idx = weighted_sample(arr, 0.0);
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 1.0);
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, -1.0);
assert_eq!(idx, None);
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
let arr = &[
RoutingCandidateNode::new(node.clone(), 1.0),
RoutingCandidateNode::new(node.clone(), 2.0),
]; let idx = weighted_sample(arr, 0.0); assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 0.33); assert_eq!(idx, Some(0)); let idx = weighted_sample(arr, 0.35); assert_eq!(idx, Some(1)); let idx = weighted_sample(arr, 1.0); assert_eq!(idx, Some(1));
let idx = weighted_sample(arr, -1.0);
assert_eq!(idx, None);
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
let arr = &[
RoutingCandidateNode::new(node.clone(), 1.0),
RoutingCandidateNode::new(node.clone(), 2.0),
RoutingCandidateNode::new(node.clone(), 1.5),
RoutingCandidateNode::new(node.clone(), 2.5),
]; let idx = weighted_sample(arr, 0.14); assert_eq!(idx, Some(0)); let idx = weighted_sample(arr, 0.15); assert_eq!(idx, Some(1));
let idx = weighted_sample(arr, 0.42); assert_eq!(idx, Some(1)); let idx = weighted_sample(arr, 0.43); assert_eq!(idx, Some(2));
let idx = weighted_sample(arr, 0.64); assert_eq!(idx, Some(2)); let idx = weighted_sample(arr, 0.65); assert_eq!(idx, Some(3));
let idx = weighted_sample(arr, 0.99);
assert_eq!(idx, Some(3)); let idx = weighted_sample(arr, -1.0);
assert_eq!(idx, None);
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
}
#[test]
fn test_compute_score_with_penalty() {
let use_penalty = true;
let weights: &[f64] = &[];
let weights_sum: f64 = weights.iter().sum();
let availabilities = VecDeque::new();
let latencies = VecDeque::new();
let score = compute_score(
weights,
weights_sum,
&availabilities,
&latencies,
use_penalty,
);
assert_eq!(score, 0.0);
let weights: &[f64] = &[2.0, 1.0];
let weights_sum: f64 = weights.iter().sum();
let availabilities = vec![true].into();
let latencies = vec![2.0].into();
let score = compute_score(
weights,
weights_sum,
&availabilities,
&latencies,
use_penalty,
);
let score_l = (2.0 / 2.0) / 2.0;
let score_a = 1.0;
assert_eq!(score, score_l * score_a);
let weights: &[f64] = &[2.0, 1.0];
let weights_sum: f64 = weights.iter().sum();
let availabilities = vec![true, false].into();
let latencies = vec![1.0, 2.0].into();
let score = compute_score(
weights,
weights_sum,
&availabilities,
&latencies,
use_penalty,
);
let score_l = (2.0 / 2.0 + 1.0 / 1.0) / weights_sum;
let score_a = (2.0 * 0.0 + 1.0 * 1.0) / weights_sum;
assert_eq!(score, score_l * score_a);
let weights: &[f64] = &[3.0, 2.0, 1.0];
let weights_sum: f64 = weights.iter().sum();
let availabilities = vec![true, false, true].into();
let latencies = vec![1.0, 2.0].into();
let score = compute_score(
weights,
weights_sum,
&availabilities,
&latencies,
use_penalty,
);
let score_l = (3.0 / 2.0 + 2.0 / 1.0) / 5.0;
let score_a = (3.0 * 1.0 + 2.0 * 0.0 + 1.0 * 1.0) / weights_sum;
assert_eq!(score, score_l * score_a);
}
#[test]
#[ignore]
fn test_stats_for_next_n_nodes() {
let mut snapshot = LatencyRoutingSnapshot::new();
let window_size = 1;
let node_1 = Node::new("api1.com").unwrap();
let node_2 = Node::new("api2.com").unwrap();
let node_3 = Node::new("api3.com").unwrap();
let node_4 = Node::new("api4.com").unwrap();
let mut metrics_1 = NodeMetrics::new(window_size);
let mut metrics_2 = NodeMetrics::new(window_size);
let mut metrics_3 = NodeMetrics::new(window_size);
let mut metrics_4 = NodeMetrics::new(window_size);
metrics_1.is_healthy = true;
metrics_2.is_healthy = true;
metrics_3.is_healthy = true;
metrics_4.is_healthy = false;
metrics_1.score = 16.0;
metrics_2.score = 8.0;
metrics_3.score = 4.0;
metrics_4.score = 30.0;
snapshot.existing_nodes.extend(vec![
(node_1, metrics_1),
(node_2, metrics_2),
(node_3, metrics_3),
(node_4, metrics_4),
]);
snapshot.publish_routing_candidates();
let mut stats = HashMap::new();
let experiments = 30;
let select_nodes_count = 1;
for i in 0..experiments {
let nodes = snapshot.next_n_nodes(select_nodes_count);
println!("Experiment {i}: selected nodes {nodes:?}");
for item in nodes.into_iter() {
*stats.entry(item).or_insert(1) += 1;
}
}
for (node, count) in stats {
println!(
"Node {:?} is selected with probability {}",
node.domain(),
count as f64 / experiments as f64
);
}
}
}