use dashmap::DashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct StickyPartitionerConfig {
pub batch_size: u32,
pub linger_duration: Duration,
}
impl Default for StickyPartitionerConfig {
fn default() -> Self {
Self {
batch_size: 16 * 1024, linger_duration: Duration::from_millis(100), }
}
}
struct TopicStickyState {
current_partition: AtomicU32,
messages_in_batch: AtomicU64,
batch_start_nanos: AtomicU64,
}
static EPOCH: std::sync::LazyLock<Instant> = std::sync::LazyLock::new(Instant::now);
impl TopicStickyState {
fn new(initial_partition: u32) -> Self {
Self {
current_partition: AtomicU32::new(initial_partition),
messages_in_batch: AtomicU64::new(0),
batch_start_nanos: AtomicU64::new(Self::now_nanos()),
}
}
#[inline]
fn now_nanos() -> u64 {
EPOCH.elapsed().as_nanos() as u64
}
#[inline]
fn batch_elapsed(&self) -> Duration {
let start = self.batch_start_nanos.load(Ordering::Relaxed);
let now = Self::now_nanos();
Duration::from_nanos(now.saturating_sub(start))
}
#[inline]
fn reset_batch_start(&self) {
self.batch_start_nanos
.store(Self::now_nanos(), Ordering::Relaxed);
}
}
pub struct StickyPartitioner {
config: StickyPartitionerConfig,
topic_states: DashMap<String, TopicStickyState>,
global_counter: AtomicU32,
}
impl StickyPartitioner {
pub fn new() -> Self {
Self::with_config(StickyPartitionerConfig::default())
}
pub fn with_config(config: StickyPartitionerConfig) -> Self {
Self {
config,
topic_states: DashMap::new(),
global_counter: AtomicU32::new(0),
}
}
pub fn partition(&self, topic: &str, key: Option<&[u8]>, num_partitions: u32) -> u32 {
if num_partitions == 0 {
return 0;
}
match key {
Some(k) => self.hash_partition(k, num_partitions),
None => self.sticky_partition(topic, num_partitions),
}
}
fn hash_partition(&self, key: &[u8], num_partitions: u32) -> u32 {
rivven_core::hash::murmur2_partition(key, num_partitions)
}
fn sticky_partition(&self, topic: &str, num_partitions: u32) -> u32 {
if let Some(state) = self.topic_states.get(topic) {
if !self.should_rotate(&state) {
state.messages_in_batch.fetch_add(1, Ordering::Relaxed);
return state.current_partition.load(Ordering::Relaxed) % num_partitions;
}
let current = state.current_partition.load(Ordering::Relaxed);
let next = (current + 1) % num_partitions;
state.current_partition.store(next, Ordering::Relaxed);
state.messages_in_batch.store(1, Ordering::Relaxed);
state.reset_batch_start();
return next;
}
let initial = self.global_counter.fetch_add(1, Ordering::Relaxed) % num_partitions;
let state = TopicStickyState::new(initial);
state.messages_in_batch.store(1, Ordering::Relaxed);
self.topic_states.insert(topic.to_string(), state);
initial
}
fn should_rotate(&self, state: &TopicStickyState) -> bool {
if self.config.batch_size > 0 {
let count = state.messages_in_batch.load(Ordering::Relaxed);
if count >= self.config.batch_size as u64 {
return true;
}
}
state.batch_elapsed() >= self.config.linger_duration
}
pub fn reset_topic(&self, topic: &str) {
self.topic_states.remove(topic);
}
pub fn retain_topics(&self, active_topics: &std::collections::HashSet<String>) {
self.topic_states
.retain(|topic, _| active_topics.contains(topic));
}
}
impl Default for StickyPartitioner {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keyed_messages_same_partition() {
let partitioner = StickyPartitioner::new();
let key = b"user-123";
let p1 = partitioner.partition("topic", Some(key), 10);
let p2 = partitioner.partition("topic", Some(key), 10);
let p3 = partitioner.partition("topic", Some(key), 10);
assert_eq!(p1, p2);
assert_eq!(p2, p3);
}
#[test]
fn test_keyless_messages_sticky() {
let config = StickyPartitionerConfig {
batch_size: 100,
linger_duration: Duration::from_secs(60), };
let partitioner = StickyPartitioner::with_config(config);
let mut partitions = Vec::new();
for _ in 0..100 {
partitions.push(partitioner.partition("topic", None, 10));
}
let first = partitions[0];
assert!(
partitions.iter().all(|&p| p == first),
"Messages within batch should go to same partition"
);
}
#[test]
fn test_batch_rotation() {
let config = StickyPartitionerConfig {
batch_size: 10, linger_duration: Duration::from_secs(60),
};
let partitioner = StickyPartitioner::with_config(config);
let mut partitions = Vec::new();
for _ in 0..25 {
partitions.push(partitioner.partition("topic", None, 10));
}
let unique: std::collections::HashSet<_> = partitions.iter().collect();
assert!(unique.len() >= 2, "Should have rotated partitions");
}
#[test]
fn test_different_topics_different_partitions() {
let partitioner = StickyPartitioner::new();
let p1 = partitioner.partition("topic-a", None, 100);
let p2 = partitioner.partition("topic-b", None, 100);
let p3 = partitioner.partition("topic-c", None, 100);
assert!(
p1 != p2 || p2 != p3 || p1 != p3,
"Different topics should get different initial partitions"
);
}
#[test]
fn test_murmur2_deterministic() {
let key = b"test-key-12345";
let h1 = rivven_core::hash::murmur2(key);
let h2 = rivven_core::hash::murmur2(key);
assert_eq!(h1, h2, "Same key should produce same hash");
let h3 = rivven_core::hash::murmur2(b"different-key");
assert_ne!(h1, h3, "Different keys should produce different hashes");
}
#[test]
fn test_key_distribution() {
let partitioner = StickyPartitioner::new();
let num_partitions = 8;
let mut counts = vec![0u32; num_partitions as usize];
for i in 0..1000 {
let key = format!("user-{}", i);
let partition = partitioner.partition("topic", Some(key.as_bytes()), num_partitions);
counts[partition as usize] += 1;
}
for (i, &count) in counts.iter().enumerate() {
assert!(
count >= 50,
"Partition {} only got {} keys (expected >= 50)",
i,
count
);
assert!(
count <= 200,
"Partition {} got {} keys (expected <= 200)",
i,
count
);
}
}
}