use {Atomic, Guard};
use std::sync::atomic;
pub struct Stm<T> {
inner: Atomic<T>,
}
impl<T> Stm<T> {
pub fn new(data: Option<Box<T>>) -> Stm<T> {
Stm {
inner: Atomic::new(data),
}
}
pub fn update<F>(&self, f: F)
where
F: Fn(Option<Guard<T>>) -> Option<Box<T>>,
T: 'static,
{
loop {
let snapshot = self.inner.load(atomic::Ordering::Relaxed);
let snapshot_ptr = snapshot.as_ref().map(Guard::as_ptr);
let ret = f(snapshot);
if self.inner.compare_and_store(snapshot_ptr, ret, atomic::Ordering::Relaxed).is_ok() {
break;
}
}
}
pub fn load(&self) -> Option<Guard<T>> {
self.inner.load(atomic::Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::sync::Arc;
#[test]
fn single_threaded() {
let stm = Stm::new(None);
stm.update(|_| Some(Box::new(4)));
stm.update(|x| Some(Box::new(*x.unwrap() + 1)));
stm.update(|x| {
assert!(*x.unwrap() == 5);
None
});
assert!(stm.load().is_none());
}
#[test]
fn multi_threaded() {
let stm = Arc::new(Stm::new(Some(Box::new(0))));
let mut j = Vec::new();
for _ in 0..16 {
let stm = stm.clone();
j.push(thread::spawn(move || {
for _ in 0..1_000_000 {
stm.update(|x| Some(Box::new(*x.unwrap() + 1)))
}
}))
}
for i in j {
i.join().unwrap();
}
assert_eq!(*stm.load().unwrap(), 16_000_000);
}
}