Skip to main content

karpal_free/
codensity.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "std")]
5use std::boxed::Box;
6
7#[cfg(all(not(feature = "std"), feature = "alloc"))]
8use alloc::boxed::Box;
9
10use core::marker::PhantomData;
11
12use karpal_core::applicative::Applicative;
13use karpal_core::chain::Chain;
14use karpal_core::hkt::HKT;
15
16/// Private dyn-safe trait for the Codensity computation tree.
17///
18/// The only eliminator is `to_monad`, which collapses the tree into `F::Of<A>`
19/// using `F::pure` and `F::chain`. The generic `lower_with` (∀R) cannot be
20/// made dyn-safe, so we only support lowering into F's own monad.
21trait CodensityInner<F: HKT + 'static, A: 'static> {
22    /// Collapse this computation into `F::Of<A>`.
23    fn to_monad(self: Box<Self>) -> F::Of<A>
24    where
25        F: Applicative + Chain;
26}
27
28/// Pure layer: stores a value.
29struct CodensityPure<F: HKT + 'static, A: 'static> {
30    value: A,
31    _marker: PhantomData<F>,
32}
33
34impl<F: HKT + 'static, A: 'static> CodensityInner<F, A> for CodensityPure<F, A> {
35    fn to_monad(self: Box<Self>) -> F::Of<A>
36    where
37        F: Applicative + Chain,
38    {
39        F::pure(self.value)
40    }
41}
42
43/// Map layer: wraps an inner computation with a transform.
44struct CodensityMap<F: HKT + 'static, Src: 'static, A: 'static> {
45    inner: Box<dyn CodensityInner<F, Src>>,
46    transform: Box<dyn Fn(Src) -> A>,
47}
48
49impl<F: HKT + 'static, Src: 'static, A: 'static> CodensityInner<F, A> for CodensityMap<F, Src, A> {
50    fn to_monad(self: Box<Self>) -> F::Of<A>
51    where
52        F: Applicative + Chain,
53    {
54        let f_src: F::Of<Src> = self.inner.to_monad();
55        F::fmap(f_src, self.transform)
56    }
57}
58
59/// Bind layer: wraps an inner computation with a monadic bind function.
60struct CodensityBind<F: HKT + 'static, Src: 'static, A: 'static> {
61    inner: Box<dyn CodensityInner<F, Src>>,
62    bind_fn: Box<dyn Fn(Src) -> Codensity<F, A>>,
63}
64
65impl<F: HKT + 'static, Src: 'static, A: 'static> CodensityInner<F, A> for CodensityBind<F, Src, A> {
66    fn to_monad(self: Box<Self>) -> F::Of<A>
67    where
68        F: Applicative + Chain,
69    {
70        let f_src: F::Of<Src> = self.inner.to_monad();
71        let bind_fn = self.bind_fn;
72        F::chain(f_src, move |src| (bind_fn)(src).inner.to_monad())
73    }
74}
75
76/// Codensity Monad — the CPS transform of a type constructor `F`.
77///
78/// `Codensity<F, A> ≅ ∀R. (A → F R) → F R`
79///
80/// This is the right Kan extension of `F` along itself (`Ran F F A`),
81/// specialised into a concrete type. The key property: `pure`, `fmap`,
82/// and `chain` require **no bounds on `F`** — only `to_monad` needs
83/// `F: Applicative + Chain`.
84///
85/// # Use cases
86///
87/// - **Monad transformer improvement**: wrapping a free monad in Codensity
88///   can improve asymptotic performance of left-associated binds.
89/// - **CPS conversion**: build computations in CPS, then interpret into
90///   the target monad via `to_monad`.
91///
92/// Note: Due to Rust's GAT limitations (`type Of<T>` cannot add `T: 'static`),
93/// `CodensityF` does not implement `HKT` or `Monad`. Use inherent methods.
94pub struct Codensity<F: HKT + 'static, A: 'static> {
95    inner: Box<dyn CodensityInner<F, A>>,
96}
97
98impl<F: HKT + 'static, A: 'static> Codensity<F, A> {
99    /// Wrap a pure value. No bounds on `F` required.
100    pub fn pure(a: A) -> Self {
101        Codensity {
102            inner: Box::new(CodensityPure {
103                value: a,
104                _marker: PhantomData,
105            }),
106        }
107    }
108
109    /// Map a function over the result. No bounds on `F` required.
110    pub fn fmap<B: 'static>(self, f: impl Fn(A) -> B + 'static) -> Codensity<F, B> {
111        Codensity {
112            inner: Box::new(CodensityMap {
113                inner: self.inner,
114                transform: Box::new(f),
115            }),
116        }
117    }
118
119    /// Monadic bind. No bounds on `F` required.
120    pub fn chain<B: 'static>(self, f: impl Fn(A) -> Codensity<F, B> + 'static) -> Codensity<F, B> {
121        Codensity {
122            inner: Box::new(CodensityBind {
123                inner: self.inner,
124                bind_fn: Box::new(f),
125            }),
126        }
127    }
128
129    /// Collapse the computation into `F::Of<A>`.
130    ///
131    /// This is the standard way to extract a result from Codensity.
132    /// Requires `F: Applicative + Chain` (i.e., F must be a monad).
133    pub fn to_monad(self) -> F::Of<A>
134    where
135        F: Applicative + Chain,
136    {
137        self.inner.to_monad()
138    }
139}
140
141/// Marker type for `Codensity<F, _>`.
142///
143/// Note: Cannot implement `HKT` or `Monad` due to Rust's GAT limitations.
144/// Use `Codensity::pure`, `Codensity::fmap`, `Codensity::chain` directly.
145pub struct CodensityF<F: HKT + 'static>(PhantomData<F>);
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use karpal_core::hkt::OptionF;
151
152    #[test]
153    fn pure_to_monad() {
154        let c = Codensity::<OptionF, i32>::pure(42);
155        let result = c.to_monad();
156        assert_eq!(result, Some(42));
157    }
158
159    #[test]
160    fn fmap_to_monad() {
161        let c = Codensity::<OptionF, i32>::pure(5).fmap(|x| x * 3);
162        let result = c.to_monad();
163        assert_eq!(result, Some(15));
164    }
165
166    #[test]
167    fn chain_to_monad() {
168        let c = Codensity::<OptionF, i32>::pure(10).chain(|x| Codensity::pure(x + 1));
169        let result = c.to_monad();
170        assert_eq!(result, Some(11));
171    }
172
173    #[test]
174    fn chain_multiple() {
175        let c = Codensity::<OptionF, i32>::pure(1)
176            .chain(|x| Codensity::pure(x + 1))
177            .chain(|x| Codensity::pure(x * 10))
178            .chain(|x| Codensity::pure(x + 5));
179        let result = c.to_monad();
180        // (1 + 1) * 10 + 5 = 25
181        assert_eq!(result, Some(25));
182    }
183
184    #[test]
185    fn fmap_then_chain() {
186        let c = Codensity::<OptionF, i32>::pure(3)
187            .fmap(|x| x * 2)
188            .chain(|x| Codensity::pure(x + 100));
189        let result = c.to_monad();
190        assert_eq!(result, Some(106));
191    }
192
193    #[test]
194    fn chain_associativity() {
195        // (m >>= f) >>= g
196        let left = Codensity::<OptionF, i32>::pure(5)
197            .chain(|x| Codensity::pure(x + 1))
198            .chain(|x| Codensity::pure(x * 2));
199
200        // m >>= (\x -> f(x) >>= g)
201        let right = Codensity::<OptionF, i32>::pure(5)
202            .chain(|x| Codensity::<OptionF, i32>::pure(x + 1).chain(|y| Codensity::pure(y * 2)));
203
204        assert_eq!(left.to_monad(), right.to_monad());
205        // Both should be Some((5 + 1) * 2) = Some(12)
206    }
207
208    #[test]
209    fn monad_left_identity() {
210        let left = Codensity::<OptionF, i32>::pure(4).chain(|x| Codensity::pure(x * 3));
211        assert_eq!(left.to_monad(), Some(12));
212    }
213
214    #[test]
215    fn monad_right_identity() {
216        let m = Codensity::<OptionF, i32>::pure(42);
217        let result = m.chain(Codensity::pure);
218        assert_eq!(result.to_monad(), Some(42));
219    }
220
221    #[test]
222    fn fmap_changes_type() {
223        let c = Codensity::<OptionF, i32>::pure(42).fmap(|x| format!("val={x}"));
224        assert_eq!(c.to_monad(), Some("val=42".to_string()));
225    }
226}
227
228#[cfg(test)]
229mod law_tests {
230    use super::*;
231    use karpal_core::hkt::OptionF;
232    use proptest::prelude::*;
233
234    proptest! {
235        // Functor identity: fmap(id) == id
236        #[test]
237        fn functor_identity(x in any::<i32>()) {
238            let result = Codensity::<OptionF, i32>::pure(x)
239                .fmap(|a| a)
240                .to_monad();
241            prop_assert_eq!(result, Some(x));
242        }
243
244        // Functor composition: fmap(g . f) == fmap(f) . fmap(g)
245        #[test]
246        fn functor_composition(x in any::<i32>()) {
247            let f = |a: i32| a.wrapping_add(1);
248            let g = |a: i32| a.wrapping_mul(2);
249
250            let left = Codensity::<OptionF, i32>::pure(x)
251                .fmap(move |a| g(f(a)))
252                .to_monad();
253            let right = Codensity::<OptionF, i32>::pure(x)
254                .fmap(f)
255                .fmap(g)
256                .to_monad();
257            prop_assert_eq!(left, right);
258        }
259
260        // Monad left identity: pure(a) >>= f == f(a)
261        #[test]
262        fn monad_left_identity(x in any::<i32>()) {
263            let left = Codensity::<OptionF, i32>::pure(x)
264                .chain(|a| Codensity::pure(a.wrapping_mul(2)))
265                .to_monad();
266            let right = Codensity::<OptionF, i32>::pure(x.wrapping_mul(2)).to_monad();
267            prop_assert_eq!(left, right);
268        }
269
270        // Monad right identity: m >>= pure == m
271        #[test]
272        fn monad_right_identity(x in any::<i32>()) {
273            let result = Codensity::<OptionF, i32>::pure(x)
274                .chain(Codensity::pure)
275                .to_monad();
276            prop_assert_eq!(result, Some(x));
277        }
278
279        // Monad associativity: (m >>= f) >>= g == m >>= (\x -> f(x) >>= g)
280        #[test]
281        fn monad_associativity(x in any::<i32>()) {
282            let left = Codensity::<OptionF, i32>::pure(x)
283                .chain(|a| Codensity::pure(a.wrapping_add(1)))
284                .chain(|a| Codensity::pure(a.wrapping_mul(2)))
285                .to_monad();
286
287            let right = Codensity::<OptionF, i32>::pure(x)
288                .chain(|a| {
289                    Codensity::<OptionF, i32>::pure(a.wrapping_add(1))
290                        .chain(|b| Codensity::pure(b.wrapping_mul(2)))
291                })
292                .to_monad();
293
294            prop_assert_eq!(left, right);
295        }
296    }
297}