Skip to main content

khive_fold/
compose.rs

1//! Fold composition utilities
2
3use crate::{Fold, FoldContext, FoldOutcome};
4
5/// Sequential fold — run one fold, then use its output to inform another.
6///
7/// # Type Compatibility Invariant
8///
9/// Both `F1` and `F2` must accept the same entry type `L`. The state types
10/// `S1` and `S2` are independent. The `context_mapper` is the only bridge
11/// between the two folds.
12pub struct SequentialFold<L, S1, S2, F1, F2, M>
13where
14    F1: Fold<L, S1>,
15    F2: Fold<L, S2>,
16    M: Fn(&S1, &FoldContext) -> FoldContext,
17{
18    first: F1,
19    second: F2,
20    context_mapper: M,
21    _phantom: std::marker::PhantomData<(L, S1, S2)>,
22}
23
24impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
25where
26    F1: Fold<L, S1>,
27    F2: Fold<L, S2>,
28    M: Fn(&S1, &FoldContext) -> FoldContext,
29{
30    /// Create a sequential fold.
31    pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
32        Self {
33            first,
34            second,
35            context_mapper,
36            _phantom: std::marker::PhantomData,
37        }
38    }
39
40    /// Execute the sequential fold.
41    pub fn execute<'a, I>(
42        &self,
43        entries: I,
44        context: &FoldContext,
45    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
46    where
47        I: IntoIterator<Item = &'a L> + Clone,
48        L: 'a,
49    {
50        let result1 = self.first.derive(entries.clone(), context);
51        let context2 = (self.context_mapper)(&result1.state, context);
52        let result2 = self.second.derive(entries, &context2);
53        (result1, result2)
54    }
55}
56
57/// Dual fold — run two folds independently over the same entries.
58///
59/// Execution is sequential (not parallel threads). The name "dual" reflects
60/// that two independent folds are combined, not that they run concurrently.
61pub struct DualFold<L, S1, S2, F1, F2>
62where
63    F1: Fold<L, S1>,
64    F2: Fold<L, S2>,
65{
66    fold1: F1,
67    fold2: F2,
68    _phantom: std::marker::PhantomData<(L, S1, S2)>,
69}
70
71impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
72where
73    F1: Fold<L, S1>,
74    F2: Fold<L, S2>,
75{
76    /// Create a dual fold.
77    pub fn new(fold1: F1, fold2: F2) -> Self {
78        Self {
79            fold1,
80            fold2,
81            _phantom: std::marker::PhantomData,
82        }
83    }
84
85    /// Execute both folds over the same entries.
86    pub fn execute<'a, I>(
87        &self,
88        entries: I,
89        context: &FoldContext,
90    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
91    where
92        I: IntoIterator<Item = &'a L> + Clone,
93        L: 'a,
94    {
95        let result1 = self.fold1.derive(entries.clone(), context);
96        let result2 = self.fold2.derive(entries, context);
97        (result1, result2)
98    }
99}
100
101/// Filter fold — only process entries matching a predicate.
102pub struct FilterFold<L, S, F, P>
103where
104    F: Fold<L, S>,
105    P: Fn(&L) -> bool,
106{
107    inner: F,
108    predicate: P,
109    _phantom: std::marker::PhantomData<(L, S)>,
110}
111
112impl<L, S, F, P> FilterFold<L, S, F, P>
113where
114    F: Fold<L, S>,
115    P: Fn(&L) -> bool,
116{
117    /// Create a filter fold.
118    pub fn new(inner: F, predicate: P) -> Self {
119        Self {
120            inner,
121            predicate,
122            _phantom: std::marker::PhantomData,
123        }
124    }
125}
126
127impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
128where
129    F: Fold<L, S>,
130    P: Fn(&L) -> bool,
131{
132    fn initial(&self, context: &FoldContext) -> S {
133        self.inner.initial(context)
134    }
135
136    fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
137        if (self.predicate)(entry) {
138            self.inner.step(state, entry, context)
139        } else {
140            state
141        }
142    }
143
144    fn finalize(&self, state: S, context: &FoldContext) -> S {
145        self.inner.finalize(state, context)
146    }
147}
148
149/// Map fold — transform entries before folding.
150pub struct MapFold<L1, L2, S, F, M>
151where
152    F: Fold<L2, S>,
153    M: Fn(&L1) -> L2,
154{
155    inner: F,
156    mapper: M,
157    _phantom: std::marker::PhantomData<(L1, L2, S)>,
158}
159
160impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
161where
162    F: Fold<L2, S>,
163    M: Fn(&L1) -> L2,
164{
165    /// Create a map fold.
166    pub fn new(inner: F, mapper: M) -> Self {
167        Self {
168            inner,
169            mapper,
170            _phantom: std::marker::PhantomData,
171        }
172    }
173}
174
175impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
176where
177    F: Fold<L2, S>,
178    M: Fn(&L1) -> L2,
179{
180    fn initial(&self, context: &FoldContext) -> S {
181        self.inner.initial(context)
182    }
183
184    fn step(&self, state: S, entry: &L1, context: &FoldContext) -> S {
185        let mapped = (self.mapper)(entry);
186        self.inner.step(state, &mapped, context)
187    }
188
189    fn finalize(&self, state: S, context: &FoldContext) -> S {
190        self.inner.finalize(state, context)
191    }
192}
193
194/// Helper to create a filter fold.
195pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
196where
197    F: Fold<L, S>,
198    P: Fn(&L) -> bool,
199{
200    FilterFold::new(inner, predicate)
201}
202
203/// Helper to create a map fold.
204pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
205where
206    F: Fold<L2, S>,
207    M: Fn(&L1) -> L2,
208{
209    MapFold::new(inner, mapper)
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::fold::fold_fn;
216
217    #[test]
218    fn test_filter_fold() {
219        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
220        let filtered = filter(counter, |e: &i32| *e % 2 == 0);
221        let entries = [1, 2, 3, 4, 5, 6];
222        let result = filtered.derive(entries.iter(), &FoldContext::new());
223        assert_eq!(result.state, 3);
224    }
225
226    #[test]
227    fn test_map_fold() {
228        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
229        let doubled = map(summer, |e: &i32| e * 2);
230        let entries = [1, 2, 3];
231        let result = doubled.derive(entries.iter(), &FoldContext::new());
232        assert_eq!(result.state, 12);
233    }
234
235    #[test]
236    fn test_dual_fold() {
237        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
238        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
239        let dual = DualFold::new(summer, counter);
240        let entries = [1, 2, 3, 4, 5];
241        let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
242        assert_eq!(sum_result.state, 15);
243        assert_eq!(count_result.state, 5);
244    }
245
246    #[test]
247    fn test_sequential_fold() {
248        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
249        let summer = fold_fn(
250            |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
251            |sum, entry: &i32, _ctx| sum + entry,
252        );
253        let sequential = SequentialFold::new(counter, summer, |count, ctx| {
254            let mut new_ctx = ctx.clone();
255            *new_ctx.extra_mut() = serde_json::json!({"count": *count});
256            new_ctx
257        });
258        let entries = [1, 2, 3];
259        let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
260        assert_eq!(count_result.state, 3);
261        assert_eq!(sum_result.state, 9);
262    }
263}