loadwise-core 0.1.0

Core traits, strategies, and in-memory stores for loadwise
Documentation
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Mutex;

use super::{SelectionContext, Strategy};
use crate::{Node, Weighted};

/// Smooth weighted round-robin (Nginx-style).
///
/// Each node maintains a running `current_weight`. On each selection:
/// 1. Add each node's effective weight to its current weight
/// 2. Select the node with the highest current weight
/// 3. Subtract total weight from the selected node's current weight
///
/// This produces an interleaved sequence that respects relative weights.
///
/// State is tracked by a fingerprint of the candidate set (IDs **and** weights).
/// If the candidate list changes — nodes added/removed/reordered, or a node's
/// weight changes — the internal state resets and re-converges within a few rounds.
pub struct WeightedRoundRobin {
    state: Mutex<WrrState>,
}

impl std::fmt::Debug for WeightedRoundRobin {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("WeightedRoundRobin").finish_non_exhaustive()
    }
}

struct WrrState {
    fingerprint: u64,
    weights: Vec<i64>,
}

impl WeightedRoundRobin {
    pub fn new() -> Self {
        Self {
            state: Mutex::new(WrrState {
                fingerprint: 0,
                weights: Vec::new(),
            }),
        }
    }
}

impl Default for WeightedRoundRobin {
    fn default() -> Self {
        Self::new()
    }
}

/// WRR-specific fingerprint that includes node weights, so a weight change
/// also triggers a state reset.
fn wrr_fingerprint<N: Weighted + Node>(candidates: &[N]) -> u64 {
    let mut hasher = DefaultHasher::new();
    candidates.len().hash(&mut hasher);
    for node in candidates {
        node.id().hash(&mut hasher);
        node.weight().hash(&mut hasher);
    }
    hasher.finish()
}

impl<N: Weighted + Node> Strategy<N> for WeightedRoundRobin {
    fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
        if candidates.is_empty() {
            return None;
        }

        let fingerprint = wrr_fingerprint(candidates);
        let mut state = self.state.lock().unwrap();

        // Reset state when the candidate set changes
        if state.fingerprint != fingerprint {
            state.fingerprint = fingerprint;
            state.weights = vec![0; candidates.len()];
        }

        let total_weight: i64 = candidates
            .iter()
            .enumerate()
            .filter(|(i, _)| !ctx.is_excluded(*i))
            .map(|(_, n)| n.weight() as i64)
            .sum();
        if total_weight == 0 {
            return None;
        }

        let mut best_idx = None;
        let mut best_weight = i64::MIN;

        for (i, node) in candidates.iter().enumerate() {
            if ctx.is_excluded(i) {
                continue;
            }
            state.weights[i] += node.weight() as i64;
            if state.weights[i] > best_weight {
                best_weight = state.weights[i];
                best_idx = Some(i);
            }
        }

        if let Some(idx) = best_idx {
            state.weights[idx] -= total_weight;
        }
        best_idx
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    struct W {
        id: &'static str,
        weight: u32,
    }

    impl W {
        fn new(id: &'static str, weight: u32) -> Self {
            Self { id, weight }
        }
    }

    impl Node for W {
        type Id = &'static str;
        fn id(&self) -> &&'static str {
            &self.id
        }
    }

    impl Weighted for W {
        fn weight(&self) -> u32 {
            self.weight
        }
    }

    #[test]
    fn respects_weights() {
        let wrr = WeightedRoundRobin::new();
        let nodes = [W::new("a", 5), W::new("b", 1), W::new("c", 1)];
        let ctx = SelectionContext::default();

        let mut counts = [0u32; 3];
        for _ in 0..70 {
            let idx = wrr.select(&nodes, &ctx).unwrap();
            counts[idx] += 1;
        }

        // 5:1:1 ratio over 70 rounds = 50:10:10
        assert_eq!(counts[0], 50);
        assert_eq!(counts[1], 10);
        assert_eq!(counts[2], 10);
    }

    #[test]
    fn smooth_distribution() {
        let wrr = WeightedRoundRobin::new();
        let nodes = [W::new("x", 2), W::new("y", 1)];
        let ctx = SelectionContext::default();

        let sequence: Vec<usize> = (0..6)
            .map(|_| wrr.select(&nodes, &ctx).unwrap())
            .collect();
        assert_eq!(sequence, vec![0, 1, 0, 0, 1, 0]);
    }

    #[test]
    fn skips_excluded() {
        let wrr = WeightedRoundRobin::new();
        let nodes = [W::new("a", 3), W::new("b", 1)];
        let ctx = SelectionContext::builder().exclude(vec![0]).build();
        // Only node b (index 1) is eligible
        assert_eq!(wrr.select(&nodes, &ctx), Some(1));
    }

    #[test]
    fn all_excluded_returns_none() {
        let wrr = WeightedRoundRobin::new();
        let nodes = [W::new("a", 1), W::new("b", 1)];
        let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
        assert_eq!(wrr.select(&nodes, &ctx), None);
    }

    #[test]
    fn resets_on_candidate_change() {
        let wrr = WeightedRoundRobin::new();
        let ctx = SelectionContext::default();

        let nodes_v1 = [W::new("a", 2), W::new("b", 1)];
        let _ = wrr.select(&nodes_v1, &ctx);
        let _ = wrr.select(&nodes_v1, &ctx);

        // Change candidate set — state should reset, not corrupt
        let nodes_v2 = [W::new("b", 1), W::new("c", 3)];
        let idx = wrr.select(&nodes_v2, &ctx).unwrap();
        // After reset, node with weight 3 should win first round
        assert_eq!(idx, 1);
    }
}