use std::sync::Arc;
use crate::error::{ComposableError, ComposableResult, IntoErrorContext};
use crate::traits::monad::Monad;
use crate::transformers::MonadTransformer;
pub type StateValueMapper<S, A, B> = Box<dyn Fn((S, A)) -> (S, B) + Send + Sync>;
pub type StateCombiner<S, A, B, C> = Box<dyn Fn((S, A), (S, B)) -> (S, C) + Send + Sync>;
pub enum StateT<S, M, A> {
Pure(A),
LiftM(M),
Effect(Arc<dyn Fn(S) -> M + Send + Sync>),
}
impl<S, M, A> Clone for StateT<S, M, A>
where
S: 'static,
M: Clone + 'static,
A: Clone + 'static,
{
fn clone(&self) -> Self {
match self {
StateT::Pure(a) => StateT::Pure(a.clone()),
StateT::LiftM(m) => StateT::LiftM(m.clone()),
StateT::Effect(f) => StateT::Effect(Arc::clone(f)),
}
}
}
impl<S, M, A> StateT<S, M, A>
where
S: 'static,
M: 'static,
A: 'static,
{
pub fn new<F>(f: F) -> Self
where
F: Fn(S) -> M + Send + Sync + 'static,
{
StateT::Effect(Arc::new(f))
}
pub fn run_state(&self, state: S) -> M
where
M: Clone,
A: Clone,
{
match self {
StateT::Pure(_) => panic!("Cannot run Pure StateT without a base monad"),
StateT::LiftM(m) => m.clone(),
StateT::Effect(f) => f(state),
}
}
pub fn get<P>(pure: P) -> StateT<S, M, S>
where
P: Fn((S, S)) -> M + Send + Sync + 'static,
S: Clone + Send + Sync,
{
StateT::new(move |s: S| pure((s.clone(), s)))
}
pub fn put<P>(new_state: S, pure: P) -> StateT<S, M, S>
where
P: Fn((S, S)) -> M + Send + Sync + 'static,
S: Clone + Send + Sync,
{
StateT::new(move |s: S| pure((new_state.clone(), s)))
}
pub fn modify<F, P>(f: F, pure: P) -> StateT<S, M, ()>
where
F: Fn(S) -> S + Send + Sync + 'static,
P: Fn((S, ())) -> M + Send + Sync + 'static,
{
StateT::new(move |s: S| pure((f(s), ())))
}
pub fn fmap_with<F, B, MapFn>(&self, f: F, map_fn: MapFn) -> StateT<S, M, B>
where
F: Fn(A) -> B + Send + Sync + Clone + 'static,
MapFn: Fn(M, StateValueMapper<S, A, B>) -> M + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
M: Clone + 'static,
A: Clone + 'static,
B: 'static,
{
match self {
StateT::Pure(a) => {
let b = f(a.clone());
StateT::Pure(b)
},
StateT::LiftM(m) => StateT::LiftM(map_fn(m.clone(), {
let f_clone = f.clone();
let mapper: StateValueMapper<S, A, B> =
Box::new(move |(state, a)| (state, f_clone(a)));
mapper
})),
StateT::Effect(run_fn) => {
let run_fn = Arc::clone(run_fn);
StateT::new(move |s: S| {
let f_clone = f.clone();
let mapper: StateValueMapper<S, A, B> =
Box::new(move |(state, a)| (state, f_clone(a)));
map_fn(run_fn(s), mapper)
})
},
}
}
pub fn bind_with<F, B, BindFn, N>(&self, f: F, bind_fn: BindFn) -> StateT<S, N, B>
where
F: Fn(A) -> StateT<S, N, B> + Send + Sync + Clone + 'static,
BindFn: Fn(M, Box<dyn Fn((S, A)) -> N + Send + Sync>) -> N + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
M: Clone + Send + Sync + 'static,
A: Clone + 'static,
N: Clone + 'static,
B: Clone + 'static,
{
match self {
StateT::Pure(a) => f(a.clone()),
StateT::LiftM(m) => {
let m_clone = m.clone();
let f_clone = f.clone();
StateT::new(move |_: S| {
let f_for_closure = f_clone.clone();
let binder: Box<dyn Fn((S, A)) -> N + Send + Sync> =
Box::new(move |(state, a)| {
let next_state_t = f_for_closure(a);
next_state_t.run_state(state)
});
bind_fn(m_clone.clone(), binder)
})
},
StateT::Effect(run_fn) => {
let run_fn = Arc::clone(run_fn);
let f_clone = f.clone();
StateT::new(move |s: S| {
let f_for_closure = f_clone.clone();
let binder: Box<dyn Fn((S, A)) -> N + Send + Sync> =
Box::new(move |(state, a)| {
let next_state_t = f_for_closure(a);
next_state_t.run_state(state)
});
bind_fn(run_fn(s), binder)
})
},
}
}
pub fn combine_with<B, C, F, CombineFn>(
&self, other: &StateT<S, M, B>, f: F, combine_fn: CombineFn,
) -> StateT<S, M, C>
where
F: Fn(A, B) -> C + Send + Sync + Clone + 'static,
CombineFn: Fn(M, M, StateCombiner<S, A, B, C>) -> M + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
M: Clone + 'static,
A: Clone + 'static,
B: Clone + 'static,
C: 'static,
{
match (self, other) {
(StateT::Pure(a), StateT::Pure(b)) => {
let c = f(a.clone(), b.clone());
StateT::Pure(c)
},
(StateT::LiftM(m1), StateT::LiftM(m2)) => {
let combiner: StateCombiner<S, A, B, C> = Box::new(move |(s1, a), (_, b)| {
let f_clone = f.clone();
(s1, f_clone(a, b))
});
StateT::LiftM(combine_fn(m1.clone(), m2.clone(), combiner))
},
(StateT::Effect(self_run_fn), StateT::Effect(other_run_fn)) => {
let self_run_fn = Arc::clone(self_run_fn);
let other_run_fn = Arc::clone(other_run_fn);
StateT::new(move |s: S| {
let f_clone = f.clone();
let combiner: StateCombiner<S, A, B, C> =
Box::new(move |(_, a), (state, b)| {
let f_clone = f_clone.clone();
(state, f_clone(a, b))
});
combine_fn(self_run_fn(s.clone()), other_run_fn(s), combiner)
})
},
_ => panic!("Cannot combine StateT variants of different types"),
}
}
pub fn pure<P>(a: A, pure_fn: P) -> Self
where
P: Fn((S, A)) -> M + Send + Sync + 'static,
A: Clone + Send + Sync,
{
StateT::new(move |s| pure_fn((s, a.clone())))
}
pub fn exec_state<F, B>(&self, s: S, extract_state_fn: F) -> B
where
F: FnOnce(M) -> B,
M: Clone,
A: Clone,
{
extract_state_fn(self.run_state(s))
}
pub fn apply<B, C, ApplyFn>(&self, other: StateT<S, M, B>, apply_fn: ApplyFn) -> StateT<S, M, C>
where
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
M: Clone + Send + Sync + 'static,
ApplyFn: Fn(M, M) -> M + Send + Sync + 'static,
{
match (self, &other) {
(StateT::Pure(f), StateT::Pure(v)) => {
let f_clone = f.clone();
let v_clone = v.clone();
StateT::new(move |s: S| {
let f_state = StateT::Pure(f_clone.clone());
let v_state = StateT::Pure(v_clone.clone());
let f_result = f_state.run_state(s.clone());
let v_result = v_state.run_state(s);
apply_fn(f_result, v_result)
})
},
(StateT::LiftM(m1), StateT::LiftM(m2)) => {
StateT::LiftM(apply_fn(m1.clone(), m2.clone()))
},
(StateT::Effect(self_run), StateT::Effect(other_run)) => {
let self_run = Arc::clone(self_run);
let other_run = Arc::clone(other_run);
StateT::new(move |s: S| apply_fn(self_run(s.clone()), other_run(s)))
},
_ => {
let self_run = match self {
StateT::Pure(_) | StateT::LiftM(_) => {
let self_clone = self.clone();
Arc::new(move |s: S| self_clone.run_state(s))
as Arc<dyn Fn(S) -> M + Send + Sync>
},
StateT::Effect(run) => Arc::clone(run),
};
let other_run = match &other {
StateT::Pure(_) | StateT::LiftM(_) => {
let other_clone = other.clone();
Arc::new(move |s: S| other_clone.run_state(s))
as Arc<dyn Fn(S) -> M + Send + Sync>
},
StateT::Effect(run) => Arc::clone(run),
};
StateT::new(move |s: S| apply_fn(self_run(s.clone()), other_run(s)))
},
}
}
pub fn join<JoinFn, OuterM>(&self, join_fn: JoinFn) -> StateT<S, OuterM, A>
where
A: Clone + Send + Sync + 'static,
M: Clone + 'static,
JoinFn: Fn(M) -> OuterM + Send + Sync + 'static,
OuterM: 'static,
{
match self {
StateT::Pure(inner_state_t) => StateT::Pure(inner_state_t.clone()),
StateT::LiftM(m) => StateT::LiftM(join_fn(m.clone())),
StateT::Effect(run_fn) => {
let run_fn = Arc::clone(run_fn);
StateT::new(move |s: S| join_fn(run_fn(s)))
},
}
}
}
impl<S, M, A> MonadTransformer for StateT<S, M, A>
where
S: Clone + Send + Sync + 'static,
M: Monad<Source = (S, A)> + Send + Sync + Clone + 'static,
A: Clone + Send + Sync + 'static,
{
type BaseMonad = M;
#[inline]
fn lift(base: M) -> Self {
StateT::new(move |_: S| base.clone())
}
}
impl<S, E, A> StateT<S, Result<(S, A), E>, A>
where
S: Clone + 'static,
E: 'static,
A: Send + Sync + 'static,
{
pub fn try_run_state(&self, state: S) -> ComposableResult<(S, A), E>
where
A: Clone,
E: Clone,
{
match self {
StateT::Pure(_) => panic!("Cannot run Pure StateT without proper context"),
StateT::LiftM(result) => match result.as_ref() {
Ok((s, a)) => Ok((s.clone(), a.clone())),
Err(e) => Err(ComposableError::new(e.clone())),
},
StateT::Effect(run_fn) => run_fn(state).map_err(ComposableError::new),
}
}
pub fn try_run_state_with_context<C>(&self, state: S, context: C) -> ComposableResult<(S, A), E>
where
C: IntoErrorContext,
A: Clone,
E: Clone,
{
let context = context.into_error_context();
match self {
StateT::Pure(_) => panic!("Cannot run Pure StateT without proper context"),
StateT::LiftM(result) => match result.as_ref() {
Ok((s, a)) => Ok((s.clone(), a.clone())),
Err(e) => Err(ComposableError::new(e.clone()).with_context(context.clone())),
},
StateT::Effect(run_fn) => {
run_fn(state).map_err(|e| ComposableError::new(e).with_context(context.clone()))
},
}
}
pub fn map_error<F, E2>(&self, f: F) -> StateT<S, Result<(S, A), E2>, A>
where
F: Fn(E) -> E2 + Send + Sync + 'static,
E2: 'static,
A: Clone,
E: Clone,
{
match self {
StateT::Pure(a) => StateT::Pure(a.clone()),
StateT::LiftM(result) => {
let mapped_result = match result.as_ref() {
Ok((s, a)) => Ok((s.clone(), a.clone())),
Err(e) => Err(f(e.clone())),
};
StateT::LiftM(mapped_result)
},
StateT::Effect(run_fn) => {
let run_fn = Arc::clone(run_fn);
StateT::new(move |s: S| run_fn(s).map_err(&f))
},
}
}
pub fn try_eval_state(&self, state: S) -> ComposableResult<A, E>
where
A: Clone,
E: Clone,
{
self.try_run_state(state).map(|(_, a)| a)
}
pub fn try_eval_state_with_context<C>(&self, state: S, context: C) -> ComposableResult<A, E>
where
C: IntoErrorContext,
A: Clone,
E: Clone,
{
self.try_run_state_with_context(state, context)
.map(|(_, a)| a)
}
pub fn try_exec_state(&self, state: S) -> ComposableResult<S, E>
where
A: Clone,
E: Clone,
{
self.try_run_state(state).map(|(s, _)| s)
}
}
use crate::datatypes::id::Id;
impl<S, A> StateT<S, Id<(S, A)>, A>
where
S: Clone + Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
{
pub fn to_state(self) -> crate::datatypes::state::State<S, A> {
crate::datatypes::state::State::new(move |s: S| {
let result = self.run_state(s.clone());
let (new_state, value) = result.unwrap().clone();
(value, new_state)
})
}
pub fn from_state(state: crate::datatypes::state::State<S, A>) -> Self {
StateT::new(move |s: S| {
let (a, s2) = state.run_state(s);
Id::new((s2, a))
})
}
}