Skip to main content

karpal_optics/
fold.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::optic::Optic;
5use karpal_core::Monoid;
6
7/// A read-only multi-focus optic.
8///
9/// Like a `Traversal` but without the ability to modify.
10pub struct Fold<S, A> {
11    #[allow(clippy::type_complexity)]
12    fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
13}
14
15impl<S, A> Optic for Fold<S, A> {}
16
17impl<S, A> Fold<S, A> {
18    pub fn new(fold_fn: impl Fn(&S) -> Vec<A> + 'static) -> Self {
19        Self {
20            fold_fn: Box::new(fold_fn),
21        }
22    }
23
24    pub fn get_all(&self, s: &S) -> Vec<A> {
25        (self.fold_fn)(s)
26    }
27
28    /// Map each focus to a monoid value and combine them.
29    pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
30        (self.fold_fn)(s)
31            .into_iter()
32            .map(&f)
33            .fold(R::empty(), |acc, r| acc.combine(r))
34    }
35
36    /// Check if any focus satisfies a predicate.
37    pub fn any(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
38        (self.fold_fn)(s).iter().any(&f)
39    }
40
41    /// Check if all foci satisfy a predicate.
42    pub fn all(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
43        (self.fold_fn)(s).iter().all(&f)
44    }
45
46    /// Find the first focus satisfying a predicate.
47    pub fn find(&self, s: &S, f: impl Fn(&A) -> bool) -> Option<A> {
48        (self.fold_fn)(s).into_iter().find(|a| f(a))
49    }
50
51    /// Count the number of foci.
52    pub fn length(&self, s: &S) -> usize {
53        (self.fold_fn)(s).len()
54    }
55
56    /// Compose with another fold for deeper read-only access.
57    pub fn then<B>(self, inner: Fold<A, B>) -> ComposedFold<S, B>
58    where
59        S: 'static,
60        A: 'static,
61        B: 'static,
62    {
63        let outer_fn = self.fold_fn;
64        let inner_fn = inner.fold_fn;
65        ComposedFold {
66            fold_fn: Box::new(move |s| {
67                outer_fn(s).into_iter().flat_map(|a| inner_fn(&a)).collect()
68            }),
69        }
70    }
71}
72
73/// A composed fold using boxed closures.
74pub struct ComposedFold<S, A> {
75    #[allow(clippy::type_complexity)]
76    fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
77}
78
79impl<S, A> Optic for ComposedFold<S, A> {}
80
81impl<S, A> ComposedFold<S, A> {
82    pub fn get_all(&self, s: &S) -> Vec<A> {
83        (self.fold_fn)(s)
84    }
85
86    pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
87        (self.fold_fn)(s)
88            .into_iter()
89            .map(&f)
90            .fold(R::empty(), |acc, r| acc.combine(r))
91    }
92
93    pub fn length(&self, s: &S) -> usize {
94        (self.fold_fn)(s).len()
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    fn vec_fold() -> Fold<Vec<i32>, i32> {
103        Fold::new(|v: &Vec<i32>| v.clone())
104    }
105
106    #[test]
107    fn fold_get_all() {
108        let fold = vec_fold();
109        assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
110    }
111
112    #[test]
113    fn fold_map_sum() {
114        let fold = vec_fold();
115        let sum: i32 = fold.fold_map(&vec![1, 2, 3], |x| x);
116        assert_eq!(sum, 6);
117    }
118
119    #[test]
120    fn fold_map_string() {
121        let fold = vec_fold();
122        let result: String = fold.fold_map(&vec![1, 2, 3], |x| x.to_string());
123        assert_eq!(result, "123");
124    }
125
126    #[test]
127    fn fold_any() {
128        let fold = vec_fold();
129        assert!(fold.any(&vec![1, 2, 3], |x| *x > 2));
130        assert!(!fold.any(&vec![1, 2, 3], |x| *x > 5));
131    }
132
133    #[test]
134    fn fold_all() {
135        let fold = vec_fold();
136        assert!(fold.all(&vec![1, 2, 3], |x| *x > 0));
137        assert!(!fold.all(&vec![1, 2, 3], |x| *x > 1));
138    }
139
140    #[test]
141    fn fold_find() {
142        let fold = vec_fold();
143        assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 1), Some(2));
144        assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 5), None);
145    }
146
147    #[test]
148    fn fold_length() {
149        let fold = vec_fold();
150        assert_eq!(fold.length(&vec![1, 2, 3]), 3);
151        assert_eq!(fold.length(&Vec::<i32>::new()), 0);
152    }
153
154    #[test]
155    fn fold_from_traversal() {
156        use crate::traversal::Traversal;
157        let trav = Traversal::new(
158            |v: &Vec<i32>| v.clone(),
159            |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect::<Vec<_>>(),
160        );
161        let fold = trav.to_fold();
162        assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
163    }
164
165    #[test]
166    fn fold_from_lens() {
167        use crate::lens::Lens;
168
169        #[derive(Clone)]
170        struct Point {
171            x: i32,
172        }
173        let lens = Lens::new(|p: &Point| p.x, |_p: Point, x| Point { x });
174        let fold = lens.to_fold();
175        assert_eq!(fold.get_all(&Point { x: 42 }), vec![42]);
176    }
177}