volo 0.12.3

Volo is a high-performance and strong-extensibility Rust RPC framework that helps developers build microservices.
Documentation
use std::{hash::Hash, sync::Arc};

use dashmap::{DashMap, mapref::entry::Entry};
use rand::Rng;

use super::{LoadBalance, error::LoadBalanceError};
use crate::{
    context::Endpoint,
    discovery::{Change, Discover, Instance},
    net::Address,
};

#[inline]
fn pick_one(
    sum_of_weight: usize,
    prefix_sum_of_weights: &[usize],
    instances: &[Arc<Instance>],
) -> Option<(usize, Arc<Instance>)> {
    if sum_of_weight == 0 {
        return None;
    }
    if instances.is_empty() {
        return None;
    }
    let weight = rand::rng().random_range(0..sum_of_weight) + 1;
    let index = prefix_sum_of_weights
        .binary_search(&weight)
        .unwrap_or_else(|index| index);
    Some((index, instances[index].clone()))
}

#[derive(Debug)]
pub struct InstancePicker {
    shared_instances: Arc<WeightedInstances>,
    sum_of_weights: usize,
    last_offset: Option<usize>,
    iter_times: usize,
}

impl Iterator for InstancePicker {
    type Item = Address;

    fn next(&mut self) -> Option<Self::Item> {
        let shared_instances = &self.shared_instances.instances;
        let prefix_sum_of_weights = &self.shared_instances.prefix_sum_of_weights;
        if shared_instances.is_empty() {
            return None;
        }
        self.iter_times += 1;
        match &mut self.last_offset {
            None => {
                let (offset, instance) =
                    pick_one(self.sum_of_weights, prefix_sum_of_weights, shared_instances)?;
                self.last_offset = Some(offset);
                Some(instance.address.clone())
            }
            Some(last_offset) => {
                if self.iter_times > shared_instances.len() {
                    return None;
                }
                let mut offset = *last_offset + 1;
                if offset == shared_instances.len() {
                    offset = 0;
                }
                *last_offset = offset;
                Some(shared_instances[offset].address.clone())
            }
        }
    }
}

#[derive(Debug, Clone)]
struct WeightedInstances {
    sum_of_weights: usize,
    prefix_sum_of_weights: Vec<usize>,
    instances: Vec<Arc<Instance>>,
}

impl From<Vec<Arc<Instance>>> for WeightedInstances {
    fn from(instances: Vec<Arc<Instance>>) -> Self {
        let mut sum_of_weights = 0;
        let mut prefix_sum_of_weights = Vec::with_capacity(instances.len());
        for instance in instances.iter() {
            sum_of_weights += instance.weight as usize;
            prefix_sum_of_weights.push(sum_of_weights);
        }
        Self {
            instances,
            prefix_sum_of_weights,
            sum_of_weights,
        }
    }
}

#[derive(Debug, Clone)]
pub struct WeightedRandomBalance<K>
where
    K: Hash + PartialEq + Eq + Send + Sync + 'static,
{
    router: DashMap<K, Arc<WeightedInstances>>,
}

impl<K> WeightedRandomBalance<K>
where
    K: Hash + PartialEq + Eq + Send + Sync + 'static,
{
    pub fn with_discover<D>(_: &D) -> Self
    where
        D: Discover<Key = K>,
    {
        Self {
            router: DashMap::new(),
        }
    }

    pub fn new() -> Self {
        Self {
            router: DashMap::new(),
        }
    }
}

impl<K> Default for WeightedRandomBalance<K>
where
    K: Hash + PartialEq + Eq + Send + Sync + 'static,
{
    fn default() -> Self {
        Self::new()
    }
}

impl<D> LoadBalance<D> for WeightedRandomBalance<D::Key>
where
    D: Discover,
{
    type InstanceIter = InstancePicker;

    async fn get_picker<'future>(
        &'future self,
        endpoint: &'future Endpoint,
        discover: &'future D,
    ) -> Result<Self::InstanceIter, LoadBalanceError> {
        let key = discover.key(endpoint);
        let weighted_list = if let Some(instances) = self.router.get(&key) {
            instances.clone()
        } else {
            let instances = Arc::new(WeightedInstances::from(
                discover
                    .discover(endpoint)
                    .await
                    .map_err(|err| err.into())?,
            ));
            self.router.insert(key, Arc::clone(&instances));
            instances
        };
        let sum_of_weights = weighted_list.sum_of_weights;
        Ok(InstancePicker {
            last_offset: None,
            iter_times: 0,
            shared_instances: weighted_list,
            sum_of_weights,
        })
    }

    fn rebalance(&self, changes: Change<D::Key>) {
        if let Entry::Occupied(entry) = self.router.entry(changes.key.clone()) {
            entry.replace_entry(Arc::new(WeightedInstances::from(changes.all)));
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use rand::{RngCore, rng};

    use super::{LoadBalance, WeightedRandomBalance};
    use crate::{
        context::Endpoint,
        discovery::{StaticDiscover, WeightedStaticDiscover},
    };

    #[tokio::test]
    async fn test_weighted_random() {
        let empty = Endpoint::new("".into());
        let discover = StaticDiscover::from(vec![
            "127.0.0.1:8000".parse().unwrap(),
            "127.0.0.2:9000".parse().unwrap(),
        ]);
        let lb = WeightedRandomBalance::with_discover(&discover);
        let picker = lb.get_picker(&empty, &discover).await.unwrap();
        let all = picker.collect::<Vec<_>>();
        assert_eq!(all.len(), 2);
        assert_ne!(all[0], all[1]);
    }

    #[tokio::test]
    async fn test_weighted_random_load_balance_same_instance() {
        let cycle = 1000;

        let empty = Endpoint::new("".into());
        let mut weighted_instances = Vec::with_capacity(100);
        let mut total_weight = 0;
        for i in 0..100 {
            let addr = format!("127.0.0.{i}:8000").parse().unwrap();
            let weight = 1;
            weighted_instances.push((addr, weight));
            total_weight += weight;
        }
        let discover = WeightedStaticDiscover::from(weighted_instances.clone());
        let lb = WeightedRandomBalance::with_discover(&discover);

        let mut actual_weights: HashMap<String, u32> = HashMap::new();
        for _ in 0..(total_weight * cycle) {
            let mut picker = lb.get_picker(&empty, &discover).await.unwrap();
            let addr = picker.next().unwrap();
            let count = actual_weights.entry(addr.to_string()).or_insert(0);
            *count += 1;
        }
        assert_eq!(actual_weights.len(), 100);
        for instance in weighted_instances.iter() {
            let addr = instance.0.to_string();
            let weight = instance.1;
            let count = *actual_weights.entry(addr.to_string()).or_insert(0);

            let expected_rate = (weight as f64) / (total_weight as f64);
            let actual_rate = (count as f64) / ((total_weight * cycle) as f64);

            println!("addr: {addr}, expected: {expected_rate}, actual: {actual_rate}");
            assert!((expected_rate - actual_rate).abs() < 0.01);
        }
    }

    #[tokio::test]
    async fn test_weighted_random_load_balance() {
        let cycle = 10;

        let empty = Endpoint::new("".into());
        let mut weighted_instances = Vec::with_capacity(100);
        let mut total_weight = 0;
        for i in 0..100 {
            let addr = format!("127.0.0.{i}:8000").parse().unwrap();
            let weight = rng().next_u32() % 100 + 1;
            weighted_instances.push((addr, weight));
            total_weight += weight;
        }
        let discover = WeightedStaticDiscover::from(weighted_instances.clone());
        let lb = WeightedRandomBalance::with_discover(&discover);

        let mut actual_weights: HashMap<String, u32> = HashMap::new();
        for _ in 0..(total_weight * cycle) {
            let mut picker = lb.get_picker(&empty, &discover).await.unwrap();
            let addr = picker.next().unwrap();
            let count = actual_weights.entry(addr.to_string()).or_insert(0);
            *count += 1;
        }
        for instance in weighted_instances.iter() {
            let addr = instance.0.to_string();
            let weight = instance.1;
            let count = *actual_weights.entry(addr.to_string()).or_insert(0);

            let expected_rate = (weight as f64) / (total_weight as f64);
            let actual_rate = (count as f64) / ((total_weight * cycle) as f64);

            println!("addr: {addr}, expected: {expected_rate}, actual: {actual_rate}");
            assert!((expected_rate - actual_rate).abs() < 0.01);
        }
    }
}