Skip to main content

karpal_free/
cofree.rs

1#[cfg(feature = "std")]
2use std::boxed::Box;
3
4#[cfg(all(not(feature = "std"), feature = "alloc"))]
5use alloc::boxed::Box;
6
7use core::marker::PhantomData;
8
9use karpal_core::comonad::Comonad;
10use karpal_core::extend::Extend;
11use karpal_core::functor::Functor;
12use karpal_core::hkt::HKT;
13
14/// Cofree Comonad — the dual of the Free Monad.
15///
16/// `Cofree<F, A>` pairs a value `A` at each node with a branching structure
17/// determined by `F`. The choice of `F` determines the shape:
18/// - `OptionF` → a non-empty list (finite stream)
19/// - `VecF` → a rose tree
20/// - `IdentityF` → an infinite stream
21///
22/// Each node carries a value (`head`) and subtrees (`tail`).
23pub struct Cofree<F: HKT, A> {
24    /// The value at this node.
25    pub head: A,
26    /// The subtrees, with branching structure determined by `F`.
27    pub tail: Box<F::Of<Cofree<F, A>>>,
28}
29
30impl<F: HKT, A> Cofree<F, A> {
31    /// Create a Cofree node with the given head and tail.
32    pub fn new(head: A, tail: F::Of<Cofree<F, A>>) -> Self {
33        Cofree {
34            head,
35            tail: Box::new(tail),
36        }
37    }
38
39    /// Extract the head value.
40    pub fn extract(&self) -> A
41    where
42        A: Clone,
43    {
44        self.head.clone()
45    }
46}
47
48impl<F: HKT + Functor, A> Cofree<F, A> {
49    /// Map a function over all head values in the tree.
50    pub fn fmap<B>(self, f: impl Fn(A) -> B) -> Cofree<F, B> {
51        self.fmap_inner(&f)
52    }
53
54    fn fmap_inner<B>(self, f: &dyn Fn(A) -> B) -> Cofree<F, B> {
55        Cofree {
56            head: f(self.head),
57            tail: Box::new(F::fmap(*self.tail, |child| child.fmap_inner(f))),
58        }
59    }
60
61    /// Apply a context-aware function to every position in the tree.
62    ///
63    /// At each node, `f` receives the entire sub-cofree rooted at that node
64    /// and produces the new head value.
65    pub fn extend<B>(self, f: impl Fn(&Cofree<F, A>) -> B) -> Cofree<F, B>
66    where
67        A: Clone,
68    {
69        self.extend_inner(&f)
70    }
71
72    fn extend_inner<B>(self, f: &dyn Fn(&Cofree<F, A>) -> B) -> Cofree<F, B>
73    where
74        A: Clone,
75    {
76        let b = f(&self);
77        let new_tail = F::fmap(*self.tail, |child| child.extend_inner(f));
78        Cofree {
79            head: b,
80            tail: Box::new(new_tail),
81        }
82    }
83
84    /// Build a Cofree from a seed value and an unfolding function.
85    ///
86    /// The function takes a seed and returns `(head, F<Seed>)` — the value
87    /// at this node and seeds for the subtrees.
88    pub fn unfold<Seed>(seed: Seed, f: impl Fn(&Seed) -> (A, F::Of<Seed>)) -> Self {
89        Self::unfold_inner(seed, &f)
90    }
91
92    #[allow(clippy::type_complexity)]
93    fn unfold_inner<Seed>(seed: Seed, f: &dyn Fn(&Seed) -> (A, F::Of<Seed>)) -> Self {
94        let (head, f_seeds) = f(&seed);
95        let tail = F::fmap(f_seeds, |child_seed| Self::unfold_inner(child_seed, f));
96        Cofree {
97            head,
98            tail: Box::new(tail),
99        }
100    }
101}
102
103/// HKT marker for `Cofree<F, _>`.
104pub struct CofreeF<F: HKT>(PhantomData<F>);
105
106impl<F: HKT> HKT for CofreeF<F> {
107    type Of<T> = Cofree<F, T>;
108}
109
110impl<F: HKT + Functor> Functor for CofreeF<F> {
111    fn fmap<A, B>(fa: Cofree<F, A>, f: impl Fn(A) -> B) -> Cofree<F, B> {
112        fa.fmap(f)
113    }
114}
115
116impl<F: HKT + Functor> Extend for CofreeF<F> {
117    fn extend<A, B>(wa: Cofree<F, A>, f: impl Fn(&Cofree<F, A>) -> B) -> Cofree<F, B>
118    where
119        A: Clone,
120    {
121        wa.extend_inner(&f)
122    }
123}
124
125impl<F: HKT + Functor> Comonad for CofreeF<F> {
126    fn extract<A: Clone>(wa: &Cofree<F, A>) -> A {
127        wa.head.clone()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use karpal_core::hkt::OptionF;
135
136    fn make_stream(values: &[i32]) -> Cofree<OptionF, i32> {
137        match values {
138            [] => Cofree::new(0, None),
139            [x] => Cofree::new(*x, None),
140            [x, rest @ ..] => Cofree::new(*x, Some(make_stream(rest))),
141        }
142    }
143
144    #[test]
145    fn extract_head() {
146        let cofree = make_stream(&[1, 2, 3]);
147        assert_eq!(cofree.extract(), 1);
148    }
149
150    #[test]
151    fn fmap_cofree() {
152        let cofree = make_stream(&[1, 2, 3]);
153        let mapped = cofree.fmap(|x| x * 10);
154        assert_eq!(mapped.head, 10);
155        let child1 = mapped.tail.as_ref().as_ref().unwrap();
156        assert_eq!(child1.head, 20);
157        let child2 = child1.tail.as_ref().as_ref().unwrap();
158        assert_eq!(child2.head, 30);
159    }
160
161    #[test]
162    fn unfold_option_stream() {
163        // Unfold a countdown: 3, 2, 1, done
164        let cofree = Cofree::<OptionF, i32>::unfold(3, |&seed| {
165            if seed <= 0 {
166                (seed, None)
167            } else {
168                (seed, Some(seed - 1))
169            }
170        });
171        assert_eq!(cofree.head, 3);
172        let c2 = cofree.tail.as_ref().as_ref().unwrap();
173        assert_eq!(c2.head, 2);
174        let c1 = c2.tail.as_ref().as_ref().unwrap();
175        assert_eq!(c1.head, 1);
176        let c0 = c1.tail.as_ref().as_ref().unwrap();
177        assert_eq!(c0.head, 0);
178        assert!(c0.tail.is_none());
179    }
180
181    #[test]
182    fn unfold_then_extract() {
183        let cofree = Cofree::<OptionF, i32>::unfold(42, |&seed| (seed, None));
184        assert_eq!(cofree.extract(), 42);
185    }
186
187    #[test]
188    fn extend_cofree() {
189        // Build a stream [1, 2, 3] and extend with a function that sums the head
190        // of the current node and the next node (if any)
191        let cofree = make_stream(&[1, 2, 3]);
192        let extended = cofree.extend(|w| {
193            let next = w.tail.as_ref().as_ref().map(|c| c.head).unwrap_or(0);
194            w.head + next
195        });
196        assert_eq!(extended.head, 1 + 2); // 3
197        let c1 = extended.tail.as_ref().as_ref().unwrap();
198        assert_eq!(c1.head, 2 + 3); // 5
199        let c2 = c1.tail.as_ref().as_ref().unwrap();
200        assert_eq!(c2.head, 3 + 0); // 3
201    }
202
203    #[test]
204    fn comonad_trait_works() {
205        let cofree = make_stream(&[42, 1]);
206        let result = <CofreeF<OptionF> as Comonad>::extract(&cofree);
207        assert_eq!(result, 42);
208    }
209
210    #[test]
211    fn extend_trait_works() {
212        let cofree = make_stream(&[10, 20]);
213        let extended = <CofreeF<OptionF> as Extend>::extend(cofree, |w| w.head * 2);
214        assert_eq!(extended.head, 20);
215        let child = extended.tail.as_ref().as_ref().unwrap();
216        assert_eq!(child.head, 40);
217    }
218
219    #[test]
220    fn functor_trait_works() {
221        let cofree = make_stream(&[5]);
222        let mapped = <CofreeF<OptionF> as Functor>::fmap(cofree, |x| x + 100);
223        assert_eq!(mapped.head, 105);
224    }
225}
226
227#[cfg(test)]
228mod law_tests {
229    use super::*;
230    use karpal_core::hkt::OptionF;
231    use proptest::prelude::*;
232
233    fn make_singleton(v: i32) -> Cofree<OptionF, i32> {
234        Cofree::new(v, None)
235    }
236
237    fn make_pair(a: i32, b: i32) -> Cofree<OptionF, i32> {
238        Cofree::new(a, Some(Cofree::new(b, None)))
239    }
240
241    proptest! {
242        // Comonad law: extract(extend(w, f)) == f(w)
243        #[test]
244        fn extract_extend(x in any::<i32>(), y in any::<i32>()) {
245            let w = make_pair(x, y);
246            let f = |w: &Cofree<OptionF, i32>| w.head.wrapping_mul(2);
247            let expected = f(&w);
248            let extended = <CofreeF<OptionF> as Extend>::extend(w, f);
249            let result = <CofreeF<OptionF> as Comonad>::extract(&extended);
250            prop_assert_eq!(result, expected);
251        }
252
253        // Comonad law: extend(w, extract) == w
254        #[test]
255        fn extend_extract(x in any::<i32>()) {
256            let w = make_singleton(x);
257            let result = <CofreeF<OptionF> as Extend>::extend(
258                w,
259                <CofreeF<OptionF> as Comonad>::extract,
260            );
261            prop_assert_eq!(result.head, x);
262        }
263
264        // Functor identity: fmap(id, w) == w
265        #[test]
266        fn functor_identity(x in any::<i32>(), y in any::<i32>()) {
267            let w = make_pair(x, y);
268            let result = <CofreeF<OptionF> as Functor>::fmap(w, |a| a);
269            prop_assert_eq!(result.head, x);
270            prop_assert_eq!(result.tail.as_ref().as_ref().unwrap().head, y);
271        }
272
273        // Functor composition: fmap(g . f, w) == fmap(g, fmap(f, w))
274        #[test]
275        fn functor_composition(x in any::<i32>()) {
276            let f = |a: i32| a.wrapping_add(1);
277            let g = |a: i32| a.wrapping_mul(2);
278
279            let left = <CofreeF<OptionF> as Functor>::fmap(make_singleton(x), |a| g(f(a)));
280            let right = <CofreeF<OptionF> as Functor>::fmap(
281                <CofreeF<OptionF> as Functor>::fmap(make_singleton(x), f),
282                g,
283            );
284            prop_assert_eq!(left.head, right.head);
285        }
286    }
287}