1use crate::optic::Optic;
2use karpal_core::Monoid;
3
4pub struct Fold<S, A> {
8 #[allow(clippy::type_complexity)]
9 fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
10}
11
12impl<S, A> Optic for Fold<S, A> {}
13
14impl<S, A> Fold<S, A> {
15 pub fn new(fold_fn: impl Fn(&S) -> Vec<A> + 'static) -> Self {
16 Self {
17 fold_fn: Box::new(fold_fn),
18 }
19 }
20
21 pub fn get_all(&self, s: &S) -> Vec<A> {
22 (self.fold_fn)(s)
23 }
24
25 pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
27 (self.fold_fn)(s)
28 .into_iter()
29 .map(&f)
30 .fold(R::empty(), |acc, r| acc.combine(r))
31 }
32
33 pub fn any(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
35 (self.fold_fn)(s).iter().any(&f)
36 }
37
38 pub fn all(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
40 (self.fold_fn)(s).iter().all(&f)
41 }
42
43 pub fn find(&self, s: &S, f: impl Fn(&A) -> bool) -> Option<A> {
45 (self.fold_fn)(s).into_iter().find(|a| f(a))
46 }
47
48 pub fn length(&self, s: &S) -> usize {
50 (self.fold_fn)(s).len()
51 }
52
53 pub fn then<B>(self, inner: Fold<A, B>) -> ComposedFold<S, B>
55 where
56 S: 'static,
57 A: 'static,
58 B: 'static,
59 {
60 let outer_fn = self.fold_fn;
61 let inner_fn = inner.fold_fn;
62 ComposedFold {
63 fold_fn: Box::new(move |s| {
64 outer_fn(s).into_iter().flat_map(|a| inner_fn(&a)).collect()
65 }),
66 }
67 }
68}
69
70pub struct ComposedFold<S, A> {
72 #[allow(clippy::type_complexity)]
73 fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
74}
75
76impl<S, A> Optic for ComposedFold<S, A> {}
77
78impl<S, A> ComposedFold<S, A> {
79 pub fn get_all(&self, s: &S) -> Vec<A> {
80 (self.fold_fn)(s)
81 }
82
83 pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
84 (self.fold_fn)(s)
85 .into_iter()
86 .map(&f)
87 .fold(R::empty(), |acc, r| acc.combine(r))
88 }
89
90 pub fn length(&self, s: &S) -> usize {
91 (self.fold_fn)(s).len()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 fn vec_fold() -> Fold<Vec<i32>, i32> {
100 Fold::new(|v: &Vec<i32>| v.clone())
101 }
102
103 #[test]
104 fn fold_get_all() {
105 let fold = vec_fold();
106 assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
107 }
108
109 #[test]
110 fn fold_map_sum() {
111 let fold = vec_fold();
112 let sum: i32 = fold.fold_map(&vec![1, 2, 3], |x| x);
113 assert_eq!(sum, 6);
114 }
115
116 #[test]
117 fn fold_map_string() {
118 let fold = vec_fold();
119 let result: String = fold.fold_map(&vec![1, 2, 3], |x| x.to_string());
120 assert_eq!(result, "123");
121 }
122
123 #[test]
124 fn fold_any() {
125 let fold = vec_fold();
126 assert!(fold.any(&vec![1, 2, 3], |x| *x > 2));
127 assert!(!fold.any(&vec![1, 2, 3], |x| *x > 5));
128 }
129
130 #[test]
131 fn fold_all() {
132 let fold = vec_fold();
133 assert!(fold.all(&vec![1, 2, 3], |x| *x > 0));
134 assert!(!fold.all(&vec![1, 2, 3], |x| *x > 1));
135 }
136
137 #[test]
138 fn fold_find() {
139 let fold = vec_fold();
140 assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 1), Some(2));
141 assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 5), None);
142 }
143
144 #[test]
145 fn fold_length() {
146 let fold = vec_fold();
147 assert_eq!(fold.length(&vec![1, 2, 3]), 3);
148 assert_eq!(fold.length(&Vec::<i32>::new()), 0);
149 }
150
151 #[test]
152 fn fold_from_traversal() {
153 use crate::traversal::Traversal;
154 let trav = Traversal::new(
155 |v: &Vec<i32>| v.clone(),
156 |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect::<Vec<_>>(),
157 );
158 let fold = trav.to_fold();
159 assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
160 }
161
162 #[test]
163 fn fold_from_lens() {
164 use crate::lens::Lens;
165
166 #[derive(Clone)]
167 struct Point {
168 x: i32,
169 }
170 let lens = Lens::new(|p: &Point| p.x, |_p: Point, x| Point { x });
171 let fold = lens.to_fold();
172 assert_eq!(fold.get_all(&Point { x: 42 }), vec![42]);
173 }
174}