1#![allow(clippy::type_complexity)]
5
6use core::marker::PhantomData;
7
8use karpal_core::hkt::HKT;
9
10use crate::classes::{ApplicativeSt, ChainSt, FunctorSt};
11use crate::trans::MonadTrans;
12
13#[cfg(all(not(feature = "std"), feature = "alloc"))]
14use alloc::rc::Rc;
15#[cfg(feature = "std")]
16use std::rc::Rc;
17
18pub struct StateTF<S, M>(PhantomData<(S, M)>);
26
27impl<S: 'static, M: HKT + 'static> HKT for StateTF<S, M> {
28 type Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>;
29}
30
31impl<S: Clone + 'static, M: FunctorSt + 'static> MonadTrans<M> for StateTF<S, M> {
32 fn lift<A: 'static>(ma: M::Of<A>) -> Box<dyn Fn(S) -> M::Of<(S, A)>>
33 where
34 M::Of<A>: Clone,
35 {
36 Box::new(move |s| M::fmap_st(ma.clone(), move |a| (s.clone(), a)))
37 }
38}
39
40pub fn state_t_pure<S: Clone + 'static, M: ApplicativeSt + 'static, A: Clone + 'static>(
42 a: A,
43) -> Box<dyn Fn(S) -> M::Of<(S, A)>> {
44 Box::new(move |s| M::pure_st((s, a.clone())))
45}
46
47pub fn state_t_fmap<S: 'static, M: FunctorSt + 'static, A: 'static, B: 'static>(
49 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
50 f: impl Fn(A) -> B + 'static,
51) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
52 let f_rc = Rc::new(f);
53 Box::new(move |s| {
54 let f_inner = f_rc.clone();
55 M::fmap_st(fa(s), move |(s2, a)| (s2, f_inner(a)))
56 })
57}
58
59pub fn state_t_chain<S: Clone + 'static, M: ChainSt + 'static, A: 'static, B: 'static>(
63 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
64 f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
65) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
66 let f_rc = Rc::new(f);
67 Box::new(move |s| {
68 let f_inner = f_rc.clone();
69 M::chain_st(fa(s), move |(s2, a)| {
70 let state_b = f_inner(a);
71 state_b(s2)
72 })
73 })
74}
75
76pub fn state_t_get<S: Clone + 'static, M: ApplicativeSt + 'static>()
78-> Box<dyn Fn(S) -> M::Of<(S, S)>> {
79 Box::new(|s: S| {
80 let s2 = s.clone();
81 M::pure_st((s, s2))
82 })
83}
84
85pub fn state_t_put<S: Clone + 'static, M: ApplicativeSt + 'static>(
87 new_state: S,
88) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
89 Box::new(move |_| M::pure_st((new_state.clone(), ())))
90}
91
92pub fn state_t_modify<S: Clone + 'static, M: ApplicativeSt + 'static>(
94 f: impl Fn(S) -> S + 'static,
95) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
96 Box::new(move |s| {
97 let new_s = f(s);
98 M::pure_st((new_s, ()))
99 })
100}
101
102pub fn state_t_run<S, M: HKT, A>(state: &dyn Fn(S) -> M::Of<(S, A)>, initial: S) -> M::Of<(S, A)> {
104 state(initial)
105}
106
107impl<S: 'static, M: FunctorSt + 'static> FunctorSt for StateTF<S, M> {
110 fn fmap_st<A: 'static, B: 'static>(
111 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
112 f: impl Fn(A) -> B + 'static,
113 ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
114 state_t_fmap::<S, M, A, B>(fa, f)
115 }
116}
117
118impl<S: Clone + 'static, M: ChainSt + 'static> ChainSt for StateTF<S, M> {
123 fn chain_st<A: 'static, B: 'static>(
124 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
125 f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
126 ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
127 state_t_chain::<S, M, A, B>(fa, f)
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use karpal_core::hkt::{IdentityF, OptionF};
135
136 #[test]
137 fn state_t_pure_identity() {
138 let s = state_t_pure::<i32, IdentityF, _>(42);
139 assert_eq!(s(0), (0, 42));
140 }
141
142 #[test]
143 fn state_t_pure_option() {
144 let s = state_t_pure::<i32, OptionF, _>(42);
145 assert_eq!(s(0), Some((0, 42)));
146 }
147
148 #[test]
149 fn state_t_get_test() {
150 let g = state_t_get::<i32, OptionF>();
151 assert_eq!(g(42), Some((42, 42)));
152 }
153
154 #[test]
155 fn state_t_put_test() {
156 let p = state_t_put::<i32, OptionF>(99);
157 assert_eq!(p(0), Some((99, ())));
158 }
159
160 #[test]
161 fn state_t_modify_test() {
162 let m = state_t_modify::<i32, OptionF>(|s| s + 1);
163 assert_eq!(m(5), Some((6, ())));
164 }
165
166 #[test]
167 fn state_t_fmap_test() {
168 let s = state_t_pure::<i32, OptionF, _>(10);
169 let mapped = state_t_fmap::<i32, OptionF, _, _>(s, |x| x * 3);
170 assert_eq!(mapped(0), Some((0, 30)));
171 }
172
173 #[test]
174 fn state_t_chain_threads_state() {
175 let program = state_t_chain::<i32, OptionF, _, _>(state_t_get::<i32, OptionF>(), |x| {
177 state_t_chain::<i32, OptionF, _, _>(
178 state_t_modify::<i32, OptionF>(move |s| s + x),
179 |_| state_t_get::<i32, OptionF>(),
180 )
181 });
182 assert_eq!(program(10), Some((20, 20))); }
184
185 #[test]
186 fn state_t_chain_with_none() {
187 let program = state_t_chain::<i32, OptionF, _, _>(
189 state_t_get::<i32, OptionF>(),
190 |x| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
191 if x > 100 {
192 state_t_pure::<i32, OptionF, _>(x)
193 } else {
194 Box::new(|_| None) }
196 },
197 );
198 assert_eq!(program(10), None);
199 assert_eq!(program(200), Some((200, 200)));
200 }
201
202 #[test]
203 fn state_t_lift_option() {
204 let lifted = StateTF::<i32, OptionF>::lift(Some(42));
205 assert_eq!(lifted(99), Some((99, 42)));
206 }
207
208 #[test]
209 fn state_t_lift_none() {
210 let lifted = StateTF::<i32, OptionF>::lift(None::<i32>);
211 assert_eq!(lifted(99), None);
212 }
213
214 #[test]
215 fn state_t_run_test() {
216 let s = state_t_pure::<i32, OptionF, _>(42);
217 assert_eq!(state_t_run::<i32, OptionF, i32>(&*s, 0), Some((0, 42)));
218 }
219
220 #[test]
223 fn state_t_functor_st_trait() {
224 let s = state_t_pure::<i32, OptionF, _>(5);
225 let mapped = StateTF::<i32, OptionF>::fmap_st(s, |x| x + 1);
226 assert_eq!(mapped(0), Some((0, 6)));
227 }
228
229 #[test]
230 fn state_t_chain_st_trait() {
231 let s = state_t_pure::<i32, OptionF, _>(5);
232 let chained =
233 StateTF::<i32, OptionF>::chain_st(s, |x| state_t_pure::<i32, OptionF, _>(x + 10));
234 assert_eq!(chained(0), Some((0, 15)));
235 }
236}
237
238#[cfg(test)]
239mod law_tests {
240 use super::*;
241 use karpal_core::hkt::OptionF;
242 use proptest::prelude::*;
243
244 proptest! {
245 #[test]
247 fn state_t_monad_left_identity(a in -100i32..100, s in -100i32..100) {
248 let f = |x: i32| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
249 state_t_pure::<i32, OptionF, _>(x + 1)
250 };
251 let left = state_t_chain::<i32, OptionF, _, _>(
252 state_t_pure::<i32, OptionF, _>(a),
253 f,
254 );
255 let right = f(a);
256 prop_assert_eq!(left(s), right(s));
257 }
258
259 #[test]
261 fn state_t_monad_right_identity(a in -100i32..100, s in -100i32..100) {
262 let m = state_t_pure::<i32, OptionF, _>(a);
263 let left = state_t_chain::<i32, OptionF, _, _>(
264 state_t_pure::<i32, OptionF, _>(a),
265 |x| state_t_pure::<i32, OptionF, _>(x),
266 );
267 prop_assert_eq!(left(s), m(s));
268 }
269
270 #[test]
272 fn state_t_functor_identity(a in -100i32..100, s in -100i32..100) {
273 let m = state_t_pure::<i32, OptionF, _>(a);
274 let mapped = state_t_fmap::<i32, OptionF, _, _>(
275 state_t_pure::<i32, OptionF, _>(a),
276 |x| x,
277 );
278 prop_assert_eq!(mapped(s), m(s));
279 }
280
281 #[test]
283 fn state_t_get_put(s in -100i32..100) {
284 let program = state_t_chain::<i32, OptionF, _, _>(
285 state_t_get::<i32, OptionF>(),
286 |x| state_t_put::<i32, OptionF>(x),
287 );
288 prop_assert_eq!(program(s), Some((s, ())));
289 }
290
291 #[test]
293 fn state_t_put_get(s in -100i32..100, new_s in -100i32..100) {
294 let program = state_t_chain::<i32, OptionF, _, _>(
295 state_t_put::<i32, OptionF>(new_s),
296 |_| state_t_get::<i32, OptionF>(),
297 );
298 prop_assert_eq!(program(s), Some((new_s, new_s)));
299 }
300 }
301}