pick_fast 0.1.7

High-performance weighted random load balancer for selecting low-latency nodes with atomic EMA weight updates. / 高性能加权随机负载均衡器,用于随机选择低延迟节点,支持基于原子操作的指数移动平均权重更新。
Documentation
#![cfg_attr(docsrs, feature(doc_cfg))]

use std::{
  marker::PhantomData,
  ops::Deref,
  sync::atomic::{AtomicU32, Ordering},
};

/// 策略 Trait: 定义耗时与权重的转换逻辑
pub trait Rank {
  /// 计算权重
  /// - 输入: `latency_us` (微秒)
  /// - 输出: `weight` (权重值)
  fn calc(latency_us: u32) -> u32;

  /// 初始权重 (默认为 1,实现慢启动)
  fn init() -> u32 {
    1
  }
}

/// 默认策略: 倒数模型 (Inverse)
/// 逻辑: Weight = BASE / Latency
/// 场景: 适用于 N <= 256 的通用负载均衡
pub struct Inverse;

impl Rank for Inverse {
  #[inline(always)]
  fn calc(latency: u32) -> u32 {
    // 调整为 2^22 (约 4,194,304)
    // 计算: 256 * 2^22 = 2^30 (约10亿) < u32::MAX (约42亿)
    // 即使有 256 个节点且延时均为 1us,也不会溢出。
    const BASE: u32 = 1 << 22;

    // 精度说明:
    // 1ms (1000us) -> 权重 4194
    // 50ms(50000us)-> 权重 83
    // 1s  (10^6us) -> 权重 4
    // 精度足够区分不同节点节点的性能差异。

    BASE / latency.max(1)
  }
}

/// 节点: 包含数据和权重
pub struct Node<T> {
  /// 节点数据
  pub data: T,
  /// 节点权重 (原子)
  pub weight: AtomicU32,
}

impl<T> Node<T> {
  /// 创建节点
  pub fn new(data: T, weight: u32) -> Self {
    Self {
      data,
      weight: AtomicU32::new(weight),
    }
  }
}

impl<T> Deref for Node<T> {
  type Target = T;
  fn deref(&self) -> &Self::Target {
    &self.data
  }
}

/// 选中节点的句柄
/// 包含节点引用和索引
pub struct Handle<'a, T> {
  pub index: usize,
  pub node: &'a Node<T>,
}

impl<'a, T> Deref for Handle<'a, T> {
  type Target = T;
  fn deref(&self) -> &Self::Target {
    &self.node.data
  }
}

// ========================================================================
// 核心结构 (PickFast)
// ========================================================================

/// 极速加权负载均衡器
///
/// 泛型参数:
/// - `T`: 节点数据类型
/// - `M`: 权重模型 (默认为 `Inverse`)
pub struct PickFast<T, M = Inverse> {
  /// 节点列表
  pub li: Vec<Node<T>>,

  /// 总权重 (原子缓存)
  pub total: AtomicU32,

  _marker: PhantomData<M>,
}

unsafe impl<T: Sync, M> Sync for PickFast<T, M> {}
unsafe impl<T: Send, M> Send for PickFast<T, M> {}

impl<T, M: Rank> PickFast<T, M> {
  /// 创建一个新的选择器
  pub fn new(data: impl IntoIterator<Item = T>) -> Self {
    let init_val = M::init();
    let li: Vec<Node<T>> = data.into_iter().map(|d| Node::new(d, init_val)).collect();

    let n = li.len();
    assert!(n > 0, "PickFast: node count must be > 0");

    // 如果 N 非常大,这里给个运行时提醒
    if n > 256 {
      log::warn!("PickFast n={n} is large, ensure Rank won't overflow u32");
    }

    // 初始化总权重
    let total = AtomicU32::new(init_val * (n as u32));

    Self {
      li,
      total,
      _marker: PhantomData,
    }
  }

  /// 节点数量
  #[inline(always)]
  pub fn len(&self) -> usize {
    self.li.len()
  }

  /// 是否为空
  #[inline(always)]
  pub fn is_empty(&self) -> bool {
    self.li.is_empty()
  }

  /// 极速挑选 (Pick)
  ///
  /// O(1) 获取总权重 -> O(N) 扫描
  #[inline(always)]
  pub fn pick(&self) -> Handle<'_, T> {
    let total_w = self.total.load(Ordering::Relaxed);

    if total_w == 0 {
      return Handle {
        index: 0,
        node: &self.li[0],
      };
    }

    // 随机目标
    let target = fastrand::u32(0..total_w);
    let mut sum = 0;

    // 扫描
    for (i, node) in self.li.iter().enumerate() {
      sum += node.weight.load(Ordering::Relaxed);
      if sum > target {
        return Handle { index: i, node };
      }
    }

    // 兜底 (处理并发更新时的微小窗口)
    let last = self.li.len() - 1;
    Handle {
      index: last,
      node: &self.li[last],
    }
  }

  /// 设定观测值 (Set)
  ///
  /// 传入观测值 (如耗时),内部自动计算权重并平滑更新
  #[inline(always)]
  pub fn set(&self, index: usize, val: u32) {
    if index >= self.li.len() {
      return;
    }

    let target_w = M::calc(val);

    // CAS 更新单节点权重 / CAS update node weight
    // EMA公式: New = max(1, (Old + Target) / 2) / EMA formula: New = max(1, (Old + Target) / 2)
    let _ = self.li[index]
      .weight
      .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
        Some(((old + target_w) >> 1).max(1))
      })
      .map(|prev| {
        // 修正总权重 / Adjust total weight
        let new_w = ((prev + target_w) >> 1).max(1);
        if new_w > prev {
          self.total.fetch_add(new_w - prev, Ordering::Relaxed);
        } else {
          self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
        }
      });
  }

  /// 标记节点失败 (Failed)
  ///
  /// 将指定节点权重减半,最低为1,用于处理节点故障或性能下降
  #[inline(always)]
  pub fn failed(&self, index: usize) {
    if index >= self.li.len() {
      return;
    }

    // CAS 更新单节点权重,减半但不低于1 / CAS update node weight, halve but not below 1
    let _ = self.li[index]
      .weight
      .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
        Some((old >> 1).max(1))
      })
      .map(|prev| {
        // 修正总权重 / Adjust total weight
        let new_w = (prev >> 1).max(1);
        self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
      });
  }

  /// 返回循环迭代器,起始位置使用加权随机选择
  ///
  /// 使用与 `pick()` 相同的加权随机算法选择起始位置
  #[cfg(feature = "iter")]
  pub fn iter(&self) -> citer::CIter<'_, Node<T>> {
    let total_w = self.total.load(Ordering::Relaxed);

    let start_pos = if total_w == 0 {
      0
    } else {
      let target = fastrand::u32(0..total_w);
      let mut sum = 0;

      let mut pos = 0;
      for (i, node) in self.li.iter().enumerate() {
        sum += node.weight.load(Ordering::Relaxed);
        if sum > target {
          pos = i;
          break;
        }
      }
      pos
    };

    citer::CIter::new(&self.li[..], start_pos)
  }
}