1#![allow(clippy::type_complexity)]
2
3use core::marker::PhantomData;
4
5use karpal_core::hkt::HKT;
6
7use crate::classes::{ApplicativeSt, ChainSt, FunctorSt};
8use crate::trans::MonadTrans;
9
10#[cfg(all(not(feature = "std"), feature = "alloc"))]
11use alloc::rc::Rc;
12#[cfg(feature = "std")]
13use std::rc::Rc;
14
15pub struct StateTF<S, M>(PhantomData<(S, M)>);
23
24impl<S: 'static, M: HKT + 'static> HKT for StateTF<S, M> {
25 type Of<A> = Box<dyn Fn(S) -> M::Of<(S, A)>>;
26}
27
28impl<S: Clone + 'static, M: FunctorSt + 'static> MonadTrans<M> for StateTF<S, M> {
29 fn lift<A: 'static>(ma: M::Of<A>) -> Box<dyn Fn(S) -> M::Of<(S, A)>>
30 where
31 M::Of<A>: Clone,
32 {
33 Box::new(move |s| M::fmap_st(ma.clone(), move |a| (s.clone(), a)))
34 }
35}
36
37pub fn state_t_pure<S: Clone + 'static, M: ApplicativeSt + 'static, A: Clone + 'static>(
39 a: A,
40) -> Box<dyn Fn(S) -> M::Of<(S, A)>> {
41 Box::new(move |s| M::pure_st((s, a.clone())))
42}
43
44pub fn state_t_fmap<S: 'static, M: FunctorSt + 'static, A: 'static, B: 'static>(
46 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
47 f: impl Fn(A) -> B + 'static,
48) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
49 let f_rc = Rc::new(f);
50 Box::new(move |s| {
51 let f_inner = f_rc.clone();
52 M::fmap_st(fa(s), move |(s2, a)| (s2, f_inner(a)))
53 })
54}
55
56pub fn state_t_chain<S: Clone + 'static, M: ChainSt + 'static, A: 'static, B: 'static>(
60 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
61 f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
62) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
63 let f_rc = Rc::new(f);
64 Box::new(move |s| {
65 let f_inner = f_rc.clone();
66 M::chain_st(fa(s), move |(s2, a)| {
67 let state_b = f_inner(a);
68 state_b(s2)
69 })
70 })
71}
72
73pub fn state_t_get<S: Clone + 'static, M: ApplicativeSt + 'static>()
75-> Box<dyn Fn(S) -> M::Of<(S, S)>> {
76 Box::new(|s: S| {
77 let s2 = s.clone();
78 M::pure_st((s, s2))
79 })
80}
81
82pub fn state_t_put<S: Clone + 'static, M: ApplicativeSt + 'static>(
84 new_state: S,
85) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
86 Box::new(move |_| M::pure_st((new_state.clone(), ())))
87}
88
89pub fn state_t_modify<S: Clone + 'static, M: ApplicativeSt + 'static>(
91 f: impl Fn(S) -> S + 'static,
92) -> Box<dyn Fn(S) -> M::Of<(S, ())>> {
93 Box::new(move |s| {
94 let new_s = f(s);
95 M::pure_st((new_s, ()))
96 })
97}
98
99pub fn state_t_run<S, M: HKT, A>(state: &dyn Fn(S) -> M::Of<(S, A)>, initial: S) -> M::Of<(S, A)> {
101 state(initial)
102}
103
104impl<S: 'static, M: FunctorSt + 'static> FunctorSt for StateTF<S, M> {
107 fn fmap_st<A: 'static, B: 'static>(
108 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
109 f: impl Fn(A) -> B + 'static,
110 ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
111 state_t_fmap::<S, M, A, B>(fa, f)
112 }
113}
114
115impl<S: Clone + 'static, M: ChainSt + 'static> ChainSt for StateTF<S, M> {
120 fn chain_st<A: 'static, B: 'static>(
121 fa: Box<dyn Fn(S) -> M::Of<(S, A)>>,
122 f: impl Fn(A) -> Box<dyn Fn(S) -> M::Of<(S, B)>> + 'static,
123 ) -> Box<dyn Fn(S) -> M::Of<(S, B)>> {
124 state_t_chain::<S, M, A, B>(fa, f)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use karpal_core::hkt::{IdentityF, OptionF};
132
133 #[test]
134 fn state_t_pure_identity() {
135 let s = state_t_pure::<i32, IdentityF, _>(42);
136 assert_eq!(s(0), (0, 42));
137 }
138
139 #[test]
140 fn state_t_pure_option() {
141 let s = state_t_pure::<i32, OptionF, _>(42);
142 assert_eq!(s(0), Some((0, 42)));
143 }
144
145 #[test]
146 fn state_t_get_test() {
147 let g = state_t_get::<i32, OptionF>();
148 assert_eq!(g(42), Some((42, 42)));
149 }
150
151 #[test]
152 fn state_t_put_test() {
153 let p = state_t_put::<i32, OptionF>(99);
154 assert_eq!(p(0), Some((99, ())));
155 }
156
157 #[test]
158 fn state_t_modify_test() {
159 let m = state_t_modify::<i32, OptionF>(|s| s + 1);
160 assert_eq!(m(5), Some((6, ())));
161 }
162
163 #[test]
164 fn state_t_fmap_test() {
165 let s = state_t_pure::<i32, OptionF, _>(10);
166 let mapped = state_t_fmap::<i32, OptionF, _, _>(s, |x| x * 3);
167 assert_eq!(mapped(0), Some((0, 30)));
168 }
169
170 #[test]
171 fn state_t_chain_threads_state() {
172 let program = state_t_chain::<i32, OptionF, _, _>(state_t_get::<i32, OptionF>(), |x| {
174 state_t_chain::<i32, OptionF, _, _>(
175 state_t_modify::<i32, OptionF>(move |s| s + x),
176 |_| state_t_get::<i32, OptionF>(),
177 )
178 });
179 assert_eq!(program(10), Some((20, 20))); }
181
182 #[test]
183 fn state_t_chain_with_none() {
184 let program = state_t_chain::<i32, OptionF, _, _>(
186 state_t_get::<i32, OptionF>(),
187 |x| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
188 if x > 100 {
189 state_t_pure::<i32, OptionF, _>(x)
190 } else {
191 Box::new(|_| None) }
193 },
194 );
195 assert_eq!(program(10), None);
196 assert_eq!(program(200), Some((200, 200)));
197 }
198
199 #[test]
200 fn state_t_lift_option() {
201 let lifted = StateTF::<i32, OptionF>::lift(Some(42));
202 assert_eq!(lifted(99), Some((99, 42)));
203 }
204
205 #[test]
206 fn state_t_lift_none() {
207 let lifted = StateTF::<i32, OptionF>::lift(None::<i32>);
208 assert_eq!(lifted(99), None);
209 }
210
211 #[test]
212 fn state_t_run_test() {
213 let s = state_t_pure::<i32, OptionF, _>(42);
214 assert_eq!(state_t_run::<i32, OptionF, i32>(&*s, 0), Some((0, 42)));
215 }
216
217 #[test]
220 fn state_t_functor_st_trait() {
221 let s = state_t_pure::<i32, OptionF, _>(5);
222 let mapped = StateTF::<i32, OptionF>::fmap_st(s, |x| x + 1);
223 assert_eq!(mapped(0), Some((0, 6)));
224 }
225
226 #[test]
227 fn state_t_chain_st_trait() {
228 let s = state_t_pure::<i32, OptionF, _>(5);
229 let chained =
230 StateTF::<i32, OptionF>::chain_st(s, |x| state_t_pure::<i32, OptionF, _>(x + 10));
231 assert_eq!(chained(0), Some((0, 15)));
232 }
233}
234
235#[cfg(test)]
236mod law_tests {
237 use super::*;
238 use karpal_core::hkt::OptionF;
239 use proptest::prelude::*;
240
241 proptest! {
242 #[test]
244 fn state_t_monad_left_identity(a in -100i32..100, s in -100i32..100) {
245 let f = |x: i32| -> Box<dyn Fn(i32) -> Option<(i32, i32)>> {
246 state_t_pure::<i32, OptionF, _>(x + 1)
247 };
248 let left = state_t_chain::<i32, OptionF, _, _>(
249 state_t_pure::<i32, OptionF, _>(a),
250 f,
251 );
252 let right = f(a);
253 prop_assert_eq!(left(s), right(s));
254 }
255
256 #[test]
258 fn state_t_monad_right_identity(a in -100i32..100, s in -100i32..100) {
259 let m = state_t_pure::<i32, OptionF, _>(a);
260 let left = state_t_chain::<i32, OptionF, _, _>(
261 state_t_pure::<i32, OptionF, _>(a),
262 |x| state_t_pure::<i32, OptionF, _>(x),
263 );
264 prop_assert_eq!(left(s), m(s));
265 }
266
267 #[test]
269 fn state_t_functor_identity(a in -100i32..100, s in -100i32..100) {
270 let m = state_t_pure::<i32, OptionF, _>(a);
271 let mapped = state_t_fmap::<i32, OptionF, _, _>(
272 state_t_pure::<i32, OptionF, _>(a),
273 |x| x,
274 );
275 prop_assert_eq!(mapped(s), m(s));
276 }
277
278 #[test]
280 fn state_t_get_put(s in -100i32..100) {
281 let program = state_t_chain::<i32, OptionF, _, _>(
282 state_t_get::<i32, OptionF>(),
283 |x| state_t_put::<i32, OptionF>(x),
284 );
285 prop_assert_eq!(program(s), Some((s, ())));
286 }
287
288 #[test]
290 fn state_t_put_get(s in -100i32..100, new_s in -100i32..100) {
291 let program = state_t_chain::<i32, OptionF, _, _>(
292 state_t_put::<i32, OptionF>(new_s),
293 |_| state_t_get::<i32, OptionF>(),
294 );
295 prop_assert_eq!(program(s), Some((new_s, new_s)));
296 }
297 }
298}