Skip to main content

karpal_free/
freer.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "std")]
5use std::boxed::Box;
6
7#[cfg(all(not(feature = "std"), feature = "alloc"))]
8use alloc::boxed::Box;
9
10#[cfg(feature = "std")]
11use std::rc::Rc;
12
13#[cfg(all(not(feature = "std"), feature = "alloc"))]
14use alloc::rc::Rc;
15
16use core::marker::PhantomData;
17
18use karpal_core::applicative::Applicative;
19use karpal_core::chain::Chain;
20use karpal_core::functor::Functor;
21use karpal_core::hkt::HKT;
22use karpal_core::natural::NaturalTransformation;
23
24// ---- Private step types for the existential encoding ----
25
26/// Dyn-safe trait for an effect step in the Freer monad.
27///
28/// Each step stores `∃B. (F::Of<B>, B → Freer<F, A>)` — an effect value
29/// and a continuation. The intermediate type `B` is hidden behind the
30/// trait object.
31trait FreerStep<F: HKT + 'static, A: 'static> {
32    /// Lower this step to `F::Of<Freer<F, A>>` by applying F::fmap.
33    /// Only called during `fold_map`, where `F: Functor` is available.
34    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
35    where
36        F: Functor;
37}
38
39/// A leaf step: stores an effect `F::Of<B>` and a continuation `B → Freer<F, A>`.
40struct ImpureStep<F: HKT + 'static, A: 'static, B: 'static> {
41    effect: F::Of<B>,
42    cont: Box<dyn Fn(B) -> Freer<F, A>>,
43}
44
45impl<F: HKT + 'static, A: 'static, B: 'static> FreerStep<F, A> for ImpureStep<F, A, B> {
46    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
47    where
48        F: Functor,
49    {
50        F::fmap(self.effect, self.cont)
51    }
52}
53
54/// A chained step: wraps an inner step with a deferred chain operation.
55///
56/// When lowered, first lowers the inner step to get `F::Of<Freer<F, Src>>`,
57/// then fmaps `chain_rc` over it to produce `F::Of<Freer<F, A>>`.
58struct ChainedStep<F: HKT + 'static, Src: 'static, A: 'static> {
59    inner: Box<dyn FreerStep<F, Src>>,
60    chain_fn: Rc<dyn Fn(Src) -> Freer<F, A>>,
61}
62
63impl<F: HKT + 'static, Src: 'static, A: 'static> FreerStep<F, A> for ChainedStep<F, Src, A> {
64    fn lower_step(self: Box<Self>) -> F::Of<Freer<F, A>>
65    where
66        F: Functor,
67    {
68        let f_freer_src: F::Of<Freer<F, Src>> = self.inner.lower_step();
69        let chain_fn = self.chain_fn;
70        F::fmap(f_freer_src, move |freer_src: Freer<F, Src>| {
71            freer_src.chain_rc(chain_fn.clone())
72        })
73    }
74}
75
76// ---- Public Freer type ----
77
78#[allow(private_interfaces)]
79/// Freer Monad — a free monad that does not require `F: Functor`.
80///
81/// `Freer<F, A>` stores a computation as a tree of effect steps, each
82/// containing `∃B. (F B, B → Freer F A)`. The existential `B` is erased
83/// via a dyn-safe trait, deferring the `Functor` requirement to `fold_map`.
84///
85/// ```text
86/// Pure(a)       — a finished computation
87/// Impure(step)  — an effect step with continuation
88/// ```
89///
90/// # When to use Freer vs Free
91///
92/// - Use `Free<F, A>` when `F: Functor` — simpler, no overhead.
93/// - Use `Freer<F, A>` when `F` is NOT a functor, or when you want to
94///   build computations without the functor constraint.
95pub enum Freer<F: HKT + 'static, A: 'static> {
96    /// A pure value — the computation is finished.
97    Pure(A),
98    /// An effect step with erased intermediate type.
99    Impure(Box<dyn FreerStep<F, A>>),
100}
101
102impl<F: HKT + 'static, A: 'static> Freer<F, A> {
103    /// Wrap a pure value into the freer monad.
104    pub fn pure(a: A) -> Self {
105        Freer::Pure(a)
106    }
107
108    /// Lift a single effect `F<A>` into the freer monad.
109    ///
110    /// No `F: Functor` required.
111    pub fn lift_f(fa: F::Of<A>) -> Self
112    where
113        F::Of<A>: 'static,
114    {
115        Freer::Impure(Box::new(ImpureStep {
116            effect: fa,
117            cont: Box::new(Freer::Pure),
118        }))
119    }
120
121    /// Map a function over the result of this computation.
122    ///
123    /// No `F: Functor` required. Implemented via `chain`.
124    pub fn fmap<B: 'static>(self, f: impl Fn(A) -> B + 'static) -> Freer<F, B> {
125        self.chain(move |a| Freer::Pure(f(a)))
126    }
127
128    /// Monadic bind — sequence this computation with a function that
129    /// produces the next computation.
130    ///
131    /// No `F: Functor` required. The closure is shared via `Rc` across
132    /// deferred chain layers.
133    pub fn chain<B: 'static>(self, f: impl Fn(A) -> Freer<F, B> + 'static) -> Freer<F, B> {
134        let f = Rc::new(f);
135        self.chain_rc(f)
136    }
137
138    fn chain_rc<B: 'static>(self, f: Rc<dyn Fn(A) -> Freer<F, B>>) -> Freer<F, B> {
139        match self {
140            Freer::Pure(a) => f(a),
141            Freer::Impure(step) => Freer::Impure(Box::new(ChainedStep {
142                inner: step,
143                chain_fn: f,
144            })),
145        }
146    }
147
148    /// Interpret this freer monad into a target monad `M` using a natural
149    /// transformation `NT: F ~> M`.
150    ///
151    /// This requires `F: Functor` (to lower each step) and
152    /// `M: Applicative + Chain` (for the target monad operations).
153    pub fn fold_map<M, NT>(self) -> M::Of<A>
154    where
155        F: Functor,
156        M: Applicative + Chain,
157        NT: NaturalTransformation<F, M>,
158    {
159        match self {
160            Freer::Pure(a) => M::pure(a),
161            Freer::Impure(step) => {
162                // Lower the step to get F<Freer<F, A>>
163                let f_freer: F::Of<Freer<F, A>> = step.lower_step();
164                // Apply NT to get M<Freer<F, A>>
165                let m_freer: M::Of<Freer<F, A>> = NT::transform(f_freer);
166                // Chain with recursive fold_map
167                M::chain(m_freer, |freer| freer.fold_map::<M, NT>())
168            }
169        }
170    }
171}
172
173/// HKT marker for `Freer<F, _>`.
174///
175/// Note: Cannot implement `HKT` or `Functor` due to Rust's GAT limitations
176/// (`type Of<T>` cannot add `T: 'static` in impl when trait doesn't have it).
177/// Use `Freer::fmap` directly.
178pub struct FreerF<F: HKT + 'static>(PhantomData<F>);
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use karpal_core::hkt::OptionF;
184
185    #[test]
186    fn pure_value() {
187        let freer = Freer::<OptionF, i32>::pure(42);
188        match freer {
189            Freer::Pure(v) => assert_eq!(v, 42),
190            Freer::Impure(_) => panic!("expected Pure"),
191        }
192    }
193
194    #[test]
195    fn lift_f_some() {
196        let freer = Freer::<OptionF, i32>::lift_f(Some(10));
197        match freer {
198            Freer::Impure(_) => {} // correct
199            Freer::Pure(_) => panic!("expected Impure"),
200        }
201    }
202
203    #[test]
204    fn chain_pure() {
205        let freer = Freer::<OptionF, i32>::pure(1).chain(|x| Freer::pure(x + 1));
206        match freer {
207            Freer::Pure(v) => assert_eq!(v, 2),
208            _ => panic!("expected Pure"),
209        }
210    }
211
212    #[test]
213    fn fmap_pure() {
214        let freer = Freer::<OptionF, i32>::pure(5).fmap(|x| x * 3);
215        match freer {
216            Freer::Pure(v) => assert_eq!(v, 15),
217            _ => panic!("expected Pure"),
218        }
219    }
220
221    #[test]
222    fn chain_associativity() {
223        // (m >>= f) >>= g
224        let left = Freer::<OptionF, i32>::pure(5)
225            .chain(|x| Freer::pure(x + 1))
226            .chain(|x| Freer::pure(x * 2));
227
228        // m >>= (\x -> f(x) >>= g)
229        let right = Freer::<OptionF, i32>::pure(5)
230            .chain(|x| Freer::<OptionF, i32>::pure(x + 1).chain(|y| Freer::pure(y * 2)));
231
232        match (left, right) {
233            (Freer::Pure(l), Freer::Pure(r)) => assert_eq!(l, r),
234            _ => panic!("expected both Pure"),
235        }
236    }
237
238    // Natural transformation: Option ~> Option (identity)
239    struct OptionId;
240    impl NaturalTransformation<OptionF, OptionF> for OptionId {
241        fn transform<A>(fa: Option<A>) -> Option<A> {
242            fa
243        }
244    }
245
246    #[test]
247    fn fold_map_pure() {
248        let freer = Freer::<OptionF, i32>::pure(42);
249        let result = freer.fold_map::<OptionF, OptionId>();
250        assert_eq!(result, Some(42));
251    }
252
253    #[test]
254    fn fold_map_lift() {
255        let freer = Freer::<OptionF, i32>::lift_f(Some(10));
256        let result = freer.fold_map::<OptionF, OptionId>();
257        assert_eq!(result, Some(10));
258    }
259
260    #[test]
261    fn fold_map_chain() {
262        let freer = Freer::<OptionF, i32>::lift_f(Some(3)).chain(|x| Freer::lift_f(Some(x * 10)));
263        let result = freer.fold_map::<OptionF, OptionId>();
264        assert_eq!(result, Some(30));
265    }
266
267    #[test]
268    fn fold_map_lift_none() {
269        let freer = Freer::<OptionF, i32>::lift_f(None);
270        let result = freer.fold_map::<OptionF, OptionId>();
271        assert_eq!(result, None);
272    }
273
274    #[test]
275    fn fmap_lift_then_fold() {
276        let freer = Freer::<OptionF, i32>::lift_f(Some(5)).fmap(|x| x + 10);
277        let result = freer.fold_map::<OptionF, OptionId>();
278        assert_eq!(result, Some(15));
279    }
280
281    #[test]
282    fn chain_lift_multiple() {
283        let freer = Freer::<OptionF, i32>::lift_f(Some(1))
284            .chain(|x| Freer::lift_f(Some(x + 1)))
285            .chain(|x| Freer::lift_f(Some(x * 10)));
286        let result = freer.fold_map::<OptionF, OptionId>();
287        assert_eq!(result, Some(20)); // (1+1)*10
288    }
289}
290
291#[cfg(test)]
292mod law_tests {
293    use super::*;
294    use karpal_core::hkt::OptionF;
295    use proptest::prelude::*;
296
297    fn extract_pure<F: HKT + 'static, A: 'static>(freer: Freer<F, A>) -> Option<A> {
298        match freer {
299            Freer::Pure(a) => Some(a),
300            Freer::Impure(_) => None,
301        }
302    }
303
304    proptest! {
305        // Monad left identity: pure(a) >>= f == f(a)
306        #[test]
307        fn monad_left_identity(x in any::<i32>()) {
308            let left = extract_pure(
309                Freer::<OptionF, i32>::pure(x)
310                    .chain(|a| Freer::pure(a.wrapping_mul(2))),
311            );
312            let right = Some(x.wrapping_mul(2));
313            prop_assert_eq!(left, right);
314        }
315
316        // Monad right identity: m >>= pure == m
317        #[test]
318        fn monad_right_identity(x in any::<i32>()) {
319            let result = extract_pure(Freer::<OptionF, i32>::pure(x).chain(Freer::pure));
320            prop_assert_eq!(result, Some(x));
321        }
322    }
323}