Skip to main content

karpal_free/
free.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "std")]
5use std::boxed::Box;
6
7#[cfg(all(not(feature = "std"), feature = "alloc"))]
8use alloc::boxed::Box;
9
10use core::marker::PhantomData;
11
12use karpal_core::applicative::Applicative;
13use karpal_core::chain::Chain;
14use karpal_core::functor::Functor;
15use karpal_core::hkt::HKT;
16use karpal_core::natural::NaturalTransformation;
17
18/// Free Monad — builds a monadic computation as a data structure.
19///
20/// `Free<F, A>` represents a program where `F` describes the available
21/// effects and `A` is the result type. Programs are built with `pure`
22/// and `lift_f`, composed with `chain`, and interpreted with `fold_map`
23/// using a natural transformation into any target monad.
24///
25/// ```text
26/// Pure(a)              — a finished computation returning a
27/// Roll(F<Free<F, A>>)  — one layer of effect wrapping a continuation
28/// ```
29pub enum Free<F: HKT, A> {
30    /// A pure value — the computation is finished.
31    Pure(A),
32    /// A layer of effect `F` wrapping a continuation.
33    Roll(Box<F::Of<Free<F, A>>>),
34}
35
36impl<F: HKT, A> Free<F, A> {
37    /// Wrap a pure value into the free monad.
38    pub fn pure(a: A) -> Self {
39        Free::Pure(a)
40    }
41}
42
43impl<F: HKT + Functor, A> Free<F, A> {
44    /// Lift a single effect `F<A>` into the free monad.
45    pub fn lift_f(fa: F::Of<A>) -> Self {
46        Free::Roll(Box::new(F::fmap(fa, Free::Pure)))
47    }
48
49    /// Map a function over the result of this computation.
50    pub fn fmap<B>(self, f: impl Fn(A) -> B) -> Free<F, B> {
51        self.fmap_inner(&f)
52    }
53
54    fn fmap_inner<B>(self, f: &dyn Fn(A) -> B) -> Free<F, B> {
55        match self {
56            Free::Pure(a) => Free::Pure(f(a)),
57            Free::Roll(ff) => Free::Roll(Box::new(F::fmap(*ff, |child| child.fmap_inner(f)))),
58        }
59    }
60
61    /// Monadic bind — sequence this computation with a function that
62    /// produces the next computation.
63    pub fn chain<B>(self, f: impl Fn(A) -> Free<F, B>) -> Free<F, B> {
64        self.chain_inner(&f)
65    }
66
67    fn chain_inner<B>(self, f: &dyn Fn(A) -> Free<F, B>) -> Free<F, B> {
68        match self {
69            Free::Pure(a) => f(a),
70            Free::Roll(ff) => Free::Roll(Box::new(F::fmap(*ff, |child| child.chain_inner(f)))),
71        }
72    }
73
74    /// Interpret this free monad into a target monad `M` using a natural
75    /// transformation `NT: F ~> M`.
76    ///
77    /// This is the core interpreter: it collapses the free structure by
78    /// translating each `F` effect into `M` and sequencing with `M::chain`.
79    pub fn fold_map<M, NT>(self) -> M::Of<A>
80    where
81        M: Applicative + Chain,
82        NT: NaturalTransformation<F, M>,
83    {
84        match self {
85            Free::Pure(a) => M::pure(a),
86            Free::Roll(ff) => {
87                let mapped = F::fmap(*ff, |child| child.fold_map::<M, NT>());
88                let m_ma: M::Of<M::Of<A>> = NT::transform(mapped);
89                M::chain(m_ma, |x| x)
90            }
91        }
92    }
93}
94
95/// HKT marker for `Free<F, _>`.
96pub struct FreeF<F: HKT>(PhantomData<F>);
97
98impl<F: HKT> HKT for FreeF<F> {
99    type Of<T> = Free<F, T>;
100}
101
102impl<F: HKT + Functor> Functor for FreeF<F> {
103    fn fmap<A, B>(fa: Free<F, A>, f: impl Fn(A) -> B) -> Free<F, B> {
104        fa.fmap(f)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use karpal_core::hkt::OptionF;
112
113    #[test]
114    fn pure_value() {
115        let free = Free::<OptionF, i32>::pure(42);
116        match free {
117            Free::Pure(v) => assert_eq!(v, 42),
118            Free::Roll(_) => panic!("expected Pure"),
119        }
120    }
121
122    #[test]
123    fn lift_f_some() {
124        let free = Free::<OptionF, i32>::lift_f(Some(1));
125        match free {
126            Free::Roll(ff) => match *ff {
127                Some(Free::Pure(v)) => assert_eq!(v, 1),
128                _ => panic!("expected Some(Pure(1))"),
129            },
130            Free::Pure(_) => panic!("expected Roll"),
131        }
132    }
133
134    #[test]
135    fn fmap_pure() {
136        let free = Free::<OptionF, i32>::pure(2).fmap(|x| x * 3);
137        match free {
138            Free::Pure(v) => assert_eq!(v, 6),
139            Free::Roll(_) => panic!("expected Pure"),
140        }
141    }
142
143    #[test]
144    fn fmap_roll() {
145        let free = Free::<OptionF, i32>::lift_f(Some(5)).fmap(|x| x + 10);
146        match free {
147            Free::Roll(ff) => match *ff {
148                Some(Free::Pure(v)) => assert_eq!(v, 15),
149                _ => panic!("expected Some(Pure(15))"),
150            },
151            Free::Pure(_) => panic!("expected Roll"),
152        }
153    }
154
155    #[test]
156    fn chain_pure() {
157        let free = Free::<OptionF, i32>::pure(1).chain(|x| Free::pure(x + 1));
158        match free {
159            Free::Pure(v) => assert_eq!(v, 2),
160            Free::Roll(_) => panic!("expected Pure"),
161        }
162    }
163
164    #[test]
165    fn chain_roll() {
166        let free = Free::<OptionF, i32>::lift_f(Some(10)).chain(|x| Free::pure(x * 2));
167        // Roll(Some(Pure(10))).chain(f)
168        // = Roll(fmap(Some(Pure(10)), |child| child.chain(f)))
169        // = Roll(Some(Pure(10).chain(f)))
170        // = Roll(Some(f(10)))
171        // = Roll(Some(Pure(20)))
172        match free {
173            Free::Roll(ff) => match *ff {
174                Some(Free::Pure(v)) => assert_eq!(v, 20),
175                _ => panic!("expected Some(Pure(20))"),
176            },
177            Free::Pure(_) => panic!("expected Roll"),
178        }
179    }
180
181    #[test]
182    fn chain_associativity() {
183        let _m = Free::<OptionF, i32>::pure(5);
184        let _f = |x: i32| Free::<OptionF, i32>::pure(x + 1);
185        let _g = |x: i32| Free::<OptionF, i32>::pure(x * 2);
186
187        // m.chain(f).chain(g)
188        let left = Free::<OptionF, i32>::pure(5)
189            .chain(|x| Free::pure(x + 1))
190            .chain(|x| Free::pure(x * 2));
191
192        // m.chain(|x| f(x).chain(g))
193        let right = Free::<OptionF, i32>::pure(5)
194            .chain(|x| Free::<OptionF, i32>::pure(x + 1).chain(|y| Free::pure(y * 2)));
195
196        match (left, right) {
197            (Free::Pure(l), Free::Pure(r)) => assert_eq!(l, r),
198            _ => panic!("expected both Pure"),
199        }
200    }
201
202    // Natural transformation: Option ~> Option (identity)
203    struct OptionId;
204    impl NaturalTransformation<OptionF, OptionF> for OptionId {
205        fn transform<A>(fa: Option<A>) -> Option<A> {
206            fa
207        }
208    }
209
210    #[test]
211    fn fold_map_pure() {
212        let free = Free::<OptionF, i32>::pure(42);
213        let result = free.fold_map::<OptionF, OptionId>();
214        assert_eq!(result, Some(42));
215    }
216
217    #[test]
218    fn fold_map_roll() {
219        let free = Free::<OptionF, i32>::lift_f(Some(10));
220        let result = free.fold_map::<OptionF, OptionId>();
221        assert_eq!(result, Some(10));
222    }
223
224    #[test]
225    fn fold_map_chain_then_interpret() {
226        let free = Free::<OptionF, i32>::lift_f(Some(3)).chain(|x| Free::lift_f(Some(x * 10)));
227        let result = free.fold_map::<OptionF, OptionId>();
228        assert_eq!(result, Some(30));
229    }
230
231    #[test]
232    fn functor_impl_works() {
233        let free = Free::<OptionF, i32>::pure(5);
234        let result = <FreeF<OptionF> as Functor>::fmap(free, |x| x + 10);
235        match result {
236            Free::Pure(v) => assert_eq!(v, 15),
237            Free::Roll(_) => panic!("expected Pure"),
238        }
239    }
240}
241
242#[cfg(test)]
243mod law_tests {
244    use super::*;
245    use karpal_core::hkt::OptionF;
246    use proptest::prelude::*;
247
248    // Helper to extract Pure value for comparison
249    fn extract_pure<F: HKT, A>(free: Free<F, A>) -> Option<A> {
250        match free {
251            Free::Pure(a) => Some(a),
252            Free::Roll(_) => None,
253        }
254    }
255
256    proptest! {
257        // Functor identity: fmap(id, fa) == fa
258        #[test]
259        fn functor_identity(x in any::<i32>()) {
260            let free = Free::<OptionF, i32>::pure(x);
261            let result = free.fmap(|a| a);
262            prop_assert_eq!(extract_pure(result), Some(x));
263        }
264
265        // Functor composition: fmap(g . f, fa) == fmap(g, fmap(f, fa))
266        #[test]
267        fn functor_composition(x in any::<i32>()) {
268            let f = |a: i32| a.wrapping_add(1);
269            let g = |a: i32| a.wrapping_mul(2);
270
271            let left = Free::<OptionF, i32>::pure(x).fmap(|a| g(f(a)));
272            let right = Free::<OptionF, i32>::pure(x).fmap(f).fmap(g);
273            prop_assert_eq!(extract_pure(left), extract_pure(right));
274        }
275
276        // Monad left identity: pure(a).chain(f) == f(a)
277        #[test]
278        fn monad_left_identity(x in any::<i32>()) {
279            let f = |a: i32| Free::<OptionF, i32>::pure(a.wrapping_mul(2));
280            let left = Free::<OptionF, i32>::pure(x).chain(&f);
281            let right = f(x);
282            prop_assert_eq!(extract_pure(left), extract_pure(right));
283        }
284
285        // Monad right identity: m.chain(pure) == m
286        #[test]
287        fn monad_right_identity(x in any::<i32>()) {
288            let m = Free::<OptionF, i32>::pure(x);
289            let result = m.chain(Free::pure);
290            prop_assert_eq!(extract_pure(result), Some(x));
291        }
292    }
293}