1use crate::optic::Optic;
5use karpal_core::Monoid;
6
7pub struct Fold<S, A> {
11 #[allow(clippy::type_complexity)]
12 fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
13}
14
15impl<S, A> Optic for Fold<S, A> {}
16
17impl<S, A> Fold<S, A> {
18 pub fn new(fold_fn: impl Fn(&S) -> Vec<A> + 'static) -> Self {
19 Self {
20 fold_fn: Box::new(fold_fn),
21 }
22 }
23
24 pub fn get_all(&self, s: &S) -> Vec<A> {
25 (self.fold_fn)(s)
26 }
27
28 pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
30 (self.fold_fn)(s)
31 .into_iter()
32 .map(&f)
33 .fold(R::empty(), |acc, r| acc.combine(r))
34 }
35
36 pub fn any(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
38 (self.fold_fn)(s).iter().any(&f)
39 }
40
41 pub fn all(&self, s: &S, f: impl Fn(&A) -> bool) -> bool {
43 (self.fold_fn)(s).iter().all(&f)
44 }
45
46 pub fn find(&self, s: &S, f: impl Fn(&A) -> bool) -> Option<A> {
48 (self.fold_fn)(s).into_iter().find(|a| f(a))
49 }
50
51 pub fn length(&self, s: &S) -> usize {
53 (self.fold_fn)(s).len()
54 }
55
56 pub fn then<B>(self, inner: Fold<A, B>) -> ComposedFold<S, B>
58 where
59 S: 'static,
60 A: 'static,
61 B: 'static,
62 {
63 let outer_fn = self.fold_fn;
64 let inner_fn = inner.fold_fn;
65 ComposedFold {
66 fold_fn: Box::new(move |s| {
67 outer_fn(s).into_iter().flat_map(|a| inner_fn(&a)).collect()
68 }),
69 }
70 }
71}
72
73pub struct ComposedFold<S, A> {
75 #[allow(clippy::type_complexity)]
76 fold_fn: Box<dyn Fn(&S) -> Vec<A>>,
77}
78
79impl<S, A> Optic for ComposedFold<S, A> {}
80
81impl<S, A> ComposedFold<S, A> {
82 pub fn get_all(&self, s: &S) -> Vec<A> {
83 (self.fold_fn)(s)
84 }
85
86 pub fn fold_map<R: Monoid>(&self, s: &S, f: impl Fn(A) -> R) -> R {
87 (self.fold_fn)(s)
88 .into_iter()
89 .map(&f)
90 .fold(R::empty(), |acc, r| acc.combine(r))
91 }
92
93 pub fn length(&self, s: &S) -> usize {
94 (self.fold_fn)(s).len()
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 fn vec_fold() -> Fold<Vec<i32>, i32> {
103 Fold::new(|v: &Vec<i32>| v.clone())
104 }
105
106 #[test]
107 fn fold_get_all() {
108 let fold = vec_fold();
109 assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
110 }
111
112 #[test]
113 fn fold_map_sum() {
114 let fold = vec_fold();
115 let sum: i32 = fold.fold_map(&vec![1, 2, 3], |x| x);
116 assert_eq!(sum, 6);
117 }
118
119 #[test]
120 fn fold_map_string() {
121 let fold = vec_fold();
122 let result: String = fold.fold_map(&vec![1, 2, 3], |x| x.to_string());
123 assert_eq!(result, "123");
124 }
125
126 #[test]
127 fn fold_any() {
128 let fold = vec_fold();
129 assert!(fold.any(&vec![1, 2, 3], |x| *x > 2));
130 assert!(!fold.any(&vec![1, 2, 3], |x| *x > 5));
131 }
132
133 #[test]
134 fn fold_all() {
135 let fold = vec_fold();
136 assert!(fold.all(&vec![1, 2, 3], |x| *x > 0));
137 assert!(!fold.all(&vec![1, 2, 3], |x| *x > 1));
138 }
139
140 #[test]
141 fn fold_find() {
142 let fold = vec_fold();
143 assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 1), Some(2));
144 assert_eq!(fold.find(&vec![1, 2, 3], |x| *x > 5), None);
145 }
146
147 #[test]
148 fn fold_length() {
149 let fold = vec_fold();
150 assert_eq!(fold.length(&vec![1, 2, 3]), 3);
151 assert_eq!(fold.length(&Vec::<i32>::new()), 0);
152 }
153
154 #[test]
155 fn fold_from_traversal() {
156 use crate::traversal::Traversal;
157 let trav = Traversal::new(
158 |v: &Vec<i32>| v.clone(),
159 |v: Vec<i32>, f: &dyn Fn(i32) -> i32| v.into_iter().map(f).collect::<Vec<_>>(),
160 );
161 let fold = trav.to_fold();
162 assert_eq!(fold.get_all(&vec![1, 2, 3]), vec![1, 2, 3]);
163 }
164
165 #[test]
166 fn fold_from_lens() {
167 use crate::lens::Lens;
168
169 #[derive(Clone)]
170 struct Point {
171 x: i32,
172 }
173 let lens = Lens::new(|p: &Point| p.x, |_p: Point, x| Point { x });
174 let fold = lens.to_fold();
175 assert_eq!(fold.get_all(&Point { x: 42 }), vec![42]);
176 }
177}