use std::cell::Cell;
use std::sync::atomic::{AtomicU16, AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static NEXT_PARTITION_ID: AtomicU16 = AtomicU16::new(0);
const MAX_PARTITIONS: u16 = u16::MAX;
const PARTITION_BITS: u32 = 16;
const TIMESTAMP_BITS: u32 = 32;
const SEQUENCE_BITS: u32 = 16;
const TIMESTAMP_SHIFT: u32 = SEQUENCE_BITS;
const PARTITION_SHIFT: u32 = TIMESTAMP_BITS + SEQUENCE_BITS;
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
const TIMESTAMP_MASK: u64 = (1 << TIMESTAMP_BITS) - 1;
const PARTITION_MASK: u64 = (1 << PARTITION_BITS) - 1;
thread_local! {
static PARTITION_ID: Cell<u16> = Cell::new(allocate_partition_id());
static SEQUENCE: Cell<u16> = const { Cell::new(0) };
static LAST_TIMESTAMP: Cell<u32> = const { Cell::new(0) };
}
fn allocate_partition_id() -> u16 {
NEXT_PARTITION_ID.fetch_add(1, Ordering::Relaxed) % MAX_PARTITIONS
}
#[inline]
#[allow(clippy::expect_used)]
fn current_timestamp() -> u32 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time before Unix epoch")
.as_secs() as u32
}
#[derive(Debug)]
pub struct HierarchicalRowIdGenerator {
epoch_offset: u32,
forced_partition: Option<u16>,
}
impl Default for HierarchicalRowIdGenerator {
fn default() -> Self {
Self::new()
}
}
impl HierarchicalRowIdGenerator {
pub const fn new() -> Self {
Self {
epoch_offset: 0,
forced_partition: None,
}
}
pub const fn with_epoch(epoch_offset: u32) -> Self {
Self {
epoch_offset,
forced_partition: None,
}
}
pub const fn with_partition(partition: u16) -> Self {
Self {
epoch_offset: 0,
forced_partition: Some(partition),
}
}
#[inline]
pub fn next(&self) -> u64 {
let partition = self.forced_partition.unwrap_or_else(|| {
PARTITION_ID.with(|p| p.get())
});
let (timestamp, sequence) = SEQUENCE.with(|seq| {
LAST_TIMESTAMP.with(|last_ts| {
let current_ts = current_timestamp().saturating_sub(self.epoch_offset);
let last = last_ts.get();
match current_ts.cmp(&last) {
std::cmp::Ordering::Greater => {
last_ts.set(current_ts);
seq.set(1);
(current_ts, 0u16)
}
std::cmp::Ordering::Equal => {
let s = seq.get();
if s == u16::MAX {
std::thread::sleep(std::time::Duration::from_millis(1));
let new_ts = current_timestamp().saturating_sub(self.epoch_offset);
last_ts.set(new_ts);
seq.set(1);
(new_ts, 0u16)
} else {
seq.set(s + 1);
(current_ts, s)
}
}
std::cmp::Ordering::Less => {
let s = seq.get();
seq.set(s.wrapping_add(1));
(last, s)
}
}
})
});
Self::pack(partition, timestamp, sequence)
}
#[inline]
pub fn next_batch(&self, count: usize) -> Vec<u64> {
let mut ids = Vec::with_capacity(count);
for _ in 0..count {
ids.push(self.next());
}
ids
}
#[inline]
const fn pack(partition: u16, timestamp: u32, sequence: u16) -> u64 {
((partition as u64) << PARTITION_SHIFT)
| ((timestamp as u64) << TIMESTAMP_SHIFT)
| (sequence as u64)
}
#[inline]
pub const fn unpack(id: u64) -> (u16, u32, u16) {
let partition = ((id >> PARTITION_SHIFT) & PARTITION_MASK) as u16;
let timestamp = ((id >> TIMESTAMP_SHIFT) & TIMESTAMP_MASK) as u32;
let sequence = (id & SEQUENCE_MASK) as u16;
(partition, timestamp, sequence)
}
#[inline]
pub const fn partition_of(id: u64) -> u16 {
((id >> PARTITION_SHIFT) & PARTITION_MASK) as u16
}
#[inline]
pub const fn timestamp_of(id: u64) -> u32 {
((id >> TIMESTAMP_SHIFT) & TIMESTAMP_MASK) as u32
}
#[inline]
pub const fn sequence_of(id: u64) -> u16 {
(id & SEQUENCE_MASK) as u16
}
#[inline]
pub const fn is_before(id1: u64, id2: u64) -> bool {
let (p1, t1, s1) = Self::unpack(id1);
let (p2, t2, s2) = Self::unpack(id2);
if p1 == p2 {
t1 < t2 || (t1 == t2 && s1 < s2)
} else {
t1 < t2
}
}
}
#[derive(Debug)]
pub struct BatchRowIdAllocator {
batch_size: u64,
allocators: dashmap::DashMap<String, BatchState>,
}
#[derive(Debug)]
struct BatchState {
next_batch: AtomicU64,
current: AtomicU64,
batch_end: AtomicU64,
}
impl BatchState {
fn new(start: u64, batch_size: u64) -> Self {
Self {
next_batch: AtomicU64::new(start + batch_size),
current: AtomicU64::new(start),
batch_end: AtomicU64::new(start + batch_size),
}
}
}
impl BatchRowIdAllocator {
pub fn new(batch_size: u64) -> Self {
Self {
batch_size,
allocators: dashmap::DashMap::new(),
}
}
pub fn initialize_table(&self, table: &str, start_id: u64) {
self.allocators.entry(table.to_string())
.or_insert_with(|| BatchState::new(start_id, self.batch_size));
}
#[inline]
pub fn next(&self, table: &str) -> u64 {
let state = self.allocators.entry(table.to_string())
.or_insert_with(|| BatchState::new(1, self.batch_size));
let current = state.current.fetch_add(1, Ordering::Relaxed);
let batch_end = state.batch_end.load(Ordering::Acquire);
if current < batch_end {
return current;
}
let new_start = state.next_batch.fetch_add(self.batch_size, Ordering::SeqCst);
state.batch_end.store(new_start + self.batch_size, Ordering::Release);
state.current.store(new_start + 1, Ordering::Release);
new_start
}
pub fn max_allocated(&self, table: &str) -> Option<u64> {
self.allocators.get(table).map(|state| {
state.next_batch.load(Ordering::Relaxed)
})
}
pub fn checkpoint_state(&self) -> Vec<(String, u64)> {
self.allocators.iter()
.map(|entry| {
let max = entry.next_batch.load(Ordering::Relaxed);
(entry.key().clone(), max)
})
.collect()
}
pub fn restore_from_checkpoint(&self, table: &str, max_id: u64) {
let safe_start = max_id + self.batch_size; self.allocators.insert(
table.to_string(),
BatchState::new(safe_start, self.batch_size)
);
}
}
pub enum RowIdGenerator {
Hierarchical(HierarchicalRowIdGenerator),
Batched(BatchRowIdAllocator),
}
impl RowIdGenerator {
pub fn next(&self, table: &str) -> u64 {
match self {
Self::Hierarchical(gen) => gen.next(),
Self::Batched(alloc) => alloc.next(table),
}
}
pub fn next_batch(&self, table: &str, count: usize) -> Vec<u64> {
match self {
Self::Hierarchical(gen) => gen.next_batch(count),
Self::Batched(alloc) => {
(0..count).map(|_| alloc.next(table)).collect()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::thread;
#[test]
fn test_hierarchical_uniqueness() {
let gen = HierarchicalRowIdGenerator::new();
let mut ids = HashSet::new();
for _ in 0..10000 {
let id = gen.next();
assert!(ids.insert(id), "Duplicate ID generated: {}", id);
}
}
#[test]
fn test_hierarchical_monotonic() {
let gen = HierarchicalRowIdGenerator::with_partition(0);
let mut last = 0u64;
for _ in 0..1000 {
let id = gen.next();
assert!(id > last, "IDs not monotonic: {} <= {}", id, last);
last = id;
}
}
#[test]
fn test_hierarchical_cross_thread() {
let ids: std::sync::Arc<std::sync::Mutex<HashSet<u64>>> =
std::sync::Arc::new(std::sync::Mutex::new(HashSet::new()));
let handles: Vec<_> = (0..4)
.map(|_| {
let ids = ids.clone();
thread::spawn(move || {
let gen = HierarchicalRowIdGenerator::new();
let local_ids: Vec<u64> = (0..1000).map(|_| gen.next()).collect();
let mut guard = ids.lock().unwrap();
for id in local_ids {
assert!(guard.insert(id), "Cross-thread duplicate: {}", id);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(ids.lock().unwrap().len(), 4000);
}
#[test]
fn test_pack_unpack() {
let partition = 123u16;
let timestamp = 1700000000u32;
let sequence = 456u16;
let packed = HierarchicalRowIdGenerator::pack(partition, timestamp, sequence);
let (p, t, s) = HierarchicalRowIdGenerator::unpack(packed);
assert_eq!(p, partition);
assert_eq!(t, timestamp);
assert_eq!(s, sequence);
}
#[test]
fn test_batch_allocator() {
let alloc = BatchRowIdAllocator::new(100);
let id1 = alloc.next("test");
let id2 = alloc.next("test");
let id3 = alloc.next("test");
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[test]
fn test_batch_allocator_cross_batch() {
let alloc = BatchRowIdAllocator::new(10);
for i in 1..=25 {
let id = alloc.next("test");
assert!(id >= 1, "ID {} should be >= 1", id);
}
let max = alloc.max_allocated("test").unwrap();
assert!(max >= 25, "Max allocated should be >= 25");
}
}