1use crate::{Fold, FoldContext, FoldOutcome};
4
5pub struct SequentialFold<L, S1, S2, F1, F2, M>
13where
14 F1: Fold<L, S1>,
15 F2: Fold<L, S2>,
16 M: Fn(&S1, &FoldContext) -> FoldContext,
17{
18 first: F1,
19 second: F2,
20 context_mapper: M,
21 _phantom: std::marker::PhantomData<(L, S1, S2)>,
22}
23
24impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
25where
26 F1: Fold<L, S1>,
27 F2: Fold<L, S2>,
28 M: Fn(&S1, &FoldContext) -> FoldContext,
29{
30 pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
32 Self {
33 first,
34 second,
35 context_mapper,
36 _phantom: std::marker::PhantomData,
37 }
38 }
39
40 pub fn execute<'a, I>(
42 &self,
43 entries: I,
44 context: &FoldContext,
45 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
46 where
47 I: IntoIterator<Item = &'a L> + Clone,
48 L: 'a,
49 {
50 let result1 = self.first.derive(entries.clone(), context);
51 let context2 = (self.context_mapper)(&result1.state, context);
52 let result2 = self.second.derive(entries, &context2);
53 (result1, result2)
54 }
55}
56
57pub struct DualFold<L, S1, S2, F1, F2>
62where
63 F1: Fold<L, S1>,
64 F2: Fold<L, S2>,
65{
66 fold1: F1,
67 fold2: F2,
68 _phantom: std::marker::PhantomData<(L, S1, S2)>,
69}
70
71impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
72where
73 F1: Fold<L, S1>,
74 F2: Fold<L, S2>,
75{
76 pub fn new(fold1: F1, fold2: F2) -> Self {
78 Self {
79 fold1,
80 fold2,
81 _phantom: std::marker::PhantomData,
82 }
83 }
84
85 pub fn execute<'a, I>(
87 &self,
88 entries: I,
89 context: &FoldContext,
90 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
91 where
92 I: IntoIterator<Item = &'a L> + Clone,
93 L: 'a,
94 {
95 let result1 = self.fold1.derive(entries.clone(), context);
96 let result2 = self.fold2.derive(entries, context);
97 (result1, result2)
98 }
99}
100
101pub struct FilterFold<L, S, F, P>
103where
104 F: Fold<L, S>,
105 P: Fn(&L) -> bool,
106{
107 inner: F,
108 predicate: P,
109 _phantom: std::marker::PhantomData<(L, S)>,
110}
111
112impl<L, S, F, P> FilterFold<L, S, F, P>
113where
114 F: Fold<L, S>,
115 P: Fn(&L) -> bool,
116{
117 pub fn new(inner: F, predicate: P) -> Self {
119 Self {
120 inner,
121 predicate,
122 _phantom: std::marker::PhantomData,
123 }
124 }
125}
126
127impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
128where
129 L: Send + Sync,
130 S: Send + Sync,
131 F: Fold<L, S>,
132 P: Fn(&L) -> bool + Send + Sync,
133{
134 fn init(&self, context: &FoldContext) -> S {
135 self.inner.init(context)
136 }
137
138 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
139 if (self.predicate)(entry) {
140 self.inner.reduce(state, entry, context)
141 } else {
142 state
143 }
144 }
145
146 fn finalize(&self, state: S, context: &FoldContext) -> S {
147 self.inner.finalize(state, context)
148 }
149}
150
151pub struct MapFold<L1, L2, S, F, M>
153where
154 F: Fold<L2, S>,
155 M: Fn(&L1) -> L2,
156{
157 inner: F,
158 mapper: M,
159 _phantom: std::marker::PhantomData<(L1, L2, S)>,
160}
161
162impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
163where
164 F: Fold<L2, S>,
165 M: Fn(&L1) -> L2,
166{
167 pub fn new(inner: F, mapper: M) -> Self {
169 Self {
170 inner,
171 mapper,
172 _phantom: std::marker::PhantomData,
173 }
174 }
175}
176
177impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
178where
179 L1: Send + Sync,
180 L2: Send + Sync,
181 S: Send + Sync,
182 F: Fold<L2, S>,
183 M: Fn(&L1) -> L2 + Send + Sync,
184{
185 fn init(&self, context: &FoldContext) -> S {
186 self.inner.init(context)
187 }
188
189 fn reduce(&self, state: S, entry: &L1, context: &FoldContext) -> S {
190 let mapped = (self.mapper)(entry);
191 self.inner.reduce(state, &mapped, context)
192 }
193
194 fn finalize(&self, state: S, context: &FoldContext) -> S {
195 self.inner.finalize(state, context)
196 }
197}
198
199pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
201where
202 F: Fold<L, S>,
203 P: Fn(&L) -> bool,
204{
205 FilterFold::new(inner, predicate)
206}
207
208pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
210where
211 F: Fold<L2, S>,
212 M: Fn(&L1) -> L2,
213{
214 MapFold::new(inner, mapper)
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::fold::fold_fn;
221
222 #[test]
223 fn test_filter_fold() {
224 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
225 let filtered = filter(counter, |e: &i32| *e % 2 == 0);
226 let entries = [1, 2, 3, 4, 5, 6];
227 let result = filtered.derive(entries.iter(), &FoldContext::new());
228 assert_eq!(result.state, 3);
229 }
230
231 #[test]
232 fn test_map_fold() {
233 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
234 let doubled = map(summer, |e: &i32| e * 2);
235 let entries = [1, 2, 3];
236 let result = doubled.derive(entries.iter(), &FoldContext::new());
237 assert_eq!(result.state, 12);
238 }
239
240 #[test]
241 fn test_dual_fold() {
242 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
243 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
244 let dual = DualFold::new(summer, counter);
245 let entries = [1, 2, 3, 4, 5];
246 let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
247 assert_eq!(sum_result.state, 15);
248 assert_eq!(count_result.state, 5);
249 }
250
251 #[test]
252 fn test_sequential_fold() {
253 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
254 let summer = fold_fn(
255 |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
256 |sum, entry: &i32, _ctx| sum + entry,
257 );
258 let sequential = SequentialFold::new(counter, summer, |count, ctx| {
259 let mut new_ctx = ctx.clone();
260 *new_ctx.extra_mut() = serde_json::json!({"count": *count});
261 new_ctx
262 });
263 let entries = [1, 2, 3];
264 let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
265 assert_eq!(count_result.state, 3);
266 assert_eq!(sum_result.state, 9);
267 }
268}