1use crate::monad::Monad;
4use std::iter::FromIterator;
5
6pub struct StateT<'a, S, M, A>
7where
8 M: 'a + Monad<Item=(A, S)>,
9{
10 pub run_state_t: Box<dyn 'a + Fn(S) -> M>,
11}
12
13impl<'a, A, S, M> StateT<'a, S, M, A>
14 where
15 A: 'a + Clone,
16 S: 'a + Clone,
17 M: 'a + Monad<Item =(A, S)> + FromIterator<(A,S)> {
19 pub fn pure(x: A) -> Self
21 {
22 StateT { run_state_t: Box::new( move |s: S| M::pure(( x.clone(), s)))} }
24
25
26 pub fn lift<N>(n: N) -> Self
27 where
28 N: 'a + Clone + Monad<Item=A>,
29 {
30 StateT { run_state_t: Box::new(
31 move |s| n.clone().into_iter().map( | a| (a, s.clone())
33 ).collect::<M>()
34 )}
35 }
36
37 pub fn lift_iter<I>(it: I) -> Self
38 where
39 I: 'a + Clone + Iterator<Item=A>,
40 {
41 StateT { run_state_t: Box::new(
42 move |s| it.clone().map( | a| (a, s.clone())
44 ).collect::<M>()
45 )}
46 }
47
48 pub fn bind<N, B, F: 'a>(self, f: F) -> StateT<'a, S, N, B>
50 where
51 F: 'a + Copy + Fn(A) -> StateT<'a, S, N, B>,
52 N: 'a + Monad<Item=(B, S)> + FromIterator<(B, S)>,
53 B: 'a,
54 {
55 StateT { run_state_t: Box::new( move |s: S| {
56 let m = (*self.run_state_t) (s); let g = move |(v, s1)| (* f( v).run_state_t) (s1);
58 M::bind( m, g).collect::<N>()
59 })}
60
61 }
62
63
64 pub fn initial_state(self, s: S) -> M {
66 (*self.run_state_t) (s)
67 }
68}
69
70pub fn get<'a, S>() -> StateT<'a, S, Vec<(S, S)>, S>
71 where
72 S: 'a + Clone,
73{
74 StateT { run_state_t: Box::new( |s: S| {let p = (s.clone(), s); Vec::pure(p)}
75 )}
76}
77
78pub fn put<'a, S>( s: S) -> StateT<'a, S, Vec<((), S)>, ()>
79 where
80 S: 'a + Clone,
81{
82 StateT { run_state_t: Box::new( move |_| {let p = ((), s.clone()); Vec::pure(p)}
83 )}
84}
85
86#[macro_export]
105macro_rules! stt_mdo {
106 (pure $e:expr) => [StateT::<'_, St, Vec<_>, _>::pure($e)];
107
108 (lift $nested_monad:expr) => [StateT::<'_, St, Vec<_>, _>::lift($nested_monad)];
109
110 (guard $boolean:expr ; $($rest:tt)*) => [StateT::<'_, St, Vec<_>, _>::lift(if $boolean {vec![()]} else {vec![]}).bind( move |_| { stt_mdo!($($rest)*)} )];
111
112 (_ <- $monad:expr ; $($rest:tt)* ) => [StateT::bind(($monad), move |_| { stt_mdo!($($rest)*)} )];
113
114 ($v:ident <- lift_iter $it:expr ; $($rest:tt)* ) => [StateT::<'_, St, Vec<_>, _>::lift_iter($it).bind( move |$v| { stt_mdo!($($rest)*)} )];
115
116 (& $v:ident <- lift $nested_monad:expr ; $($rest:tt)* ) => [StateT::<'_, St, Vec<_>, _>::lift($nested_monad).bind( move |& $v| { stt_mdo!($($rest)*)} )];
117
118 ($v:ident <- lift $nested_monad:expr ; $($rest:tt)* ) => [StateT::<'_, St, Vec<_>, _>::lift($nested_monad).bind( move |$v| { stt_mdo!($($rest)*)} )];
119
120 (let $v:ident = $e:expr ; $($rest:tt)* ) => [StateT::bind(StateT::<'_, St, Vec<_>, _>::pure($e), move |$v| { stt_mdo!($($rest)*)} )];
121
122 ($v:ident <- $monad:expr ; $($rest:tt)* ) => [StateT::bind(($monad), move |$v| { stt_mdo!($($rest)*)} )];
123
124 ($monad:expr ) => [$monad];
125}
126