Skip to main content

karpal_effect/
state_t.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(clippy::type_complexity)]
5
6use core::marker::PhantomData;
7
8use karpal_core::hkt::HKT;
9
10use crate::classes::{ApplicativeSt, ChainSt, FunctorSt};
11use crate::trans::MonadTrans;
12
13#[cfg(all(not(feature = "std"), feature = "alloc"))]
14use alloc::rc::Rc;
15#[cfg(feature = "std")]
16use std::rc::Rc;
17
18/// StateT monad transformer: adds mutable state to an inner monad.
19///
20/// `StateTF<S, M>::Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>`
21///
22/// Unlike ReaderT, the state is threaded (modified) through computations.
23/// When the inner monad is `IdentityF`, this is equivalent to the State monad
24/// from `karpal-core`'s adjunction module.
25pub struct StateTF<S, M>(PhantomData<(S, M)>);
26
27impl<S: 'static, M: HKT + 'static> HKT for StateTF<S, M> {
28    type Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>;
29}
30
31impl<S: Clone + 'static, M: FunctorSt + 'static> MonadTrans<M> for StateTF<S, M> {
32    fn lift<A: 'static>(ma: M::Of<A>) -> Box<dyn Fn(S) -> M::Of<(S, A)>>
33    where
34        M::Of<A>: Clone,
35    {
36        Box::new(move |s| M::fmap_st(ma.clone(), move |a| (s.clone(), a)))
37    }
38}
39
40/// StateT `pure`: wrap a value without modifying state.
41pub fn state_t_pure<S: Clone + 'static, M: ApplicativeSt + 'static, A: Clone + 'static>(
42    a: A,
43) -> Box<dyn Fn(S) -> M::Of<(S, A)>> {
44    Box::new(move |s| M::pure_st((s, a.clone())))
45}
46
47/// StateT `fmap`: apply a function to the result, leaving state unchanged.
48pub fn state_t_fmap<S: 'static, M: FunctorSt + 'static, A: 'static, B: 'static>(
49    fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
50    f: impl Fn(A) -> B + 'static,
51) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
52    let f_rc = Rc::new(f);
53    Box::new(move |s| {
54        let f_inner = f_rc.clone();
55        M::fmap_st(fa(s), move |(s2, a)| (s2, f_inner(a)))
56    })
57}
58
59/// StateT `chain`: sequence stateful computations, threading state.
60///
61/// The state from the first computation is passed to the second.
62pub fn state_t_chain<S: Clone + 'static, M: ChainSt + 'static, A: 'static, B: 'static>(
63    fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
64    f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
65) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
66    let f_rc = Rc::new(f);
67    Box::new(move |s| {
68        let f_inner = f_rc.clone();
69        M::chain_st(fa(s), move |(s2, a)| {
70            let state_b = f_inner(a);
71            state_b(s2)
72        })
73    })
74}
75
76/// StateT `get`: read the current state.
77pub fn state_t_get<S: Clone + 'static, M: ApplicativeSt + 'static>()
78-> Box<dyn Fn(S) -> M::Of<(S, S)>> {
79    Box::new(|s: S| {
80        let s2 = s.clone();
81        M::pure_st((s, s2))
82    })
83}
84
85/// StateT `put`: replace the state.
86pub fn state_t_put<S: Clone + 'static, M: ApplicativeSt + 'static>(
87    new_state: S,
88) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
89    Box::new(move |_| M::pure_st((new_state.clone(), ())))
90}
91
92/// StateT `modify`: apply a function to the state.
93pub fn state_t_modify<S: Clone + 'static, M: ApplicativeSt + 'static>(
94    f: impl Fn(S) -> S + 'static,
95) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
96    Box::new(move |s| {
97        let new_s = f(s);
98        M::pure_st((new_s, ()))
99    })
100}
101
102/// StateT `run`: run the computation with initial state.
103pub fn state_t_run<S, M: HKT, A>(state: &dyn Fn(S) -> M::Of<(S, A)>, initial: S) -> M::Of<(S, A)> {
104    state(initial)
105}
106
107// --- FunctorSt / ApplicativeSt / ChainSt for StateTF ---
108
109impl<S: 'static, M: FunctorSt + 'static> FunctorSt for StateTF<S, M> {
110    fn fmap_st<A: 'static, B: 'static>(
111        fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
112        f: impl Fn(A) -> B + 'static,
113    ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
114        state_t_fmap::<S, M, A, B>(fa, f)
115    }
116}
117
118// Note: ApplicativeSt is not implemented for StateTF because pure_st
119// cannot produce a Box<dyn Fn(S) -> M::Of<(S, A)>> from a single A without Clone.
120// Use the standalone state_t_pure function instead (which requires A: Clone).
121
122impl<S: Clone + 'static, M: ChainSt + 'static> ChainSt for StateTF<S, M> {
123    fn chain_st<A: 'static, B: 'static>(
124        fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
125        f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
126    ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
127        state_t_chain::<S, M, A, B>(fa, f)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use karpal_core::hkt::{IdentityF, OptionF};
135
136    #[test]
137    fn state_t_pure_identity() {
138        let s = state_t_pure::<i32, IdentityF, _>(42);
139        assert_eq!(s(0), (0, 42));
140    }
141
142    #[test]
143    fn state_t_pure_option() {
144        let s = state_t_pure::<i32, OptionF, _>(42);
145        assert_eq!(s(0), Some((0, 42)));
146    }
147
148    #[test]
149    fn state_t_get_test() {
150        let g = state_t_get::<i32, OptionF>();
151        assert_eq!(g(42), Some((42, 42)));
152    }
153
154    #[test]
155    fn state_t_put_test() {
156        let p = state_t_put::<i32, OptionF>(99);
157        assert_eq!(p(0), Some((99, ())));
158    }
159
160    #[test]
161    fn state_t_modify_test() {
162        let m = state_t_modify::<i32, OptionF>(|s| s + 1);
163        assert_eq!(m(5), Some((6, ())));
164    }
165
166    #[test]
167    fn state_t_fmap_test() {
168        let s = state_t_pure::<i32, OptionF, _>(10);
169        let mapped = state_t_fmap::<i32, OptionF, _, _>(s, |x| x * 3);
170        assert_eq!(mapped(0), Some((0, 30)));
171    }
172
173    #[test]
174    fn state_t_chain_threads_state() {
175        // get, then add state to value
176        let program = state_t_chain::<i32, OptionF, _, _>(state_t_get::<i32, OptionF>(), |x| {
177            state_t_chain::<i32, OptionF, _, _>(
178                state_t_modify::<i32, OptionF>(move |s| s + x),
179                |_| state_t_get::<i32, OptionF>(),
180            )
181        });
182        assert_eq!(program(10), Some((20, 20))); // get 10, modify +10, get 20
183    }
184
185    #[test]
186    fn state_t_chain_with_none() {
187        // OptionF inner monad can short-circuit
188        let program = state_t_chain::<i32, OptionF, _, _>(
189            state_t_get::<i32, OptionF>(),
190            |x| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
191                if x > 100 {
192                    state_t_pure::<i32, OptionF, _>(x)
193                } else {
194                    Box::new(|_| None) // short-circuit
195                }
196            },
197        );
198        assert_eq!(program(10), None);
199        assert_eq!(program(200), Some((200, 200)));
200    }
201
202    #[test]
203    fn state_t_lift_option() {
204        let lifted = StateTF::<i32, OptionF>::lift(Some(42));
205        assert_eq!(lifted(99), Some((99, 42)));
206    }
207
208    #[test]
209    fn state_t_lift_none() {
210        let lifted = StateTF::<i32, OptionF>::lift(None::<i32>);
211        assert_eq!(lifted(99), None);
212    }
213
214    #[test]
215    fn state_t_run_test() {
216        let s = state_t_pure::<i32, OptionF, _>(42);
217        assert_eq!(state_t_run::<i32, OptionF, i32>(&*s, 0), Some((0, 42)));
218    }
219
220    // Trait impls
221
222    #[test]
223    fn state_t_functor_st_trait() {
224        let s = state_t_pure::<i32, OptionF, _>(5);
225        let mapped = StateTF::<i32, OptionF>::fmap_st(s, |x| x + 1);
226        assert_eq!(mapped(0), Some((0, 6)));
227    }
228
229    #[test]
230    fn state_t_chain_st_trait() {
231        let s = state_t_pure::<i32, OptionF, _>(5);
232        let chained =
233            StateTF::<i32, OptionF>::chain_st(s, |x| state_t_pure::<i32, OptionF, _>(x + 10));
234        assert_eq!(chained(0), Some((0, 15)));
235    }
236}
237
238#[cfg(test)]
239mod law_tests {
240    use super::*;
241    use karpal_core::hkt::OptionF;
242    use proptest::prelude::*;
243
244    proptest! {
245        // Monad left identity: chain(pure(a), f) == f(a)
246        #[test]
247        fn state_t_monad_left_identity(a in -100i32..100, s in -100i32..100) {
248            let f = |x: i32| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
249                state_t_pure::<i32, OptionF, _>(x + 1)
250            };
251            let left = state_t_chain::<i32, OptionF, _, _>(
252                state_t_pure::<i32, OptionF, _>(a),
253                f,
254            );
255            let right = f(a);
256            prop_assert_eq!(left(s), right(s));
257        }
258
259        // Monad right identity: chain(m, pure) == m
260        #[test]
261        fn state_t_monad_right_identity(a in -100i32..100, s in -100i32..100) {
262            let m = state_t_pure::<i32, OptionF, _>(a);
263            let left = state_t_chain::<i32, OptionF, _, _>(
264                state_t_pure::<i32, OptionF, _>(a),
265                |x| state_t_pure::<i32, OptionF, _>(x),
266            );
267            prop_assert_eq!(left(s), m(s));
268        }
269
270        // Functor identity
271        #[test]
272        fn state_t_functor_identity(a in -100i32..100, s in -100i32..100) {
273            let m = state_t_pure::<i32, OptionF, _>(a);
274            let mapped = state_t_fmap::<i32, OptionF, _, _>(
275                state_t_pure::<i32, OptionF, _>(a),
276                |x| x,
277            );
278            prop_assert_eq!(mapped(s), m(s));
279        }
280
281        // State: get then put restores
282        #[test]
283        fn state_t_get_put(s in -100i32..100) {
284            let program = state_t_chain::<i32, OptionF, _, _>(
285                state_t_get::<i32, OptionF>(),
286                |x| state_t_put::<i32, OptionF>(x),
287            );
288            prop_assert_eq!(program(s), Some((s, ())));
289        }
290
291        // State: put then get returns what was put
292        #[test]
293        fn state_t_put_get(s in -100i32..100, new_s in -100i32..100) {
294            let program = state_t_chain::<i32, OptionF, _, _>(
295                state_t_put::<i32, OptionF>(new_s),
296                |_| state_t_get::<i32, OptionF>(),
297            );
298            prop_assert_eq!(program(s), Some((new_s, new_s)));
299        }
300    }
301}