use crate::error::{MetricsError, Result};
use scirs2_core::numeric::Float;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::hash::{DefaultHasher, Hash, Hasher};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PartitionAssignment {
pub partition_id: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PartitionStats {
pub element_count: u64,
pub load_balance_selections: u64,
pub current_queue_depth: usize,
}
#[derive(Debug, Clone)]
pub struct HashPartitioner {
num_partitions: usize,
stats: Vec<PartitionStats>,
}
impl HashPartitioner {
pub fn new(num_partitions: usize) -> Result<Self> {
if num_partitions == 0 {
return Err(MetricsError::InvalidInput(
"HashPartitioner requires at least 1 partition".to_string(),
));
}
Ok(Self {
num_partitions,
stats: vec![PartitionStats::default(); num_partitions],
})
}
pub fn assign<K: Hash>(&mut self, key: &K) -> PartitionAssignment {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let partition_id = (hasher.finish() as usize) % self.num_partitions;
self.stats[partition_id].element_count += 1;
PartitionAssignment { partition_id }
}
#[inline]
pub fn stats(&self) -> &[PartitionStats] {
&self.stats
}
#[inline]
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
}
#[derive(Debug, Clone)]
pub struct RoundRobinPartitioner {
num_partitions: usize,
cursor: usize,
stats: Vec<PartitionStats>,
}
impl RoundRobinPartitioner {
pub fn new(num_partitions: usize) -> Result<Self> {
if num_partitions == 0 {
return Err(MetricsError::InvalidInput(
"RoundRobinPartitioner requires at least 1 partition".to_string(),
));
}
Ok(Self {
num_partitions,
cursor: 0,
stats: vec![PartitionStats::default(); num_partitions],
})
}
pub fn assign(&mut self) -> PartitionAssignment {
let partition_id = self.cursor;
self.cursor = (self.cursor + 1) % self.num_partitions;
self.stats[partition_id].element_count += 1;
PartitionAssignment { partition_id }
}
#[inline]
pub fn stats(&self) -> &[PartitionStats] {
&self.stats
}
#[inline]
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
}
#[derive(Debug, Clone)]
pub struct RangePartitioner<F: Float + std::fmt::Debug + Copy + PartialOrd> {
split_points: Vec<F>,
num_partitions: usize,
stats: Vec<PartitionStats>,
}
impl<F: Float + std::fmt::Debug + Copy + PartialOrd> RangePartitioner<F> {
pub fn new(num_partitions: usize, domain_min: F, domain_max: F) -> Result<Self> {
if num_partitions == 0 {
return Err(MetricsError::InvalidInput(
"RangePartitioner requires at least 1 partition".to_string(),
));
}
if domain_max <= domain_min {
return Err(MetricsError::InvalidInput(
"RangePartitioner domain_max must be > domain_min".to_string(),
));
}
let range = domain_max - domain_min;
let n_f = F::from(num_partitions).expect("usize fits in F");
let step = range / n_f;
let split_points: Vec<F> = (1..num_partitions)
.map(|i| domain_min + step * F::from(i).expect("usize fits in F"))
.collect();
Ok(Self {
split_points,
num_partitions,
stats: vec![PartitionStats::default(); num_partitions],
})
}
pub fn from_split_points(split_points: Vec<F>) -> Result<Self> {
let num_partitions = split_points.len() + 1;
for window in split_points.windows(2) {
if window[0] >= window[1] {
return Err(MetricsError::InvalidInput(
"RangePartitioner split_points must be strictly ascending".to_string(),
));
}
}
Ok(Self {
split_points,
num_partitions,
stats: vec![PartitionStats::default(); num_partitions],
})
}
pub fn assign(&mut self, value: F) -> PartitionAssignment {
let partition_id = self
.split_points
.iter()
.position(|&sp| value < sp)
.unwrap_or(self.num_partitions - 1);
self.stats[partition_id].element_count += 1;
PartitionAssignment { partition_id }
}
#[inline]
pub fn stats(&self) -> &[PartitionStats] {
&self.stats
}
#[inline]
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
}
#[derive(Debug, Clone)]
pub struct LoadBalancedPartitioner {
num_partitions: usize,
queue_depths: Vec<usize>,
stats: Vec<PartitionStats>,
}
impl LoadBalancedPartitioner {
pub fn new(num_partitions: usize) -> Result<Self> {
if num_partitions == 0 {
return Err(MetricsError::InvalidInput(
"LoadBalancedPartitioner requires at least 1 partition".to_string(),
));
}
Ok(Self {
num_partitions,
queue_depths: vec![0; num_partitions],
stats: vec![PartitionStats::default(); num_partitions],
})
}
pub fn assign(&mut self) -> PartitionAssignment {
let partition_id = self
.queue_depths
.iter()
.enumerate()
.min_by_key(|&(_, &depth)| depth)
.map(|(id, _)| id)
.unwrap_or(0);
self.stats[partition_id].element_count += 1;
self.stats[partition_id].load_balance_selections += 1;
PartitionAssignment { partition_id }
}
pub fn report_queue_depth(&mut self, partition_id: usize, depth: usize) -> Result<()> {
if partition_id >= self.num_partitions {
return Err(MetricsError::InvalidInput(format!(
"partition_id {partition_id} out of range (num_partitions={})",
self.num_partitions
)));
}
self.queue_depths[partition_id] = depth;
self.stats[partition_id].current_queue_depth = depth;
Ok(())
}
#[inline]
pub fn queue_depths(&self) -> &[usize] {
&self.queue_depths
}
#[inline]
pub fn stats(&self) -> &[PartitionStats] {
&self.stats
}
#[inline]
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
}
#[derive(Debug, Default)]
pub struct PartitionRegistry {
partitioner_stats: HashMap<String, Vec<PartitionStats>>,
}
impl PartitionRegistry {
pub fn new() -> Self {
Self {
partitioner_stats: HashMap::new(),
}
}
pub fn register(&mut self, name: impl Into<String>, stats: Vec<PartitionStats>) {
self.partitioner_stats.insert(name.into(), stats);
}
pub fn get(&self, name: &str) -> Option<&[PartitionStats]> {
self.partitioner_stats.get(name).map(|v| v.as_slice())
}
pub fn imbalance_score(&self, name: &str) -> Option<f64> {
let stats = self.partitioner_stats.get(name)?;
if stats.is_empty() {
return None;
}
let total: u64 = stats.iter().map(|s| s.element_count).sum();
if total == 0 {
return Some(0.0);
}
let n = stats.len() as f64;
let mean = total as f64 / n;
let max = stats.iter().map(|s| s.element_count).max().unwrap_or(0) as f64;
Some((max / mean.max(f64::EPSILON)) - 1.0)
}
pub fn registered_names(&self) -> Vec<&str> {
self.partitioner_stats.keys().map(|s| s.as_str()).collect()
}
}
#[derive(Debug)]
pub struct PartitionedStream<T> {
queues: Vec<VecDeque<T>>,
partitioner: RoundRobinPartitioner,
}
impl<T: Clone> PartitionedStream<T> {
pub fn new(num_partitions: usize) -> Result<Self> {
let partitioner = RoundRobinPartitioner::new(num_partitions)?;
Ok(Self {
queues: vec![VecDeque::new(); num_partitions],
partitioner,
})
}
pub fn push(&mut self, value: T) {
let assignment = self.partitioner.assign();
self.queues[assignment.partition_id].push_back(value);
}
pub fn pop(&mut self, partition_id: usize) -> Result<Option<T>> {
if partition_id >= self.queues.len() {
return Err(MetricsError::InvalidInput(format!(
"partition_id {partition_id} out of range"
)));
}
Ok(self.queues[partition_id].pop_front())
}
pub fn queue_depth(&self, partition_id: usize) -> Result<usize> {
if partition_id >= self.queues.len() {
return Err(MetricsError::InvalidInput(format!(
"partition_id {partition_id} out of range"
)));
}
Ok(self.queues[partition_id].len())
}
#[inline]
pub fn num_partitions(&self) -> usize {
self.queues.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_partitioner_deterministic() {
let mut p = HashPartitioner::new(4).expect("valid");
let a1 = p.assign(&"hello");
let a2 = p.assign(&"hello");
assert_eq!(a1.partition_id, a2.partition_id, "same key same partition");
let a3 = p.assign(&"world");
assert!(a3.partition_id < 4);
}
#[test]
fn hash_partitioner_stats() {
let mut p = HashPartitioner::new(3).expect("valid");
for i in 0_u64..30 {
p.assign(&i);
}
let total: u64 = p.stats().iter().map(|s| s.element_count).sum();
assert_eq!(total, 30);
}
#[test]
fn round_robin_even_distribution() {
let mut p = RoundRobinPartitioner::new(3).expect("valid");
let assignments: Vec<usize> = (0..9).map(|_| p.assign().partition_id).collect();
assert_eq!(assignments, vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
}
#[test]
fn round_robin_single_partition() {
let mut p = RoundRobinPartitioner::new(1).expect("valid");
for _ in 0..5 {
assert_eq!(p.assign().partition_id, 0);
}
}
#[test]
fn range_partitioner_equal_width() {
let mut p = RangePartitioner::<f64>::new(3, 0.0, 9.0).expect("valid");
assert_eq!(p.assign(1.0).partition_id, 0);
assert_eq!(p.assign(3.0).partition_id, 1);
assert_eq!(p.assign(6.0).partition_id, 2);
assert_eq!(p.assign(8.9).partition_id, 2);
}
#[test]
fn range_partitioner_clamping() {
let mut p = RangePartitioner::<f64>::new(3, 0.0, 9.0).expect("valid");
assert_eq!(p.assign(-5.0).partition_id, 0);
assert_eq!(p.assign(100.0).partition_id, 2);
}
#[test]
fn range_partitioner_invalid_domain() {
assert!(RangePartitioner::<f64>::new(3, 5.0, 5.0).is_err());
assert!(RangePartitioner::<f64>::new(3, 10.0, 5.0).is_err());
}
#[test]
fn range_partitioner_from_split_points() {
let mut p = RangePartitioner::<f64>::from_split_points(vec![10.0, 20.0]).expect("valid");
assert_eq!(p.assign(5.0).partition_id, 0);
assert_eq!(p.assign(15.0).partition_id, 1);
assert_eq!(p.assign(25.0).partition_id, 2);
}
#[test]
fn load_balanced_routes_to_least_loaded() {
let mut p = LoadBalancedPartitioner::new(3).expect("valid");
assert_eq!(p.assign().partition_id, 0);
p.report_queue_depth(0, 10).expect("valid id");
let next = p.assign().partition_id;
assert!(next == 1 || next == 2);
}
#[test]
fn load_balanced_invalid_partition_id() {
let mut p = LoadBalancedPartitioner::new(2).expect("valid");
assert!(p.report_queue_depth(5, 10).is_err());
}
#[test]
fn partition_registry_imbalance() {
let mut registry = PartitionRegistry::new();
let stats = vec![
PartitionStats { element_count: 100, ..Default::default() },
PartitionStats { element_count: 100, ..Default::default() },
PartitionStats { element_count: 100, ..Default::default() },
];
registry.register("rr", stats);
let score = registry.imbalance_score("rr").expect("exists");
assert!((score - 0.0).abs() < 1e-10);
}
#[test]
fn partitioned_stream_fan_out() {
let mut stream: PartitionedStream<i32> = PartitionedStream::new(3).expect("valid");
for i in 0..9 {
stream.push(i);
}
for pid in 0..3 {
assert_eq!(stream.queue_depth(pid).expect("valid"), 3);
}
}
#[test]
fn partitioned_stream_pop() {
let mut stream: PartitionedStream<i32> = PartitionedStream::new(2).expect("valid");
stream.push(10);
stream.push(20);
let v0 = stream.pop(0).expect("ok").expect("some");
let v1 = stream.pop(1).expect("ok").expect("some");
assert_eq!(v0, 10);
assert_eq!(v1, 20);
}
}