use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex};
use super::types::SequenceError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ReservationId(u64);
#[derive(Debug, Clone)]
pub struct ReservationHandle {
pub id: ReservationId,
pub sequence_key: String,
pub value: i64,
}
struct SequenceLock {
locked: Mutex<bool>,
unlocked: Condvar,
}
pub struct GapFreeManager {
locks: Mutex<HashMap<String, Arc<SequenceLock>>>,
next_id: Mutex<u64>,
}
impl GapFreeManager {
pub fn new() -> Self {
Self {
locks: Mutex::new(HashMap::new()),
next_id: Mutex::new(1),
}
}
pub fn reserve(
&self,
sequence_key: &str,
advance_fn: impl FnOnce() -> Result<i64, SequenceError>,
) -> Result<ReservationHandle, SequenceError> {
let lock = {
let mut locks = self.locks.lock().unwrap_or_else(|p| p.into_inner());
locks
.entry(sequence_key.to_string())
.or_insert_with(|| {
Arc::new(SequenceLock {
locked: Mutex::new(false),
unlocked: Condvar::new(),
})
})
.clone()
};
{
let mut is_locked = lock.locked.lock().unwrap_or_else(|p| p.into_inner());
while *is_locked {
is_locked = lock
.unlocked
.wait(is_locked)
.unwrap_or_else(|p| p.into_inner());
}
*is_locked = true;
}
let value = match advance_fn() {
Ok(v) => v,
Err(e) => {
self.unlock_sequence(&lock);
return Err(e);
}
};
let id = {
let mut next = self.next_id.lock().unwrap_or_else(|p| p.into_inner());
let id = ReservationId(*next);
*next += 1;
id
};
Ok(ReservationHandle {
id,
sequence_key: sequence_key.to_string(),
value,
})
}
pub fn commit(&self, handle: &ReservationHandle) {
let locks = self.locks.lock().unwrap_or_else(|p| p.into_inner());
if let Some(lock) = locks.get(&handle.sequence_key) {
self.unlock_sequence(lock);
}
}
pub fn rollback(&self, handle: &ReservationHandle, rollback_fn: impl FnOnce()) {
rollback_fn();
let locks = self.locks.lock().unwrap_or_else(|p| p.into_inner());
if let Some(lock) = locks.get(&handle.sequence_key) {
self.unlock_sequence(lock);
}
}
fn unlock_sequence(&self, lock: &SequenceLock) {
let mut is_locked = lock.locked.lock().unwrap_or_else(|p| p.into_inner());
*is_locked = false;
lock.unlocked.notify_one();
}
}
impl Default for GapFreeManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicI64, Ordering};
#[test]
fn reserve_commit() {
let mgr = GapFreeManager::new();
let counter = AtomicI64::new(0);
let handle = mgr
.reserve("1:test", || Ok(counter.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
assert_eq!(handle.value, 1);
mgr.commit(&handle);
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[test]
fn reserve_rollback() {
let mgr = GapFreeManager::new();
let counter = AtomicI64::new(0);
let handle = mgr
.reserve("1:test", || Ok(counter.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
assert_eq!(handle.value, 1);
assert_eq!(counter.load(Ordering::Relaxed), 1);
mgr.rollback(&handle, || {
counter.fetch_sub(1, Ordering::Relaxed);
});
assert_eq!(counter.load(Ordering::Relaxed), 0);
}
#[test]
fn sequential_reservations() {
let mgr = GapFreeManager::new();
let counter = AtomicI64::new(0);
let h1 = mgr
.reserve("1:test", || Ok(counter.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
mgr.commit(&h1);
let h2 = mgr
.reserve("1:test", || Ok(counter.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
assert_eq!(h2.value, 2);
mgr.commit(&h2);
}
#[test]
fn different_sequences_independent() {
let mgr = GapFreeManager::new();
let c1 = AtomicI64::new(0);
let c2 = AtomicI64::new(0);
let h1 = mgr
.reserve("1:seq_a", || Ok(c1.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
let h2 = mgr
.reserve("1:seq_b", || Ok(c2.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
assert_eq!(h1.value, 1);
assert_eq!(h2.value, 1);
mgr.commit(&h1);
mgr.commit(&h2);
}
#[test]
fn concurrent_serialization() {
let mgr = Arc::new(GapFreeManager::new());
let counter = Arc::new(AtomicI64::new(0));
let order = Arc::new(AtomicI64::new(0));
let mgr2 = Arc::clone(&mgr);
let counter2 = Arc::clone(&counter);
let order2 = Arc::clone(&order);
let h1 = mgr
.reserve("1:test", || Ok(counter.fetch_add(1, Ordering::SeqCst) + 1))
.unwrap();
assert_eq!(h1.value, 1);
order.store(1, Ordering::SeqCst);
let t = std::thread::spawn(move || {
let h2 = mgr2
.reserve("1:test", || Ok(counter2.fetch_add(1, Ordering::SeqCst) + 1))
.unwrap();
assert!(order2.load(Ordering::SeqCst) >= 2);
assert_eq!(h2.value, 2);
mgr2.commit(&h2);
});
std::thread::sleep(std::time::Duration::from_millis(50));
order.store(2, Ordering::SeqCst);
mgr.commit(&h1);
t.join().unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
}