Skip to main content

karpal_effect/
state_t.rs

1#![allow(clippy::type_complexity)]
2
3use core::marker::PhantomData;
4
5use karpal_core::hkt::HKT;
6
7use crate::classes::{ApplicativeSt, ChainSt, FunctorSt};
8use crate::trans::MonadTrans;
9
10#[cfg(all(not(feature = "std"), feature = "alloc"))]
11use alloc::rc::Rc;
12#[cfg(feature = "std")]
13use std::rc::Rc;
14
15/// StateT monad transformer: adds mutable state to an inner monad.
16///
17/// `StateTF<S, M>::Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>`
18///
19/// Unlike ReaderT, the state is threaded (modified) through computations.
20/// When the inner monad is `IdentityF`, this is equivalent to the State monad
21/// from `karpal-core`'s adjunction module.
22pub struct StateTF<S, M>(PhantomData<(S, M)>);
23
24impl<S: 'static, M: HKT + 'static> HKT for StateTF<S, M> {
25    type Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>;
26}
27
28impl<S: Clone + 'static, M: FunctorSt + 'static> MonadTrans<M> for StateTF<S, M> {
29    fn lift<A: 'static>(ma: M::Of<A>) -> Box<dyn Fn(S) -> M::Of<(S, A)>>
30    where
31        M::Of<A>: Clone,
32    {
33        Box::new(move |s| M::fmap_st(ma.clone(), move |a| (s.clone(), a)))
34    }
35}
36
37/// StateT `pure`: wrap a value without modifying state.
38pub fn state_t_pure<S: Clone + 'static, M: ApplicativeSt + 'static, A: Clone + 'static>(
39    a: A,
40) -> Box<dyn Fn(S) -> M::Of<(S, A)>> {
41    Box::new(move |s| M::pure_st((s, a.clone())))
42}
43
44/// StateT `fmap`: apply a function to the result, leaving state unchanged.
45pub fn state_t_fmap<S: 'static, M: FunctorSt + 'static, A: 'static, B: 'static>(
46    fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
47    f: impl Fn(A) -> B + 'static,
48) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
49    let f_rc = Rc::new(f);
50    Box::new(move |s| {
51        let f_inner = f_rc.clone();
52        M::fmap_st(fa(s), move |(s2, a)| (s2, f_inner(a)))
53    })
54}
55
56/// StateT `chain`: sequence stateful computations, threading state.
57///
58/// The state from the first computation is passed to the second.
59pub fn state_t_chain<S: Clone + 'static, M: ChainSt + 'static, A: 'static, B: 'static>(
60    fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
61    f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
62) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
63    let f_rc = Rc::new(f);
64    Box::new(move |s| {
65        let f_inner = f_rc.clone();
66        M::chain_st(fa(s), move |(s2, a)| {
67            let state_b = f_inner(a);
68            state_b(s2)
69        })
70    })
71}
72
73/// StateT `get`: read the current state.
74pub fn state_t_get<S: Clone + 'static, M: ApplicativeSt + 'static>()
75-> Box<dyn Fn(S) -> M::Of<(S, S)>> {
76    Box::new(|s: S| {
77        let s2 = s.clone();
78        M::pure_st((s, s2))
79    })
80}
81
82/// StateT `put`: replace the state.
83pub fn state_t_put<S: Clone + 'static, M: ApplicativeSt + 'static>(
84    new_state: S,
85) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
86    Box::new(move |_| M::pure_st((new_state.clone(), ())))
87}
88
89/// StateT `modify`: apply a function to the state.
90pub fn state_t_modify<S: Clone + 'static, M: ApplicativeSt + 'static>(
91    f: impl Fn(S) -> S + 'static,
92) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
93    Box::new(move |s| {
94        let new_s = f(s);
95        M::pure_st((new_s, ()))
96    })
97}
98
99/// StateT `run`: run the computation with initial state.
100pub fn state_t_run<S, M: HKT, A>(state: &dyn Fn(S) -> M::Of<(S, A)>, initial: S) -> M::Of<(S, A)> {
101    state(initial)
102}
103
104// --- FunctorSt / ApplicativeSt / ChainSt for StateTF ---
105
106impl<S: 'static, M: FunctorSt + 'static> FunctorSt for StateTF<S, M> {
107    fn fmap_st<A: 'static, B: 'static>(
108        fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
109        f: impl Fn(A) -> B + 'static,
110    ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
111        state_t_fmap::<S, M, A, B>(fa, f)
112    }
113}
114
115// Note: ApplicativeSt is not implemented for StateTF because pure_st
116// cannot produce a Box<dyn Fn(S) -> M::Of<(S, A)>> from a single A without Clone.
117// Use the standalone state_t_pure function instead (which requires A: Clone).
118
119impl<S: Clone + 'static, M: ChainSt + 'static> ChainSt for StateTF<S, M> {
120    fn chain_st<A: 'static, B: 'static>(
121        fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
122        f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
123    ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
124        state_t_chain::<S, M, A, B>(fa, f)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use karpal_core::hkt::{IdentityF, OptionF};
132
133    #[test]
134    fn state_t_pure_identity() {
135        let s = state_t_pure::<i32, IdentityF, _>(42);
136        assert_eq!(s(0), (0, 42));
137    }
138
139    #[test]
140    fn state_t_pure_option() {
141        let s = state_t_pure::<i32, OptionF, _>(42);
142        assert_eq!(s(0), Some((0, 42)));
143    }
144
145    #[test]
146    fn state_t_get_test() {
147        let g = state_t_get::<i32, OptionF>();
148        assert_eq!(g(42), Some((42, 42)));
149    }
150
151    #[test]
152    fn state_t_put_test() {
153        let p = state_t_put::<i32, OptionF>(99);
154        assert_eq!(p(0), Some((99, ())));
155    }
156
157    #[test]
158    fn state_t_modify_test() {
159        let m = state_t_modify::<i32, OptionF>(|s| s + 1);
160        assert_eq!(m(5), Some((6, ())));
161    }
162
163    #[test]
164    fn state_t_fmap_test() {
165        let s = state_t_pure::<i32, OptionF, _>(10);
166        let mapped = state_t_fmap::<i32, OptionF, _, _>(s, |x| x * 3);
167        assert_eq!(mapped(0), Some((0, 30)));
168    }
169
170    #[test]
171    fn state_t_chain_threads_state() {
172        // get, then add state to value
173        let program = state_t_chain::<i32, OptionF, _, _>(state_t_get::<i32, OptionF>(), |x| {
174            state_t_chain::<i32, OptionF, _, _>(
175                state_t_modify::<i32, OptionF>(move |s| s + x),
176                |_| state_t_get::<i32, OptionF>(),
177            )
178        });
179        assert_eq!(program(10), Some((20, 20))); // get 10, modify +10, get 20
180    }
181
182    #[test]
183    fn state_t_chain_with_none() {
184        // OptionF inner monad can short-circuit
185        let program = state_t_chain::<i32, OptionF, _, _>(
186            state_t_get::<i32, OptionF>(),
187            |x| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
188                if x > 100 {
189                    state_t_pure::<i32, OptionF, _>(x)
190                } else {
191                    Box::new(|_| None) // short-circuit
192                }
193            },
194        );
195        assert_eq!(program(10), None);
196        assert_eq!(program(200), Some((200, 200)));
197    }
198
199    #[test]
200    fn state_t_lift_option() {
201        let lifted = StateTF::<i32, OptionF>::lift(Some(42));
202        assert_eq!(lifted(99), Some((99, 42)));
203    }
204
205    #[test]
206    fn state_t_lift_none() {
207        let lifted = StateTF::<i32, OptionF>::lift(None::<i32>);
208        assert_eq!(lifted(99), None);
209    }
210
211    #[test]
212    fn state_t_run_test() {
213        let s = state_t_pure::<i32, OptionF, _>(42);
214        assert_eq!(state_t_run::<i32, OptionF, i32>(&*s, 0), Some((0, 42)));
215    }
216
217    // Trait impls
218
219    #[test]
220    fn state_t_functor_st_trait() {
221        let s = state_t_pure::<i32, OptionF, _>(5);
222        let mapped = StateTF::<i32, OptionF>::fmap_st(s, |x| x + 1);
223        assert_eq!(mapped(0), Some((0, 6)));
224    }
225
226    #[test]
227    fn state_t_chain_st_trait() {
228        let s = state_t_pure::<i32, OptionF, _>(5);
229        let chained =
230            StateTF::<i32, OptionF>::chain_st(s, |x| state_t_pure::<i32, OptionF, _>(x + 10));
231        assert_eq!(chained(0), Some((0, 15)));
232    }
233}
234
235#[cfg(test)]
236mod law_tests {
237    use super::*;
238    use karpal_core::hkt::OptionF;
239    use proptest::prelude::*;
240
241    proptest! {
242        // Monad left identity: chain(pure(a), f) == f(a)
243        #[test]
244        fn state_t_monad_left_identity(a in -100i32..100, s in -100i32..100) {
245            let f = |x: i32| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
246                state_t_pure::<i32, OptionF, _>(x + 1)
247            };
248            let left = state_t_chain::<i32, OptionF, _, _>(
249                state_t_pure::<i32, OptionF, _>(a),
250                f,
251            );
252            let right = f(a);
253            prop_assert_eq!(left(s), right(s));
254        }
255
256        // Monad right identity: chain(m, pure) == m
257        #[test]
258        fn state_t_monad_right_identity(a in -100i32..100, s in -100i32..100) {
259            let m = state_t_pure::<i32, OptionF, _>(a);
260            let left = state_t_chain::<i32, OptionF, _, _>(
261                state_t_pure::<i32, OptionF, _>(a),
262                |x| state_t_pure::<i32, OptionF, _>(x),
263            );
264            prop_assert_eq!(left(s), m(s));
265        }
266
267        // Functor identity
268        #[test]
269        fn state_t_functor_identity(a in -100i32..100, s in -100i32..100) {
270            let m = state_t_pure::<i32, OptionF, _>(a);
271            let mapped = state_t_fmap::<i32, OptionF, _, _>(
272                state_t_pure::<i32, OptionF, _>(a),
273                |x| x,
274            );
275            prop_assert_eq!(mapped(s), m(s));
276        }
277
278        // State: get then put restores
279        #[test]
280        fn state_t_get_put(s in -100i32..100) {
281            let program = state_t_chain::<i32, OptionF, _, _>(
282                state_t_get::<i32, OptionF>(),
283                |x| state_t_put::<i32, OptionF>(x),
284            );
285            prop_assert_eq!(program(s), Some((s, ())));
286        }
287
288        // State: put then get returns what was put
289        #[test]
290        fn state_t_put_get(s in -100i32..100, new_s in -100i32..100) {
291            let program = state_t_chain::<i32, OptionF, _, _>(
292                state_t_put::<i32, OptionF>(new_s),
293                |_| state_t_get::<i32, OptionF>(),
294            );
295            prop_assert_eq!(program(s), Some((new_s, new_s)));
296        }
297    }
298}