use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Arc;
use std::sync::{Mutex, MutexGuard};
use std::fmt::Debug;
use std::hash::Hash;
use lru::LruCache;
struct Inner<K, V>
where
K: Hash + Eq + Debug + Clone,
{
tid: usize,
last_inv: Option<Invalidate<K>>,
cache: LruCache<K, V>,
}
pub struct ThreadLocal<K, V>
where
K: Hash + Eq + Debug + Clone,
{
rx: Receiver<Invalidate<K>>,
wrlock: Arc<Mutex<Writer<K>>>,
inv_up_to_txid: Arc<AtomicU64>,
inner: Mutex<Inner<K, V>>,
}
struct Writer<K>
where
K: Hash + Eq + Debug + Clone,
{
txs: Vec<Sender<Invalidate<K>>>,
}
pub struct ThreadLocalWriteTxn<'a, K, V>
where
K: Hash + Eq + Debug + Clone,
{
txid: u64,
parent: MutexGuard<'a, Inner<K, V>>,
guard: MutexGuard<'a, Writer<K>>,
rollback: HashSet<K>,
inv_up_to_txid: Arc<AtomicU64>,
}
pub struct ThreadLocalReadTxn<'a, K, V>
where
K: Hash + Eq + Debug + Clone,
{
parent: MutexGuard<'a, Inner<K, V>>,
}
#[derive(Clone)]
struct Invalidate<K>
where
K: Hash + Eq + Debug + Clone,
{
k: K,
txid: u64,
}
impl<K, V> ThreadLocal<K, V>
where
K: Hash + Eq + Debug + Clone,
{
pub fn new(threads: usize, capacity: usize) -> Vec<Self> {
assert!(threads > 0);
let capacity = NonZeroUsize::new(capacity).unwrap();
let (txs, rxs): (Vec<_>, Vec<_>) = (0..threads).map(|_| channel::<Invalidate<K>>()).unzip();
let inv_up_to_txid = Arc::new(AtomicU64::new(0));
let wrlock = Arc::new(Mutex::new(Writer { txs }));
rxs.into_iter()
.enumerate()
.map(|(tid, rx)| ThreadLocal {
rx,
wrlock: wrlock.clone(),
inv_up_to_txid: inv_up_to_txid.clone(),
inner: Mutex::new(Inner {
tid,
last_inv: None,
cache: LruCache::new(capacity),
}),
})
.collect()
}
pub fn read(&mut self) -> ThreadLocalReadTxn<'_, K, V> {
let txid = self.inv_up_to_txid.load(Ordering::Acquire);
let parent = self.invalidate(txid);
ThreadLocalReadTxn { parent }
}
pub fn write(&mut self) -> ThreadLocalWriteTxn<'_, K, V> {
let guard = self.wrlock.lock().unwrap();
let inv_up_to_txid = self.inv_up_to_txid.clone();
let txid = self.inv_up_to_txid.load(Ordering::Acquire);
let txid = txid + 1;
let parent = self.invalidate(txid);
ThreadLocalWriteTxn {
txid,
parent,
guard,
rollback: HashSet::new(),
inv_up_to_txid,
}
}
#[allow(mismatched_lifetime_syntaxes)]
fn invalidate(&self, up_to: u64) -> MutexGuard<Inner<K, V>> {
let mut inner = self.inner.lock().unwrap();
if let Some(inv_txid) = inner.last_inv.as_ref().map(|inv| inv.txid) {
if inv_txid > up_to {
return inner;
} else {
let mut inv = None;
std::mem::swap(&mut inv, &mut inner.last_inv);
let inv = inv.unwrap();
inner.cache.pop(&inv.k);
}
}
while let Ok(inv) = self.rx.try_recv() {
if inv.txid > up_to {
inner.last_inv = Some(inv);
return inner;
} else {
inner.cache.pop(&inv.k);
}
}
inner
}
}
impl<K, V> ThreadLocalWriteTxn<'_, K, V>
where
K: Hash + Eq + Debug + Clone,
{
pub fn get(&mut self, k: &K) -> Option<&V> {
self.parent.cache.get(k)
}
pub fn contains_key(&mut self, k: &K) -> bool {
self.parent.cache.get(k).is_some()
}
pub fn insert(&mut self, k: K, v: V) -> Option<V> {
self.rollback.insert(k.clone());
self.parent.cache.put(k, v)
}
pub fn remove(&mut self, k: &K) -> Option<V> {
self.rollback.insert(k.clone());
self.parent.cache.pop(k)
}
pub fn commit(mut self) {
self.guard.txs.iter().enumerate().for_each(|(i, tx)| {
if i != self.parent.tid {
self.rollback.iter().for_each(|k| {
let _ = tx.send(Invalidate {
k: k.clone(),
txid: self.txid,
});
});
}
});
self.inv_up_to_txid.store(self.txid, Ordering::Release);
self.rollback.clear();
}
}
impl<K, V> Drop for ThreadLocalWriteTxn<'_, K, V>
where
K: Hash + Eq + Debug + Clone,
{
fn drop(&mut self) {
for k in self.rollback.iter() {
self.parent.cache.pop(k);
}
}
}
impl<K, V> ThreadLocalReadTxn<'_, K, V>
where
K: Hash + Eq + Debug + Clone,
{
pub fn get(&mut self, k: &K) -> Option<&V> {
self.parent.cache.get(k)
}
pub fn contains_key(&mut self, k: &K) -> bool {
self.parent.cache.get(k).is_some()
}
pub fn insert(&mut self, k: K, v: V) -> Option<V> {
self.parent.cache.put(k, v)
}
}
#[cfg(test)]
mod tests {
use super::ThreadLocal;
#[test]
fn test_basic() {
let mut cache: Vec<ThreadLocal<u32, u32>> = ThreadLocal::new(2, 8);
let mut cache_a = cache.pop().unwrap();
let mut cache_b = cache.pop().unwrap();
let mut wr_txn = cache_a.write();
let mut rd_txn = cache_b.read();
wr_txn.insert(1, 1);
wr_txn.insert(2, 2);
assert!(wr_txn.contains_key(&1));
assert!(wr_txn.contains_key(&2));
assert!(!rd_txn.contains_key(&1));
assert!(!rd_txn.contains_key(&2));
wr_txn.commit();
drop(rd_txn);
let mut rd_txn = cache_b.read();
assert!(!rd_txn.contains_key(&1));
assert!(!rd_txn.contains_key(&2));
rd_txn.insert(1, 1);
rd_txn.insert(2, 2);
drop(rd_txn);
let mut rd_txn = cache_b.read();
assert!(rd_txn.contains_key(&1));
assert!(rd_txn.contains_key(&2));
drop(rd_txn);
let mut wr_txn = cache_a.write();
assert!(wr_txn.contains_key(&1));
assert!(wr_txn.contains_key(&2));
wr_txn.insert(3, 3);
assert!(wr_txn.contains_key(&3));
drop(wr_txn);
let mut wr_txn = cache_a.write();
assert!(wr_txn.contains_key(&1));
assert!(wr_txn.contains_key(&2));
assert!(!wr_txn.contains_key(&3));
wr_txn.remove(&1);
wr_txn.remove(&2);
wr_txn.commit();
let mut rd_txn = cache_b.read();
assert!(!rd_txn.contains_key(&1));
assert!(!rd_txn.contains_key(&2));
}
}