use crate::fnbox::FnBox;
use lazy_static::lazy_static;
use std::cell::RefCell;
use std::sync::Mutex;
lazy_static! {
static ref TRANSACTION_MUTEX: Mutex<()> = Mutex::new(());
}
thread_local!(
static CURRENT_TRANSACTION: RefCell<Option<Transaction>> =
RefCell::new(None)
);
type Callback = Box<dyn FnBox + 'static>;
#[derive(Default)]
pub struct Transaction {
finalizers: Vec<Callback>,
}
impl Transaction {
fn new() -> Transaction {
Transaction { finalizers: vec![] }
}
pub fn later<F: FnOnce() + 'static>(&mut self, callback: F) {
self.finalizers.push(Box::new(callback));
}
fn finalizers(&mut self) -> Vec<Callback> {
use std::mem;
let mut finalizers = vec![];
mem::swap(&mut finalizers, &mut self.finalizers);
finalizers
}
}
pub fn commit<A, F: FnOnce() -> A>(body: F) -> A {
use std::mem;
let mut prev = CURRENT_TRANSACTION.with(|current| {
let mut prev = Some(Transaction::new());
mem::swap(&mut prev, &mut current.borrow_mut());
prev
});
let _lock = match prev {
None => Some(
TRANSACTION_MUTEX
.lock()
.expect("global transaction mutex poisoned"),
),
Some(_) => None,
};
let result = body();
match prev {
Some(ref mut trans) => with_current(|cur| trans.finalizers.append(&mut cur.finalizers)),
None => loop {
let callbacks = with_current(Transaction::finalizers);
if callbacks.is_empty() {
break;
}
for callback in callbacks {
callback.call_box();
}
},
}
CURRENT_TRANSACTION.with(|current| mem::swap(&mut prev, &mut current.borrow_mut()));
result
}
pub fn with_current<A, F: FnOnce(&mut Transaction) -> A>(action: F) -> A {
CURRENT_TRANSACTION.with(|current| match *current.borrow_mut() {
Some(ref mut trans) => action(trans),
_ => panic!("there is no active transaction to register a callback"),
})
}
pub fn later<F: FnOnce() + 'static>(action: F) {
with_current(|c| c.later(action))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn commit_single() {
let mut v = 3;
commit(|| v += 5);
assert_eq!(v, 8);
}
#[test]
fn commit_nested() {
let mut v = 3;
commit(|| {
commit(|| v *= 2);
v += 4;
});
assert_eq!(v, 10);
}
#[test]
fn commits_parallel() {
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
let v = Arc::new(Mutex::new(3));
let guards: Vec<_> = (0..3)
.map(|_| {
let v = v.clone();
thread::spawn(move || {
commit(move || {
*v.lock().unwrap() *= 2;
thread::sleep(Duration::from_millis(1));
*v.lock().unwrap() -= 1;
})
})
})
.collect();
for guard in guards {
guard.join().ok().expect("thread failed");
}
assert_eq!(&*v.lock().unwrap(), &17);
}
}