1use crate::{Fold, FoldContext, FoldOutcome};
4
5pub struct SequentialFold<L, S1, S2, F1, F2, M>
13where
14 F1: Fold<L, S1>,
15 F2: Fold<L, S2>,
16 M: Fn(&S1, &FoldContext) -> FoldContext,
17{
18 first: F1,
19 second: F2,
20 context_mapper: M,
21 _phantom: std::marker::PhantomData<(L, S1, S2)>,
22}
23
24impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
25where
26 F1: Fold<L, S1>,
27 F2: Fold<L, S2>,
28 M: Fn(&S1, &FoldContext) -> FoldContext,
29{
30 pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
32 Self {
33 first,
34 second,
35 context_mapper,
36 _phantom: std::marker::PhantomData,
37 }
38 }
39
40 pub fn execute<'a, I>(
42 &self,
43 entries: I,
44 context: &FoldContext,
45 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
46 where
47 I: IntoIterator<Item = &'a L> + Clone,
48 L: 'a,
49 {
50 let result1 = self.first.derive(entries.clone(), context);
51 let context2 = (self.context_mapper)(&result1.state, context);
52 let result2 = self.second.derive(entries, &context2);
53 (result1, result2)
54 }
55}
56
57pub struct DualFold<L, S1, S2, F1, F2>
62where
63 F1: Fold<L, S1>,
64 F2: Fold<L, S2>,
65{
66 fold1: F1,
67 fold2: F2,
68 _phantom: std::marker::PhantomData<(L, S1, S2)>,
69}
70
71impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
72where
73 F1: Fold<L, S1>,
74 F2: Fold<L, S2>,
75{
76 pub fn new(fold1: F1, fold2: F2) -> Self {
78 Self {
79 fold1,
80 fold2,
81 _phantom: std::marker::PhantomData,
82 }
83 }
84
85 pub fn execute<'a, I>(
87 &self,
88 entries: I,
89 context: &FoldContext,
90 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
91 where
92 I: IntoIterator<Item = &'a L> + Clone,
93 L: 'a,
94 {
95 let result1 = self.fold1.derive(entries.clone(), context);
96 let result2 = self.fold2.derive(entries, context);
97 (result1, result2)
98 }
99}
100
101pub struct FilterFold<L, S, F, P>
103where
104 F: Fold<L, S>,
105 P: Fn(&L) -> bool,
106{
107 inner: F,
108 predicate: P,
109 _phantom: std::marker::PhantomData<(L, S)>,
110}
111
112impl<L, S, F, P> FilterFold<L, S, F, P>
113where
114 F: Fold<L, S>,
115 P: Fn(&L) -> bool,
116{
117 pub fn new(inner: F, predicate: P) -> Self {
119 Self {
120 inner,
121 predicate,
122 _phantom: std::marker::PhantomData,
123 }
124 }
125}
126
127impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
128where
129 F: Fold<L, S>,
130 P: Fn(&L) -> bool,
131{
132 fn initial(&self, context: &FoldContext) -> S {
133 self.inner.initial(context)
134 }
135
136 fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
137 if (self.predicate)(entry) {
138 self.inner.step(state, entry, context)
139 } else {
140 state
141 }
142 }
143
144 fn finalize(&self, state: S, context: &FoldContext) -> S {
145 self.inner.finalize(state, context)
146 }
147}
148
149pub struct MapFold<L1, L2, S, F, M>
151where
152 F: Fold<L2, S>,
153 M: Fn(&L1) -> L2,
154{
155 inner: F,
156 mapper: M,
157 _phantom: std::marker::PhantomData<(L1, L2, S)>,
158}
159
160impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
161where
162 F: Fold<L2, S>,
163 M: Fn(&L1) -> L2,
164{
165 pub fn new(inner: F, mapper: M) -> Self {
167 Self {
168 inner,
169 mapper,
170 _phantom: std::marker::PhantomData,
171 }
172 }
173}
174
175impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
176where
177 F: Fold<L2, S>,
178 M: Fn(&L1) -> L2,
179{
180 fn initial(&self, context: &FoldContext) -> S {
181 self.inner.initial(context)
182 }
183
184 fn step(&self, state: S, entry: &L1, context: &FoldContext) -> S {
185 let mapped = (self.mapper)(entry);
186 self.inner.step(state, &mapped, context)
187 }
188
189 fn finalize(&self, state: S, context: &FoldContext) -> S {
190 self.inner.finalize(state, context)
191 }
192}
193
194pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
196where
197 F: Fold<L, S>,
198 P: Fn(&L) -> bool,
199{
200 FilterFold::new(inner, predicate)
201}
202
203pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
205where
206 F: Fold<L2, S>,
207 M: Fn(&L1) -> L2,
208{
209 MapFold::new(inner, mapper)
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::fold::fold_fn;
216
217 #[test]
218 fn test_filter_fold() {
219 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
220 let filtered = filter(counter, |e: &i32| *e % 2 == 0);
221 let entries = [1, 2, 3, 4, 5, 6];
222 let result = filtered.derive(entries.iter(), &FoldContext::new());
223 assert_eq!(result.state, 3);
224 }
225
226 #[test]
227 fn test_map_fold() {
228 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
229 let doubled = map(summer, |e: &i32| e * 2);
230 let entries = [1, 2, 3];
231 let result = doubled.derive(entries.iter(), &FoldContext::new());
232 assert_eq!(result.state, 12);
233 }
234
235 #[test]
236 fn test_dual_fold() {
237 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
238 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
239 let dual = DualFold::new(summer, counter);
240 let entries = [1, 2, 3, 4, 5];
241 let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
242 assert_eq!(sum_result.state, 15);
243 assert_eq!(count_result.state, 5);
244 }
245
246 #[test]
247 fn test_sequential_fold() {
248 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
249 let summer = fold_fn(
250 |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
251 |sum, entry: &i32, _ctx| sum + entry,
252 );
253 let sequential = SequentialFold::new(counter, summer, |count, ctx| {
254 let mut new_ctx = ctx.clone();
255 *new_ctx.extra_mut() = serde_json::json!({"count": *count});
256 new_ctx
257 });
258 let entries = [1, 2, 3];
259 let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
260 assert_eq!(count_result.state, 3);
261 assert_eq!(sum_result.state, 9);
262 }
263}