use self::error::WtmError;
use core::{borrow::Borrow, hash::Hash};
use super::*;
pub struct Wtm<K, V, C, P> {
pub(super) read_ts: u64,
pub(super) size: u64,
pub(super) count: u64,
pub(super) orc: Arc<Oracle<C>>,
pub(super) conflict_manager: Option<C>,
pub(super) pending_writes: Option<P>,
pub(super) duplicate_writes: OneOrMore<Entry<K, V>>,
pub(super) discarded: bool,
pub(super) done_read: bool,
}
impl<K, V, C, P> Drop for Wtm<K, V, C, P> {
fn drop(&mut self) {
if !self.discarded {
self.discard();
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P> {
#[inline]
pub const fn version(&self) -> u64 {
self.read_ts
}
#[doc(hidden)]
#[inline]
pub fn __set_read_version(&mut self, version: u64) {
self.read_ts = version;
}
#[inline]
pub fn pwm(&self) -> Option<&P> {
self.pending_writes.as_ref()
}
#[inline]
pub fn cm(&self) -> Option<&C> {
self.conflict_manager.as_ref()
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K>,
{
pub fn marker(&mut self) -> Option<Marker<'_, C>> {
self.conflict_manager.as_mut().map(Marker::new)
}
pub fn marker_with_pm(&mut self) -> Option<(Marker<'_, C>, &P)> {
self
.conflict_manager
.as_mut()
.map(|marker| (Marker::new(marker), self.pending_writes.as_ref().unwrap()))
}
pub fn mark_read(&mut self, k: &K) {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read(k);
}
}
pub fn mark_conflict(&mut self, k: &K) {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_conflict(k);
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K>,
P: Pwm<Key = K, Value = V>,
{
pub fn insert(&mut self, key: K, value: V) -> Result<(), TransactionError<C::Error, P::Error>> {
self.insert_with_in(key, value)
}
pub fn remove(&mut self, key: K) -> Result<(), TransactionError<C::Error, P::Error>> {
self.modify(Entry {
data: EntryData::Remove(key),
version: 0,
})
}
pub fn rollback(&mut self) -> Result<(), TransactionError<C::Error, P::Error>> {
if self.discarded {
return Err(TransactionError::Discard);
}
self
.pending_writes
.as_mut()
.unwrap()
.rollback()
.map_err(TransactionError::Pwm)?;
self
.conflict_manager
.as_mut()
.unwrap()
.rollback()
.map_err(TransactionError::Cm)?;
Ok(())
}
pub fn contains_key(
&mut self,
key: &K,
) -> Result<Option<bool>, TransactionError<C::Error, P::Error>> {
if self.discarded {
return Err(TransactionError::Discard);
}
match self
.pending_writes
.as_ref()
.unwrap()
.get(key)
.map_err(TransactionError::pending)?
{
Some(ent) => {
if ent.value.is_none() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read(key);
}
Ok(None)
}
}
}
pub fn get<'a, 'b: 'a>(
&'a mut self,
key: &'b K,
) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>> {
if self.discarded {
return Err(TransactionError::Discard);
}
if let Some(e) = self
.pending_writes
.as_ref()
.unwrap()
.get(key)
.map_err(TransactionError::Pwm)?
{
if e.value.is_none() {
return Ok(None);
}
Ok(Some(EntryRef {
data: match &e.value {
Some(value) => EntryDataRef::Insert { key, value },
None => EntryDataRef::Remove(key),
},
version: e.version,
}))
} else {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read(key);
}
Ok(None)
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K>,
P: Pwm<Key = K, Value = V>,
{
pub fn commit<F, E>(&mut self, apply: F) -> Result<(), WtmError<C::Error, P::Error, E>>
where
F: FnOnce(OneOrMore<Entry<K, V>>) -> Result<(), E>,
E: std::error::Error,
{
if self.discarded {
return Err(TransactionError::Discard.into());
}
if self.pending_writes.as_ref().unwrap().is_empty() {
self.discard();
return Ok(());
}
let (commit_ts, entries) = self.commit_entries().map_err(|e| match e {
TransactionError::Conflict => e,
_ => {
self.discard();
e
}
})?;
apply(entries)
.map(|_| {
self.orc().done_commit(commit_ts);
self.discard();
})
.map_err(|e| {
self.orc().done_commit(commit_ts);
self.discard();
WtmError::commit(e)
})
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmEquivalent<Key = K>,
P: Pwm<Key = K, Value = V>,
{
pub fn mark_read_equivalent<Q>(&mut self, k: &Q)
where
K: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_equivalent(k);
}
}
pub fn mark_conflict_equivalent<Q>(&mut self, k: &Q)
where
K: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_conflict_equivalent(k);
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmEquivalent<Key = K>,
P: PwmEquivalent<Key = K, Value = V>,
{
pub fn contains_key_equivalent<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
match self
.pending_writes
.as_ref()
.unwrap()
.get_equivalent(key)
.map_err(TransactionError::pending)?
{
Some(ent) => {
if ent.value.is_none() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_equivalent(key);
}
Ok(None)
}
}
}
pub fn get_equivalent<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
if let Some((k, e)) = self
.pending_writes
.as_ref()
.unwrap()
.get_entry_equivalent(key)
.map_err(TransactionError::Pwm)?
{
if e.value.is_none() {
return Ok(None);
}
Ok(Some(EntryRef {
data: match &e.value {
Some(value) => EntryDataRef::Insert { key: k, value },
None => EntryDataRef::Remove(k),
},
version: e.version,
}))
} else {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_equivalent(key);
}
Ok(None)
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmComparable<Key = K>,
P: PwmEquivalent<Key = K, Value = V>,
{
pub fn contains_key_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Ord + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
match self
.pending_writes
.as_ref()
.unwrap()
.get_equivalent(key)
.map_err(TransactionError::pending)?
{
Some(ent) => {
if ent.value.is_none() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_comparable(key);
}
Ok(None)
}
}
}
pub fn get_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Ord + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
if let Some((k, e)) = self
.pending_writes
.as_ref()
.unwrap()
.get_entry_equivalent(key)
.map_err(TransactionError::Pwm)?
{
if e.value.is_none() {
return Ok(None);
}
Ok(Some(EntryRef {
data: match &e.value {
Some(value) => EntryDataRef::Insert { key: k, value },
None => EntryDataRef::Remove(k),
},
version: e.version,
}))
} else {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_comparable(key);
}
Ok(None)
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmComparable<Key = K>,
P: Pwm<Key = K, Value = V>,
{
pub fn mark_read_comparable<Q>(&mut self, k: &Q)
where
K: Borrow<Q>,
Q: ?Sized + Ord,
{
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_comparable(k);
}
}
pub fn mark_conflict_comparable<Q>(&mut self, k: &Q)
where
K: Borrow<Q>,
Q: ?Sized + Ord,
{
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_conflict_comparable(k);
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmComparable<Key = K>,
P: PwmComparable<Key = K, Value = V>,
{
pub fn contains_key_comparable<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Ord,
{
if self.discarded {
return Err(TransactionError::Discard);
}
match self
.pending_writes
.as_ref()
.unwrap()
.get_comparable(key)
.map_err(TransactionError::pending)?
{
Some(ent) => {
if ent.value.is_none() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_comparable(key);
}
Ok(None)
}
}
}
pub fn get_comparable<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Ord,
{
if self.discarded {
return Err(TransactionError::Discard);
}
if let Some((k, e)) = self
.pending_writes
.as_ref()
.unwrap()
.get_entry_comparable(key)
.map_err(TransactionError::Pwm)?
{
if e.value.is_none() {
return Ok(None);
}
Ok(Some(EntryRef {
data: match &e.value {
Some(value) => EntryDataRef::Insert { key: k, value },
None => EntryDataRef::Remove(k),
},
version: e.version,
}))
} else {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_comparable(key);
}
Ok(None)
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: CmEquivalent<Key = K>,
P: PwmComparable<Key = K, Value = V>,
{
pub fn contains_key_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Ord + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
match self
.pending_writes
.as_ref()
.unwrap()
.get_comparable(key)
.map_err(TransactionError::pending)?
{
Some(ent) => {
if ent.value.is_none() {
return Ok(Some(false));
}
Ok(Some(true))
}
None => {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_equivalent(key);
}
Ok(None)
}
}
}
pub fn get_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
&'a mut self,
key: &'b Q,
) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
where
K: Borrow<Q>,
Q: ?Sized + Eq + Ord + Hash,
{
if self.discarded {
return Err(TransactionError::Discard);
}
if let Some((k, e)) = self
.pending_writes
.as_ref()
.unwrap()
.get_entry_comparable(key)
.map_err(TransactionError::Pwm)?
{
if e.value.is_none() {
return Ok(None);
}
Ok(Some(EntryRef {
data: match &e.value {
Some(value) => EntryDataRef::Insert { key: k, value },
None => EntryDataRef::Remove(k),
},
version: e.version,
}))
} else {
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_read_equivalent(key);
}
Ok(None)
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K> + Send,
P: Pwm<Key = K, Value = V> + Send,
{
pub fn commit_with_callback<F, E, R>(
&mut self,
apply: F,
callback: impl FnOnce(Result<(), E>) -> R + Send + 'static,
) -> Result<std::thread::JoinHandle<R>, WtmError<C::Error, P::Error, E>>
where
K: Send + 'static,
V: Send + 'static,
F: FnOnce(OneOrMore<Entry<K, V>>) -> Result<(), E> + Send + 'static,
E: std::error::Error,
R: Send + 'static,
C: 'static,
{
if self.discarded {
return Err(WtmError::transaction(TransactionError::Discard));
}
if self.pending_writes.as_ref().unwrap().is_empty() {
self.discard();
return Ok(std::thread::spawn(move || callback(Ok(()))));
}
let (commit_ts, entries) = self.commit_entries().map_err(|e| match e {
TransactionError::Conflict => e,
_ => {
self.discard();
e
}
})?;
let orc = self.orc.clone();
Ok(std::thread::spawn(move || {
callback(
apply(entries)
.map(|_| {
orc.done_commit(commit_ts);
})
.map_err(|e| {
orc.done_commit(commit_ts);
e
}),
)
}))
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K>,
P: Pwm<Key = K, Value = V>,
{
fn insert_with_in(
&mut self,
key: K,
value: V,
) -> Result<(), TransactionError<C::Error, P::Error>> {
let ent = Entry {
data: EntryData::Insert { key, value },
version: self.read_ts,
};
self.modify(ent)
}
fn modify(&mut self, ent: Entry<K, V>) -> Result<(), TransactionError<C::Error, P::Error>> {
if self.discarded {
return Err(TransactionError::Discard);
}
let pending_writes = self.pending_writes.as_mut().unwrap();
pending_writes
.validate_entry(&ent)
.map_err(TransactionError::Pwm)?;
let cnt = self.count + 1;
let size = self.size + pending_writes.estimate_size(&ent);
if cnt >= pending_writes.max_batch_entries() || size >= pending_writes.max_batch_size() {
return Err(TransactionError::LargeTxn);
}
self.count = cnt;
self.size = size;
if let Some(ref mut conflict_manager) = self.conflict_manager {
conflict_manager.mark_conflict(ent.key());
}
let eversion = ent.version;
let (ek, ev) = ent.split();
if let Some((old_key, old_value)) = pending_writes
.remove_entry(&ek)
.map_err(TransactionError::Pwm)?
{
if old_value.version != eversion {
self
.duplicate_writes
.push(Entry::unsplit(old_key, old_value));
}
}
pending_writes
.insert(ek, ev)
.map_err(TransactionError::Pwm)?;
Ok(())
}
}
impl<K, V, C, P> Wtm<K, V, C, P>
where
C: Cm<Key = K>,
P: Pwm<Key = K, Value = V>,
{
fn commit_entries(
&mut self,
) -> Result<(u64, OneOrMore<Entry<K, V>>), TransactionError<C::Error, P::Error>> {
let _write_lock = self.orc.write_serialize_lock.lock();
let conflict_manager = if self.conflict_manager.is_none() {
None
} else {
mem::take(&mut self.conflict_manager)
};
match self
.orc
.new_commit_ts(&mut self.done_read, self.read_ts, conflict_manager)
{
CreateCommitTimestampResult::Conflict(conflict_manager) => {
self.conflict_manager = conflict_manager;
Err(TransactionError::Conflict)
}
CreateCommitTimestampResult::Timestamp(commit_ts) => {
let pending_writes = mem::take(&mut self.pending_writes).unwrap();
let duplicate_writes = mem::take(&mut self.duplicate_writes);
let mut entries =
OneOrMore::with_capacity(pending_writes.len() + self.duplicate_writes.len());
let process_entry = |entries: &mut OneOrMore<Entry<K, V>>, mut ent: Entry<K, V>| {
ent.version = commit_ts;
entries.push(ent);
};
pending_writes
.into_iter()
.for_each(|(k, v)| process_entry(&mut entries, Entry::unsplit(k, v)));
duplicate_writes
.into_iter()
.for_each(|ent| process_entry(&mut entries, ent));
assert_ne!(commit_ts, 0);
Ok((commit_ts, entries))
}
}
}
}
impl<K, V, C, P> Wtm<K, V, C, P> {
fn done_read(&mut self) {
if !self.done_read {
self.done_read = true;
self.orc().read_mark.done(self.read_ts).unwrap();
}
}
#[inline]
fn orc(&self) -> &Oracle<C> {
&self.orc
}
pub fn discard(&mut self) {
if self.discarded {
return;
}
self.discarded = true;
self.done_read();
}
#[inline]
pub const fn is_discard(&self) -> bool {
self.discarded
}
}
#[cfg(test)]
mod tests {
use std::{collections::BTreeSet, convert::Infallible, marker::PhantomData};
use super::*;
#[test]
fn wtm() {
let tm = Tm::<String, u64, HashCm<String>, IndexMapPwm<String, u64>>::new("test", 0);
let mut wtm = tm.write(Default::default(), Default::default()).unwrap();
assert!(!wtm.is_discard());
assert!(wtm.pwm().is_some());
assert!(wtm.cm().is_some());
let mut marker = wtm.marker().unwrap();
marker.mark(&"1".to_owned());
marker.mark_equivalent("3");
marker.mark_conflict(&"2".to_owned());
marker.mark_conflict_equivalent("4");
wtm.mark_read(&"2".to_owned());
wtm.mark_conflict(&"1".to_owned());
wtm.mark_conflict_equivalent("2");
wtm.mark_read_equivalent("3");
wtm.insert("5".into(), 5).unwrap();
assert_eq!(wtm.contains_key_equivalent("5").unwrap(), Some(true));
assert_eq!(
wtm.get_equivalent("5").unwrap().unwrap().value().unwrap(),
&5
);
assert_eq!(wtm.contains_key_equivalent("6").unwrap(), None);
assert_eq!(wtm.get_equivalent("6").unwrap(), None);
}
struct TestCm<K> {
conflict_keys: BTreeSet<usize>,
reads: BTreeSet<usize>,
_m: PhantomData<K>,
}
impl<K> Cm for TestCm<K> {
type Error = Infallible;
type Key = K;
type Options = ();
fn new(_options: Self::Options) -> Result<Self, Self::Error> {
Ok(Self {
conflict_keys: BTreeSet::new(),
reads: BTreeSet::new(),
_m: PhantomData,
})
}
fn mark_read(&mut self, key: &Self::Key) {
self.reads.insert(key as *const K as usize);
}
fn mark_conflict(&mut self, key: &Self::Key) {
self.conflict_keys.insert(key as *const K as usize);
}
fn has_conflict(&self, other: &Self) -> bool {
if self.reads.is_empty() {
return false;
}
for ro in self.reads.iter() {
if other.conflict_keys.contains(ro) {
return true;
}
}
false
}
fn rollback(&mut self) -> Result<(), Self::Error> {
self.conflict_keys.clear();
self.reads.clear();
Ok(())
}
}
impl<K> CmComparable for TestCm<K> {
fn mark_read_comparable<Q>(&mut self, key: &Q)
where
Self::Key: Borrow<Q>,
Q: Ord + ?Sized,
{
self.reads.insert(key as *const Q as *const () as usize);
}
fn mark_conflict_comparable<Q>(&mut self, key: &Q)
where
Self::Key: Borrow<Q>,
Q: Ord + ?Sized,
{
self
.conflict_keys
.insert(key as *const Q as *const () as usize);
}
}
#[test]
fn wtm2() {
let tm = Tm::<Arc<u64>, u64, TestCm<Arc<u64>>, IndexMapPwm<Arc<u64>, u64>>::new("test", 0);
let mut wtm = tm.write(Default::default(), ()).unwrap();
assert!(!wtm.is_discard());
assert!(wtm.pwm().is_some());
assert!(wtm.cm().is_some());
let mut marker = wtm.marker().unwrap();
let one = Arc::new(1);
let two = Arc::new(2);
let three = Arc::new(3);
let four = Arc::new(4);
let five = Arc::new(5);
marker.mark(&one);
marker.mark_comparable(&three);
marker.mark_conflict(&two);
marker.mark_conflict_comparable(&four);
wtm.mark_read(&two);
wtm.mark_conflict(&one);
wtm.mark_conflict_comparable(&two);
wtm.mark_read_comparable(&three);
wtm.insert(five.clone(), 5).unwrap();
assert_eq!(
wtm.contains_key_comparable_cm_equivalent_pm(&five).unwrap(),
Some(true)
);
assert_eq!(
wtm
.get_comparable_cm_equivalent_pm(&five)
.unwrap()
.unwrap()
.value()
.unwrap(),
&5
);
let six = Arc::new(6);
assert_eq!(
wtm.contains_key_comparable_cm_equivalent_pm(&six).unwrap(),
None
);
assert_eq!(wtm.get_comparable_cm_equivalent_pm(&six).unwrap(), None);
}
#[test]
fn wtm3() {
let tm = Tm::<Arc<u64>, u64, TestCm<Arc<u64>>, BTreePwm<Arc<u64>, u64>>::new("test", 0);
let mut wtm = tm.write((), ()).unwrap();
assert!(!wtm.is_discard());
assert!(wtm.pwm().is_some());
assert!(wtm.cm().is_some());
let mut marker = wtm.marker().unwrap();
let one = Arc::new(1);
let two = Arc::new(2);
let three = Arc::new(3);
let four = Arc::new(4);
let five = Arc::new(5);
marker.mark(&one);
marker.mark_comparable(&three);
marker.mark_conflict(&two);
marker.mark_conflict_comparable(&four);
wtm.mark_read(&two);
wtm.mark_conflict(&one);
wtm.mark_conflict_comparable(&two);
wtm.mark_read_comparable(&three);
wtm.insert(five.clone(), 5).unwrap();
assert_eq!(wtm.contains_key_comparable(&five).unwrap(), Some(true));
assert_eq!(
wtm.get_comparable(&five).unwrap().unwrap().value().unwrap(),
&5
);
let six = Arc::new(6);
assert_eq!(wtm.contains_key_comparable(&six).unwrap(), None);
assert_eq!(wtm.get_comparable(&six).unwrap(), None);
}
}