1use crate::{Fold, FoldContext, FoldOutcome};
4
5pub struct SequentialFold<L, S1, S2, F1, F2, M>
7where
8 F1: Fold<L, S1>,
9 F2: Fold<L, S2>,
10 M: Fn(&S1, &FoldContext) -> FoldContext,
11{
12 first: F1,
13 second: F2,
14 context_mapper: M,
15 _phantom: std::marker::PhantomData<(L, S1, S2)>,
16}
17
18impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
19where
20 F1: Fold<L, S1>,
21 F2: Fold<L, S2>,
22 M: Fn(&S1, &FoldContext) -> FoldContext,
23{
24 pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
26 Self {
27 first,
28 second,
29 context_mapper,
30 _phantom: std::marker::PhantomData,
31 }
32 }
33
34 pub fn execute<'a, I>(
36 &self,
37 entries: I,
38 context: &FoldContext,
39 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
40 where
41 I: IntoIterator<Item = &'a L> + Clone,
42 L: 'a,
43 {
44 let result1 = self.first.derive(entries.clone(), context);
45 let context2 = (self.context_mapper)(&result1.state, context);
46 let result2 = self.second.derive(entries, &context2);
47 (result1, result2)
48 }
49}
50
51pub struct DualFold<L, S1, S2, F1, F2>
53where
54 F1: Fold<L, S1>,
55 F2: Fold<L, S2>,
56{
57 fold1: F1,
58 fold2: F2,
59 _phantom: std::marker::PhantomData<(L, S1, S2)>,
60}
61
62impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
63where
64 F1: Fold<L, S1>,
65 F2: Fold<L, S2>,
66{
67 pub fn new(fold1: F1, fold2: F2) -> Self {
69 Self {
70 fold1,
71 fold2,
72 _phantom: std::marker::PhantomData,
73 }
74 }
75
76 pub fn execute<'a, I>(
78 &self,
79 entries: I,
80 context: &FoldContext,
81 ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
82 where
83 I: IntoIterator<Item = &'a L> + Clone,
84 L: 'a,
85 {
86 let result1 = self.fold1.derive(entries.clone(), context);
87 let result2 = self.fold2.derive(entries, context);
88 (result1, result2)
89 }
90}
91
92pub struct FilterFold<L, S, F, P>
94where
95 F: Fold<L, S>,
96 P: Fn(&L) -> bool,
97{
98 inner: F,
99 predicate: P,
100 _phantom: std::marker::PhantomData<(L, S)>,
101}
102
103impl<L, S, F, P> FilterFold<L, S, F, P>
104where
105 F: Fold<L, S>,
106 P: Fn(&L) -> bool,
107{
108 pub fn new(inner: F, predicate: P) -> Self {
110 Self {
111 inner,
112 predicate,
113 _phantom: std::marker::PhantomData,
114 }
115 }
116}
117
118impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
119where
120 L: Send + Sync,
121 S: Send + Sync,
122 F: Fold<L, S>,
123 P: Fn(&L) -> bool + Send + Sync,
124{
125 fn init(&self, context: &FoldContext) -> S {
126 self.inner.init(context)
127 }
128
129 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
130 if (self.predicate)(entry) {
131 self.inner.reduce(state, entry, context)
132 } else {
133 state
134 }
135 }
136
137 fn finalize(&self, state: S, context: &FoldContext) -> S {
138 self.inner.finalize(state, context)
139 }
140}
141
142pub struct MapFold<L1, L2, S, F, M>
144where
145 F: Fold<L2, S>,
146 M: Fn(&L1) -> L2,
147{
148 inner: F,
149 mapper: M,
150 _phantom: std::marker::PhantomData<(L1, L2, S)>,
151}
152
153impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
154where
155 F: Fold<L2, S>,
156 M: Fn(&L1) -> L2,
157{
158 pub fn new(inner: F, mapper: M) -> Self {
160 Self {
161 inner,
162 mapper,
163 _phantom: std::marker::PhantomData,
164 }
165 }
166}
167
168impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
169where
170 L1: Send + Sync,
171 L2: Send + Sync,
172 S: Send + Sync,
173 F: Fold<L2, S>,
174 M: Fn(&L1) -> L2 + Send + Sync,
175{
176 fn init(&self, context: &FoldContext) -> S {
177 self.inner.init(context)
178 }
179
180 fn reduce(&self, state: S, entry: &L1, context: &FoldContext) -> S {
181 let mapped = (self.mapper)(entry);
182 self.inner.reduce(state, &mapped, context)
183 }
184
185 fn finalize(&self, state: S, context: &FoldContext) -> S {
186 self.inner.finalize(state, context)
187 }
188}
189
190pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
192where
193 F: Fold<L, S>,
194 P: Fn(&L) -> bool,
195{
196 FilterFold::new(inner, predicate)
197}
198
199pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
201where
202 F: Fold<L2, S>,
203 M: Fn(&L1) -> L2,
204{
205 MapFold::new(inner, mapper)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::fold::fold_fn;
212
213 #[test]
214 fn test_filter_fold() {
215 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
216 let filtered = filter(counter, |e: &i32| *e % 2 == 0);
217 let entries = [1, 2, 3, 4, 5, 6];
218 let result = filtered.derive(entries.iter(), &FoldContext::new());
219 assert_eq!(result.state, 3);
220 }
221
222 #[test]
223 fn test_map_fold() {
224 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
225 let doubled = map(summer, |e: &i32| e * 2);
226 let entries = [1, 2, 3];
227 let result = doubled.derive(entries.iter(), &FoldContext::new());
228 assert_eq!(result.state, 12);
229 }
230
231 #[test]
232 fn test_dual_fold() {
233 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
234 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
235 let dual = DualFold::new(summer, counter);
236 let entries = [1, 2, 3, 4, 5];
237 let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
238 assert_eq!(sum_result.state, 15);
239 assert_eq!(count_result.state, 5);
240 }
241
242 #[test]
243 fn test_sequential_fold() {
244 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
245 let summer = fold_fn(
246 |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
247 |sum, entry: &i32, _ctx| sum + entry,
248 );
249 let sequential = SequentialFold::new(counter, summer, |count, ctx| {
250 let mut new_ctx = ctx.clone();
251 *new_ctx.extra_mut() = serde_json::json!({"count": *count});
252 new_ctx
253 });
254 let entries = [1, 2, 3];
255 let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
256 assert_eq!(count_result.state, 3);
257 assert_eq!(sum_result.state, 9);
258 }
259}