use std::collections::HashMap;
use std::sync::{Arc, Mutex, MutexGuard};
use crate::{MutationWriteSet, NodeId, RelationshipId};
pub const LOCK_TABLE_SHARDS: usize = 256;
const SHARD_MASK: u64 = (LOCK_TABLE_SHARDS as u64) - 1;
pub struct LockTable {
nodes: [Mutex<HashMap<NodeId, Arc<Mutex<()>>>>; LOCK_TABLE_SHARDS],
rels: [Mutex<HashMap<RelationshipId, Arc<Mutex<()>>>>; LOCK_TABLE_SHARDS],
}
impl Default for LockTable {
fn default() -> Self {
Self::new()
}
}
impl LockTable {
pub fn new() -> Self {
Self {
nodes: std::array::from_fn(|_| Mutex::new(HashMap::new())),
rels: std::array::from_fn(|_| Mutex::new(HashMap::new())),
}
}
pub fn node_lock_arc(&self, id: NodeId) -> Arc<Mutex<()>> {
let shard = (id & SHARD_MASK) as usize;
let mut map = self.nodes[shard].lock().unwrap_or_else(|p| p.into_inner());
map.entry(id)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
pub fn rel_lock_arc(&self, id: RelationshipId) -> Arc<Mutex<()>> {
let shard = (id & SHARD_MASK) as usize;
let mut map = self.rels[shard].lock().unwrap_or_else(|p| p.into_inner());
map.entry(id)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
}
pub struct WriteSetLocks {
_guards: Vec<OwnedMutexGuard>,
}
struct OwnedMutexGuard {
guard: Option<MutexGuard<'static, ()>>,
_arc: Arc<Mutex<()>>,
}
impl OwnedMutexGuard {
fn lock(arc: Arc<Mutex<()>>) -> Self {
let guard = arc.lock().unwrap_or_else(|p| p.into_inner());
let guard: MutexGuard<'static, ()> =
unsafe { std::mem::transmute::<MutexGuard<'_, ()>, _>(guard) };
Self {
guard: Some(guard),
_arc: arc,
}
}
}
impl Drop for OwnedMutexGuard {
fn drop(&mut self) {
self.guard.take();
}
}
impl WriteSetLocks {
pub fn acquire(table: &LockTable, write_set: &MutationWriteSet) -> Self {
let mut node_ids: Vec<NodeId> = write_set.nodes.iter().copied().collect();
node_ids.sort_unstable();
let mut rel_ids: Vec<RelationshipId> = write_set.rels.iter().copied().collect();
rel_ids.sort_unstable();
let mut guards = Vec::with_capacity(node_ids.len() + rel_ids.len());
for id in node_ids {
guards.push(OwnedMutexGuard::lock(table.node_lock_arc(id)));
}
for id in rel_ids {
guards.push(OwnedMutexGuard::lock(table.rel_lock_arc(id)));
}
Self { _guards: guards }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn lock_table_returns_same_arc_for_same_id() {
let table = LockTable::new();
let a = table.node_lock_arc(42);
let b = table.node_lock_arc(42);
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn lock_table_distinct_ids_get_distinct_locks() {
let table = LockTable::new();
let a = table.node_lock_arc(1);
let b = table.node_lock_arc(2);
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn node_and_rel_namespaces_are_separate() {
let table = LockTable::new();
let n = table.node_lock_arc(7);
let r = table.rel_lock_arc(7);
assert!(!Arc::ptr_eq(&n, &r));
}
#[test]
fn write_set_locks_serialize_same_id() {
let table = Arc::new(LockTable::new());
let counter = Arc::new(Mutex::new(0u32));
let mut handles = Vec::new();
for _ in 0..4 {
let table = table.clone();
let counter = counter.clone();
handles.push(thread::spawn(move || {
let mut ws = MutationWriteSet::new();
ws.nodes.insert(99);
let _locks = WriteSetLocks::acquire(&table, &ws);
let mut c = counter.lock().unwrap();
let before = *c;
*c += 1;
thread::sleep(Duration::from_millis(5));
assert_eq!(*c, before + 1, "lock did not serialize");
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(*counter.lock().unwrap(), 4);
}
}