Skip to main content

karpal_optics/
fold.rs

1use crate::optic::Optic;
2use karpal_core::Monoid;
3
4/// A read-only multi-focus optic.
5///
6/// Like a `Traversal` but without the ability to modify.
7pub struct Fold<S, A> {
8    #[allow(clippy::type_complexity)]
9    fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
10}
11
12impl<S, A> Optic for Fold<S, A> {}
13
14impl<S, A> Fold<S, A> {
15    pub fn new(fold_fn: impl Fn(&S) -> Vec<A> + 'static) -> Self {
16        Self {
17            fold_fn: Box::new(fold_fn),
18        }
19    }
20
21    pub fn get_all(&self, s: &S) -> Vec<A> {
22        (self.fold_fn)(s)
23    }
24
25    /// Map each focus to a monoid value and combine them.
26    pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
27        (self.fold_fn)(s)
28            .into_iter()
29            .map(&f)
30            .fold(R::empty(), |acc, r| acc.combine(r))
31    }
32
33    /// Check if any focus satisfies a predicate.
34    pub fn any(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
35        (self.fold_fn)(s).iter().any(&f)
36    }
37
38    /// Check if all foci satisfy a predicate.
39    pub fn all(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
40        (self.fold_fn)(s).iter().all(&f)
41    }
42
43    /// Find the first focus satisfying a predicate.
44    pub fn find(&self, s: &S, f: impl Fn(&A) -> bool) -> Option<A> {
45        (self.fold_fn)(s).into_iter().find(|a| f(a))
46    }
47
48    /// Count the number of foci.
49    pub fn length(&self, s: &S) -> usize {
50        (self.fold_fn)(s).len()
51    }
52
53    /// Compose with another fold for deeper read-only access.
54    pub fn then<B>(self, inner: Fold<A, B>) -> ComposedFold<S, B>
55    where
56        S: 'static,
57        A: 'static,
58        B: 'static,
59    {
60        let outer_fn = self.fold_fn;
61        let inner_fn = inner.fold_fn;
62        ComposedFold {
63            fold_fn: Box::new(move |s| {
64                outer_fn(s).into_iter().flat_map(|a| inner_fn(&a)).collect()
65            }),
66        }
67    }
68}
69
70/// A composed fold using boxed closures.
71pub struct ComposedFold<S, A> {
72    #[allow(clippy::type_complexity)]
73    fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
74}
75
76impl<S, A> Optic for ComposedFold<S, A> {}
77
78impl<S, A> ComposedFold<S, A> {
79    pub fn get_all(&self, s: &S) -> Vec<A> {
80        (self.fold_fn)(s)
81    }
82
83    pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
84        (self.fold_fn)(s)
85            .into_iter()
86            .map(&f)
87            .fold(R::empty(), |acc, r| acc.combine(r))
88    }
89
90    pub fn length(&self, s: &S) -> usize {
91        (self.fold_fn)(s).len()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    fn vec_fold() -> Fold<Vec<i32>, i32> {
100        Fold::new(|v: &Vec<i32>| v.clone())
101    }
102
103    #[test]
104    fn fold_get_all() {
105        let fold = vec_fold();
106        assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
107    }
108
109    #[test]
110    fn fold_map_sum() {
111        let fold = vec_fold();
112        let sum: i32 = fold.fold_map(&vec![1, 2, 3], |x| x);
113        assert_eq!(sum, 6);
114    }
115
116    #[test]
117    fn fold_map_string() {
118        let fold = vec_fold();
119        let result: String = fold.fold_map(&vec![1, 2, 3], |x| x.to_string());
120        assert_eq!(result, "123");
121    }
122
123    #[test]
124    fn fold_any() {
125        let fold = vec_fold();
126        assert!(fold.any(&vec![1, 2, 3], |x| *x > 2));
127        assert!(!fold.any(&vec![1, 2, 3], |x| *x > 5));
128    }
129
130    #[test]
131    fn fold_all() {
132        let fold = vec_fold();
133        assert!(fold.all(&vec![1, 2, 3], |x| *x > 0));
134        assert!(!fold.all(&vec![1, 2, 3], |x| *x > 1));
135    }
136
137    #[test]
138    fn fold_find() {
139        let fold = vec_fold();
140        assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 1), Some(2));
141        assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 5), None);
142    }
143
144    #[test]
145    fn fold_length() {
146        let fold = vec_fold();
147        assert_eq!(fold.length(&vec![1, 2, 3]), 3);
148        assert_eq!(fold.length(&Vec::<i32>::new()), 0);
149    }
150
151    #[test]
152    fn fold_from_traversal() {
153        use crate::traversal::Traversal;
154        let trav = Traversal::new(
155            |v: &Vec<i32>| v.clone(),
156            |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect::<Vec<_>>(),
157        );
158        let fold = trav.to_fold();
159        assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
160    }
161
162    #[test]
163    fn fold_from_lens() {
164        use crate::lens::Lens;
165
166        #[derive(Clone)]
167        struct Point {
168            x: i32,
169        }
170        let lens = Lens::new(|p: &Point| p.x, |_p: Point, x| Point { x });
171        let fold = lens.to_fold();
172        assert_eq!(fold.get_all(&Point { x: 42 }), vec![42]);
173    }
174}