use std::marker::PhantomData;
use crate::prelude::*;
pub trait TransactionState {}
#[derive(Default, Debug)]
pub struct Committed;
impl TransactionState for Committed {}
#[derive(Default, Debug)]
pub struct Transacting;
impl TransactionState for Transacting {}
pub(crate) trait TransactionBuilder<'a> {
type Item: ToString;
fn new(state: &'a mut State<Self::Item>, value: Option<Self::Item>)
-> Self;
}
pub trait TransactionBehavior {
type Item: ToString;
type F: Fn(
Option<&Self::Item>,
Option<&Self::Item>,
) -> Result<(), UserError>;
fn commit_if_ok(self, lambda: Self::F) -> Result<(), Error>;
fn commit(self);
}
#[derive(Debug)]
pub struct Transaction<'a, T, S = Committed>
where
T: ToString,
{
state: &'a mut State<T>,
value: Option<T>,
_state: PhantomData<S>,
}
impl<'a, T> TransactionBuilder<'a> for Transaction<'a, T, Transacting>
where
T: ToString,
{
type Item = T;
fn new(
state: &'a mut State<Self::Item>,
value: Option<Self::Item>,
) -> Self {
Self {
state,
value,
_state: PhantomData,
}
}
}
impl<T> TransactionBehavior for Transaction<'_, T, Transacting>
where
T: ToString,
{
type Item = T;
type F = fn(Option<&T>, Option<&T>) -> Result<(), UserError>;
fn commit_if_ok(self, lambda: Self::F) -> Result<(), Error> {
self.state.modify(self.value);
match lambda(self.state.get_previous(), self.state.get()) {
Ok(()) => {
self.state.commit();
Ok(())
}
Err(e) => {
self.state.rollback();
Err(Error::transaction_rejected(
self.state.get_previous(),
self.state.get(),
e,
))
}
}
}
fn commit(self) {
self.state.modify(self.value);
self.state.commit();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_transaction() {
let mut state = State::default();
assert!(state.get().is_none());
let result =
state
.transact(Some(42))
.commit_if_ok(|previous, current| {
if current <= previous {
Err("current must be greater than previous".into())
} else {
Ok(())
}
});
assert!(result.is_ok());
assert_eq!(state.get(), Some(&42));
}
#[test]
fn test_reusable_function() {
let mut state = State::default();
let f = |a: Option<&i32>, b: Option<&i32>| -> Result<(), UserError> {
if b <= a {
Err("b must be greater than a".into())
} else {
Ok(())
}
};
let result = state.transact(Some(42)).commit_if_ok(f);
assert!(result.is_ok());
assert_eq!(state.get(), Some(&42));
assert_eq!(state.get_previous(), None);
let result = state.transact(Some(43)).commit_if_ok(f);
assert!(result.is_ok());
assert_eq!(state.get(), Some(&43));
assert_eq!(state.get_previous(), Some(&42));
let result = state.transact(Some(40)).commit_if_ok(f);
assert!(result.is_err());
assert_eq!(state.get(), Some(&43));
assert_eq!(state.get_previous(), Some(&42));
let result = state.transact(Some(39)).commit_if_ok(f);
assert!(result.is_err());
assert_eq!(state.get(), Some(&43));
assert_eq!(state.get_previous(), Some(&42));
}
#[test]
fn test_non_confirmed() {
let mut state = State::default();
state.transact(Some(42)).commit();
assert_eq!(state.get(), Some(&42));
assert_eq!(state.get_previous(), None);
state.transact(Some(43)).commit();
assert_eq!(state.get(), Some(&43));
assert_eq!(state.get_previous(), Some(&42));
state.transact(None).commit();
assert_eq!(state.get(), None);
assert_eq!(state.get_previous(), Some(&43));
}
}