Skip to main content

karpal_free/
freer.rs

1#[cfg(feature = "std")]
2use std::boxed::Box;
3
4#[cfg(all(not(feature = "std"), feature = "alloc"))]
5use alloc::boxed::Box;
6
7#[cfg(feature = "std")]
8use std::rc::Rc;
9
10#[cfg(all(not(feature = "std"), feature = "alloc"))]
11use alloc::rc::Rc;
12
13use core::marker::PhantomData;
14
15use karpal_core::applicative::Applicative;
16use karpal_core::chain::Chain;
17use karpal_core::functor::Functor;
18use karpal_core::hkt::HKT;
19use karpal_core::natural::NaturalTransformation;
20
21// ---- Private step types for the existential encoding ----
22
23/// Dyn-safe trait for an effect step in the Freer monad.
24///
25/// Each step stores `∃B. (F::Of<B>, B → Freer<F, A>)` — an effect value
26/// and a continuation. The intermediate type `B` is hidden behind the
27/// trait object.
28trait FreerStep<F: HKT + 'static, A: 'static> {
29    /// Lower this step to `F::Of<Freer<F, A>>` by applying F::fmap.
30    /// Only called during `fold_map`, where `F: Functor` is available.
31    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
32    where
33        F: Functor;
34}
35
36/// A leaf step: stores an effect `F::Of<B>` and a continuation `B → Freer<F, A>`.
37struct ImpureStep<F: HKT + 'static, A: 'static, B: 'static> {
38    effect: F::Of<B>,
39    cont: Box<dyn Fn(B) -> Freer<F, A>>,
40}
41
42impl<F: HKT + 'static, A: 'static, B: 'static> FreerStep<F, A> for ImpureStep<F, A, B> {
43    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
44    where
45        F: Functor,
46    {
47        F::fmap(self.effect, self.cont)
48    }
49}
50
51/// A chained step: wraps an inner step with a deferred chain operation.
52///
53/// When lowered, first lowers the inner step to get `F::Of<Freer<F, Src>>`,
54/// then fmaps `chain_rc` over it to produce `F::Of<Freer<F, A>>`.
55struct ChainedStep<F: HKT + 'static, Src: 'static, A: 'static> {
56    inner: Box<dyn FreerStep<F, Src>>,
57    chain_fn: Rc<dyn Fn(Src) -> Freer<F, A>>,
58}
59
60impl<F: HKT + 'static, Src: 'static, A: 'static> FreerStep<F, A> for ChainedStep<F, Src, A> {
61    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
62    where
63        F: Functor,
64    {
65        let f_freer_src: F::Of<Freer<F, Src>> = self.inner.lower_step();
66        let chain_fn = self.chain_fn;
67        F::fmap(f_freer_src, move |freer_src: Freer<F, Src>| {
68            freer_src.chain_rc(chain_fn.clone())
69        })
70    }
71}
72
73// ---- Public Freer type ----
74
75#[allow(private_interfaces)]
76/// Freer Monad — a free monad that does not require `F: Functor`.
77///
78/// `Freer<F, A>` stores a computation as a tree of effect steps, each
79/// containing `∃B. (F B, B → Freer F A)`. The existential `B` is erased
80/// via a dyn-safe trait, deferring the `Functor` requirement to `fold_map`.
81///
82/// ```text
83/// Pure(a)       — a finished computation
84/// Impure(step)  — an effect step with continuation
85/// ```
86///
87/// # When to use Freer vs Free
88///
89/// - Use `Free<F, A>` when `F: Functor` — simpler, no overhead.
90/// - Use `Freer<F, A>` when `F` is NOT a functor, or when you want to
91///   build computations without the functor constraint.
92pub enum Freer<F: HKT + 'static, A: 'static> {
93    /// A pure value — the computation is finished.
94    Pure(A),
95    /// An effect step with erased intermediate type.
96    Impure(Box<dyn FreerStep<F, A>>),
97}
98
99impl<F: HKT + 'static, A: 'static> Freer<F, A> {
100    /// Wrap a pure value into the freer monad.
101    pub fn pure(a: A) -> Self {
102        Freer::Pure(a)
103    }
104
105    /// Lift a single effect `F<A>` into the freer monad.
106    ///
107    /// No `F: Functor` required.
108    pub fn lift_f(fa: F::Of<A>) -> Self
109    where
110        F::Of<A>: 'static,
111    {
112        Freer::Impure(Box::new(ImpureStep {
113            effect: fa,
114            cont: Box::new(Freer::Pure),
115        }))
116    }
117
118    /// Map a function over the result of this computation.
119    ///
120    /// No `F: Functor` required. Implemented via `chain`.
121    pub fn fmap<B: 'static>(self, f: impl Fn(A) -> B + 'static) -> Freer<F, B> {
122        self.chain(move |a| Freer::Pure(f(a)))
123    }
124
125    /// Monadic bind — sequence this computation with a function that
126    /// produces the next computation.
127    ///
128    /// No `F: Functor` required. The closure is shared via `Rc` across
129    /// deferred chain layers.
130    pub fn chain<B: 'static>(self, f: impl Fn(A) -> Freer<F, B> + 'static) -> Freer<F, B> {
131        let f = Rc::new(f);
132        self.chain_rc(f)
133    }
134
135    fn chain_rc<B: 'static>(self, f: Rc<dyn Fn(A) -> Freer<F, B>>) -> Freer<F, B> {
136        match self {
137            Freer::Pure(a) => f(a),
138            Freer::Impure(step) => Freer::Impure(Box::new(ChainedStep {
139                inner: step,
140                chain_fn: f,
141            })),
142        }
143    }
144
145    /// Interpret this freer monad into a target monad `M` using a natural
146    /// transformation `NT: F ~> M`.
147    ///
148    /// This requires `F: Functor` (to lower each step) and
149    /// `M: Applicative + Chain` (for the target monad operations).
150    pub fn fold_map<M, NT>(self) -> M::Of<A>
151    where
152        F: Functor,
153        M: Applicative + Chain,
154        NT: NaturalTransformation<F, M>,
155    {
156        match self {
157            Freer::Pure(a) => M::pure(a),
158            Freer::Impure(step) => {
159                // Lower the step to get F<Freer<F, A>>
160                let f_freer: F::Of<Freer<F, A>> = step.lower_step();
161                // Apply NT to get M<Freer<F, A>>
162                let m_freer: M::Of<Freer<F, A>> = NT::transform(f_freer);
163                // Chain with recursive fold_map
164                M::chain(m_freer, |freer| freer.fold_map::<M, NT>())
165            }
166        }
167    }
168}
169
170/// HKT marker for `Freer<F, _>`.
171///
172/// Note: Cannot implement `HKT` or `Functor` due to Rust's GAT limitations
173/// (`type Of<T>` cannot add `T: 'static` in impl when trait doesn't have it).
174/// Use `Freer::fmap` directly.
175pub struct FreerF<F: HKT + 'static>(PhantomData<F>);
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use karpal_core::hkt::OptionF;
181
182    #[test]
183    fn pure_value() {
184        let freer = Freer::<OptionF, i32>::pure(42);
185        match freer {
186            Freer::Pure(v) => assert_eq!(v, 42),
187            Freer::Impure(_) => panic!("expected Pure"),
188        }
189    }
190
191    #[test]
192    fn lift_f_some() {
193        let freer = Freer::<OptionF, i32>::lift_f(Some(10));
194        match freer {
195            Freer::Impure(_) => {} // correct
196            Freer::Pure(_) => panic!("expected Impure"),
197        }
198    }
199
200    #[test]
201    fn chain_pure() {
202        let freer = Freer::<OptionF, i32>::pure(1).chain(|x| Freer::pure(x + 1));
203        match freer {
204            Freer::Pure(v) => assert_eq!(v, 2),
205            _ => panic!("expected Pure"),
206        }
207    }
208
209    #[test]
210    fn fmap_pure() {
211        let freer = Freer::<OptionF, i32>::pure(5).fmap(|x| x * 3);
212        match freer {
213            Freer::Pure(v) => assert_eq!(v, 15),
214            _ => panic!("expected Pure"),
215        }
216    }
217
218    #[test]
219    fn chain_associativity() {
220        // (m >>= f) >>= g
221        let left = Freer::<OptionF, i32>::pure(5)
222            .chain(|x| Freer::pure(x + 1))
223            .chain(|x| Freer::pure(x * 2));
224
225        // m >>= (\x -> f(x) >>= g)
226        let right = Freer::<OptionF, i32>::pure(5)
227            .chain(|x| Freer::<OptionF, i32>::pure(x + 1).chain(|y| Freer::pure(y * 2)));
228
229        match (left, right) {
230            (Freer::Pure(l), Freer::Pure(r)) => assert_eq!(l, r),
231            _ => panic!("expected both Pure"),
232        }
233    }
234
235    // Natural transformation: Option ~> Option (identity)
236    struct OptionId;
237    impl NaturalTransformation<OptionF, OptionF> for OptionId {
238        fn transform<A>(fa: Option<A>) -> Option<A> {
239            fa
240        }
241    }
242
243    #[test]
244    fn fold_map_pure() {
245        let freer = Freer::<OptionF, i32>::pure(42);
246        let result = freer.fold_map::<OptionF, OptionId>();
247        assert_eq!(result, Some(42));
248    }
249
250    #[test]
251    fn fold_map_lift() {
252        let freer = Freer::<OptionF, i32>::lift_f(Some(10));
253        let result = freer.fold_map::<OptionF, OptionId>();
254        assert_eq!(result, Some(10));
255    }
256
257    #[test]
258    fn fold_map_chain() {
259        let freer = Freer::<OptionF, i32>::lift_f(Some(3)).chain(|x| Freer::lift_f(Some(x * 10)));
260        let result = freer.fold_map::<OptionF, OptionId>();
261        assert_eq!(result, Some(30));
262    }
263
264    #[test]
265    fn fold_map_lift_none() {
266        let freer = Freer::<OptionF, i32>::lift_f(None);
267        let result = freer.fold_map::<OptionF, OptionId>();
268        assert_eq!(result, None);
269    }
270
271    #[test]
272    fn fmap_lift_then_fold() {
273        let freer = Freer::<OptionF, i32>::lift_f(Some(5)).fmap(|x| x + 10);
274        let result = freer.fold_map::<OptionF, OptionId>();
275        assert_eq!(result, Some(15));
276    }
277
278    #[test]
279    fn chain_lift_multiple() {
280        let freer = Freer::<OptionF, i32>::lift_f(Some(1))
281            .chain(|x| Freer::lift_f(Some(x + 1)))
282            .chain(|x| Freer::lift_f(Some(x * 10)));
283        let result = freer.fold_map::<OptionF, OptionId>();
284        assert_eq!(result, Some(20)); // (1+1)*10
285    }
286}
287
288#[cfg(test)]
289mod law_tests {
290    use super::*;
291    use karpal_core::hkt::OptionF;
292    use proptest::prelude::*;
293
294    fn extract_pure<F: HKT + 'static, A: 'static>(freer: Freer<F, A>) -> Option<A> {
295        match freer {
296            Freer::Pure(a) => Some(a),
297            Freer::Impure(_) => None,
298        }
299    }
300
301    proptest! {
302        // Monad left identity: pure(a) >>= f == f(a)
303        #[test]
304        fn monad_left_identity(x in any::<i32>()) {
305            let left = extract_pure(
306                Freer::<OptionF, i32>::pure(x)
307                    .chain(|a| Freer::pure(a.wrapping_mul(2))),
308            );
309            let right = Some(x.wrapping_mul(2));
310            prop_assert_eq!(left, right);
311        }
312
313        // Monad right identity: m >>= pure == m
314        #[test]
315        fn monad_right_identity(x in any::<i32>()) {
316            let result = extract_pure(Freer::<OptionF, i32>::pure(x).chain(Freer::pure));
317            prop_assert_eq!(result, Some(x));
318        }
319    }
320}