Skip to main content

karpal_free/
codensity.rs

1#[cfg(feature = "std")]
2use std::boxed::Box;
3
4#[cfg(all(not(feature = "std"), feature = "alloc"))]
5use alloc::boxed::Box;
6
7use core::marker::PhantomData;
8
9use karpal_core::applicative::Applicative;
10use karpal_core::chain::Chain;
11use karpal_core::hkt::HKT;
12
13/// Private dyn-safe trait for the Codensity computation tree.
14///
15/// The only eliminator is `to_monad`, which collapses the tree into `F::Of<A>`
16/// using `F::pure` and `F::chain`. The generic `lower_with` (∀R) cannot be
17/// made dyn-safe, so we only support lowering into F's own monad.
18trait CodensityInner<F: HKT + 'static, A: 'static> {
19    /// Collapse this computation into `F::Of<A>`.
20    fn to_monad(self: Box<Self>) -> F::Of<A>
21    where
22        F: Applicative + Chain;
23}
24
25/// Pure layer: stores a value.
26struct CodensityPure<F: HKT + 'static, A: 'static> {
27    value: A,
28    _marker: PhantomData<F>,
29}
30
31impl<F: HKT + 'static, A: 'static> CodensityInner<F, A> for CodensityPure<F, A> {
32    fn to_monad(self: Box<Self>) -> F::Of<A>
33    where
34        F: Applicative + Chain,
35    {
36        F::pure(self.value)
37    }
38}
39
40/// Map layer: wraps an inner computation with a transform.
41struct CodensityMap<F: HKT + 'static, Src: 'static, A: 'static> {
42    inner: Box<dyn CodensityInner<F, Src>>,
43    transform: Box<dyn Fn(Src) -> A>,
44}
45
46impl<F: HKT + 'static, Src: 'static, A: 'static> CodensityInner<F, A> for CodensityMap<F, Src, A> {
47    fn to_monad(self: Box<Self>) -> F::Of<A>
48    where
49        F: Applicative + Chain,
50    {
51        let f_src: F::Of<Src> = self.inner.to_monad();
52        F::fmap(f_src, self.transform)
53    }
54}
55
56/// Bind layer: wraps an inner computation with a monadic bind function.
57struct CodensityBind<F: HKT + 'static, Src: 'static, A: 'static> {
58    inner: Box<dyn CodensityInner<F, Src>>,
59    bind_fn: Box<dyn Fn(Src) -> Codensity<F, A>>,
60}
61
62impl<F: HKT + 'static, Src: 'static, A: 'static> CodensityInner<F, A> for CodensityBind<F, Src, A> {
63    fn to_monad(self: Box<Self>) -> F::Of<A>
64    where
65        F: Applicative + Chain,
66    {
67        let f_src: F::Of<Src> = self.inner.to_monad();
68        let bind_fn = self.bind_fn;
69        F::chain(f_src, move |src| (bind_fn)(src).inner.to_monad())
70    }
71}
72
73/// Codensity Monad — the CPS transform of a type constructor `F`.
74///
75/// `Codensity<F, A> ≅ ∀R. (A → F R) → F R`
76///
77/// This is the right Kan extension of `F` along itself (`Ran F F A`),
78/// specialised into a concrete type. The key property: `pure`, `fmap`,
79/// and `chain` require **no bounds on `F`** — only `to_monad` needs
80/// `F: Applicative + Chain`.
81///
82/// # Use cases
83///
84/// - **Monad transformer improvement**: wrapping a free monad in Codensity
85///   can improve asymptotic performance of left-associated binds.
86/// - **CPS conversion**: build computations in CPS, then interpret into
87///   the target monad via `to_monad`.
88///
89/// Note: Due to Rust's GAT limitations (`type Of<T>` cannot add `T: 'static`),
90/// `CodensityF` does not implement `HKT` or `Monad`. Use inherent methods.
91pub struct Codensity<F: HKT + 'static, A: 'static> {
92    inner: Box<dyn CodensityInner<F, A>>,
93}
94
95impl<F: HKT + 'static, A: 'static> Codensity<F, A> {
96    /// Wrap a pure value. No bounds on `F` required.
97    pub fn pure(a: A) -> Self {
98        Codensity {
99            inner: Box::new(CodensityPure {
100                value: a,
101                _marker: PhantomData,
102            }),
103        }
104    }
105
106    /// Map a function over the result. No bounds on `F` required.
107    pub fn fmap<B: 'static>(self, f: impl Fn(A) -> B + 'static) -> Codensity<F, B> {
108        Codensity {
109            inner: Box::new(CodensityMap {
110                inner: self.inner,
111                transform: Box::new(f),
112            }),
113        }
114    }
115
116    /// Monadic bind. No bounds on `F` required.
117    pub fn chain<B: 'static>(self, f: impl Fn(A) -> Codensity<F, B> + 'static) -> Codensity<F, B> {
118        Codensity {
119            inner: Box::new(CodensityBind {
120                inner: self.inner,
121                bind_fn: Box::new(f),
122            }),
123        }
124    }
125
126    /// Collapse the computation into `F::Of<A>`.
127    ///
128    /// This is the standard way to extract a result from Codensity.
129    /// Requires `F: Applicative + Chain` (i.e., F must be a monad).
130    pub fn to_monad(self) -> F::Of<A>
131    where
132        F: Applicative + Chain,
133    {
134        self.inner.to_monad()
135    }
136}
137
138/// Marker type for `Codensity<F, _>`.
139///
140/// Note: Cannot implement `HKT` or `Monad` due to Rust's GAT limitations.
141/// Use `Codensity::pure`, `Codensity::fmap`, `Codensity::chain` directly.
142pub struct CodensityF<F: HKT + 'static>(PhantomData<F>);
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use karpal_core::hkt::OptionF;
148
149    #[test]
150    fn pure_to_monad() {
151        let c = Codensity::<OptionF, i32>::pure(42);
152        let result = c.to_monad();
153        assert_eq!(result, Some(42));
154    }
155
156    #[test]
157    fn fmap_to_monad() {
158        let c = Codensity::<OptionF, i32>::pure(5).fmap(|x| x * 3);
159        let result = c.to_monad();
160        assert_eq!(result, Some(15));
161    }
162
163    #[test]
164    fn chain_to_monad() {
165        let c = Codensity::<OptionF, i32>::pure(10).chain(|x| Codensity::pure(x + 1));
166        let result = c.to_monad();
167        assert_eq!(result, Some(11));
168    }
169
170    #[test]
171    fn chain_multiple() {
172        let c = Codensity::<OptionF, i32>::pure(1)
173            .chain(|x| Codensity::pure(x + 1))
174            .chain(|x| Codensity::pure(x * 10))
175            .chain(|x| Codensity::pure(x + 5));
176        let result = c.to_monad();
177        // (1 + 1) * 10 + 5 = 25
178        assert_eq!(result, Some(25));
179    }
180
181    #[test]
182    fn fmap_then_chain() {
183        let c = Codensity::<OptionF, i32>::pure(3)
184            .fmap(|x| x * 2)
185            .chain(|x| Codensity::pure(x + 100));
186        let result = c.to_monad();
187        assert_eq!(result, Some(106));
188    }
189
190    #[test]
191    fn chain_associativity() {
192        // (m >>= f) >>= g
193        let left = Codensity::<OptionF, i32>::pure(5)
194            .chain(|x| Codensity::pure(x + 1))
195            .chain(|x| Codensity::pure(x * 2));
196
197        // m >>= (\x -> f(x) >>= g)
198        let right = Codensity::<OptionF, i32>::pure(5)
199            .chain(|x| Codensity::<OptionF, i32>::pure(x + 1).chain(|y| Codensity::pure(y * 2)));
200
201        assert_eq!(left.to_monad(), right.to_monad());
202        // Both should be Some((5 + 1) * 2) = Some(12)
203    }
204
205    #[test]
206    fn monad_left_identity() {
207        let left = Codensity::<OptionF, i32>::pure(4).chain(|x| Codensity::pure(x * 3));
208        assert_eq!(left.to_monad(), Some(12));
209    }
210
211    #[test]
212    fn monad_right_identity() {
213        let m = Codensity::<OptionF, i32>::pure(42);
214        let result = m.chain(Codensity::pure);
215        assert_eq!(result.to_monad(), Some(42));
216    }
217
218    #[test]
219    fn fmap_changes_type() {
220        let c = Codensity::<OptionF, i32>::pure(42).fmap(|x| format!("val={x}"));
221        assert_eq!(c.to_monad(), Some("val=42".to_string()));
222    }
223}
224
225#[cfg(test)]
226mod law_tests {
227    use super::*;
228    use karpal_core::hkt::OptionF;
229    use proptest::prelude::*;
230
231    proptest! {
232        // Functor identity: fmap(id) == id
233        #[test]
234        fn functor_identity(x in any::<i32>()) {
235            let result = Codensity::<OptionF, i32>::pure(x)
236                .fmap(|a| a)
237                .to_monad();
238            prop_assert_eq!(result, Some(x));
239        }
240
241        // Functor composition: fmap(g . f) == fmap(f) . fmap(g)
242        #[test]
243        fn functor_composition(x in any::<i32>()) {
244            let f = |a: i32| a.wrapping_add(1);
245            let g = |a: i32| a.wrapping_mul(2);
246
247            let left = Codensity::<OptionF, i32>::pure(x)
248                .fmap(move |a| g(f(a)))
249                .to_monad();
250            let right = Codensity::<OptionF, i32>::pure(x)
251                .fmap(f)
252                .fmap(g)
253                .to_monad();
254            prop_assert_eq!(left, right);
255        }
256
257        // Monad left identity: pure(a) >>= f == f(a)
258        #[test]
259        fn monad_left_identity(x in any::<i32>()) {
260            let left = Codensity::<OptionF, i32>::pure(x)
261                .chain(|a| Codensity::pure(a.wrapping_mul(2)))
262                .to_monad();
263            let right = Codensity::<OptionF, i32>::pure(x.wrapping_mul(2)).to_monad();
264            prop_assert_eq!(left, right);
265        }
266
267        // Monad right identity: m >>= pure == m
268        #[test]
269        fn monad_right_identity(x in any::<i32>()) {
270            let result = Codensity::<OptionF, i32>::pure(x)
271                .chain(Codensity::pure)
272                .to_monad();
273            prop_assert_eq!(result, Some(x));
274        }
275
276        // Monad associativity: (m >>= f) >>= g == m >>= (\x -> f(x) >>= g)
277        #[test]
278        fn monad_associativity(x in any::<i32>()) {
279            let left = Codensity::<OptionF, i32>::pure(x)
280                .chain(|a| Codensity::pure(a.wrapping_add(1)))
281                .chain(|a| Codensity::pure(a.wrapping_mul(2)))
282                .to_monad();
283
284            let right = Codensity::<OptionF, i32>::pure(x)
285                .chain(|a| {
286                    Codensity::<OptionF, i32>::pure(a.wrapping_add(1))
287                        .chain(|b| Codensity::pure(b.wrapping_mul(2)))
288                })
289                .to_monad();
290
291            prop_assert_eq!(left, right);
292        }
293    }
294}