Skip to main content

khive_fold/
compose.rs

1//! Fold composition utilities
2
3use crate::{Fold, FoldContext, FoldOutcome};
4
5/// Sequential fold — run one fold, then use its output to inform another.
6///
7/// # Type Compatibility Invariant
8///
9/// Both `F1` and `F2` must accept the same entry type `L`. The state types
10/// `S1` and `S2` are independent. The `context_mapper` is the only bridge
11/// between the two folds.
12pub struct SequentialFold<L, S1, S2, F1, F2, M>
13where
14    F1: Fold<L, S1>,
15    F2: Fold<L, S2>,
16    M: Fn(&S1, &FoldContext) -> FoldContext,
17{
18    first: F1,
19    second: F2,
20    context_mapper: M,
21    _phantom: std::marker::PhantomData<(L, S1, S2)>,
22}
23
24impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
25where
26    F1: Fold<L, S1>,
27    F2: Fold<L, S2>,
28    M: Fn(&S1, &FoldContext) -> FoldContext,
29{
30    /// Create a sequential fold.
31    pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
32        Self {
33            first,
34            second,
35            context_mapper,
36            _phantom: std::marker::PhantomData,
37        }
38    }
39
40    /// Execute the sequential fold.
41    pub fn execute<'a, I>(
42        &self,
43        entries: I,
44        context: &FoldContext,
45    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
46    where
47        I: IntoIterator<Item = &'a L> + Clone,
48        L: 'a,
49    {
50        let result1 = self.first.derive(entries.clone(), context);
51        let context2 = (self.context_mapper)(&result1.state, context);
52        let result2 = self.second.derive(entries, &context2);
53        (result1, result2)
54    }
55}
56
57/// Dual fold — run two folds independently over the same entries.
58///
59/// Execution is sequential (not parallel threads). The name "dual" reflects
60/// that two independent folds are combined, not that they run concurrently.
61pub struct DualFold<L, S1, S2, F1, F2>
62where
63    F1: Fold<L, S1>,
64    F2: Fold<L, S2>,
65{
66    fold1: F1,
67    fold2: F2,
68    _phantom: std::marker::PhantomData<(L, S1, S2)>,
69}
70
71impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
72where
73    F1: Fold<L, S1>,
74    F2: Fold<L, S2>,
75{
76    /// Create a dual fold.
77    pub fn new(fold1: F1, fold2: F2) -> Self {
78        Self {
79            fold1,
80            fold2,
81            _phantom: std::marker::PhantomData,
82        }
83    }
84
85    /// Execute both folds over the same entries.
86    pub fn execute<'a, I>(
87        &self,
88        entries: I,
89        context: &FoldContext,
90    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
91    where
92        I: IntoIterator<Item = &'a L> + Clone,
93        L: 'a,
94    {
95        let result1 = self.fold1.derive(entries.clone(), context);
96        let result2 = self.fold2.derive(entries, context);
97        (result1, result2)
98    }
99}
100
101/// Filter fold — only process entries matching a predicate.
102pub struct FilterFold<L, S, F, P>
103where
104    F: Fold<L, S>,
105    P: Fn(&L) -> bool,
106{
107    inner: F,
108    predicate: P,
109    _phantom: std::marker::PhantomData<(L, S)>,
110}
111
112impl<L, S, F, P> FilterFold<L, S, F, P>
113where
114    F: Fold<L, S>,
115    P: Fn(&L) -> bool,
116{
117    /// Create a filter fold.
118    pub fn new(inner: F, predicate: P) -> Self {
119        Self {
120            inner,
121            predicate,
122            _phantom: std::marker::PhantomData,
123        }
124    }
125}
126
127impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
128where
129    L: Send + Sync,
130    S: Send + Sync,
131    F: Fold<L, S>,
132    P: Fn(&L) -> bool + Send + Sync,
133{
134    fn init(&self, context: &FoldContext) -> S {
135        self.inner.init(context)
136    }
137
138    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
139        if (self.predicate)(entry) {
140            self.inner.reduce(state, entry, context)
141        } else {
142            state
143        }
144    }
145
146    fn finalize(&self, state: S, context: &FoldContext) -> S {
147        self.inner.finalize(state, context)
148    }
149}
150
151/// Map fold — transform entries before folding.
152pub struct MapFold<L1, L2, S, F, M>
153where
154    F: Fold<L2, S>,
155    M: Fn(&L1) -> L2,
156{
157    inner: F,
158    mapper: M,
159    _phantom: std::marker::PhantomData<(L1, L2, S)>,
160}
161
162impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
163where
164    F: Fold<L2, S>,
165    M: Fn(&L1) -> L2,
166{
167    /// Create a map fold.
168    pub fn new(inner: F, mapper: M) -> Self {
169        Self {
170            inner,
171            mapper,
172            _phantom: std::marker::PhantomData,
173        }
174    }
175}
176
177impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
178where
179    L1: Send + Sync,
180    L2: Send + Sync,
181    S: Send + Sync,
182    F: Fold<L2, S>,
183    M: Fn(&L1) -> L2 + Send + Sync,
184{
185    fn init(&self, context: &FoldContext) -> S {
186        self.inner.init(context)
187    }
188
189    fn reduce(&self, state: S, entry: &L1, context: &FoldContext) -> S {
190        let mapped = (self.mapper)(entry);
191        self.inner.reduce(state, &mapped, context)
192    }
193
194    fn finalize(&self, state: S, context: &FoldContext) -> S {
195        self.inner.finalize(state, context)
196    }
197}
198
199/// Helper to create a filter fold.
200pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
201where
202    F: Fold<L, S>,
203    P: Fn(&L) -> bool,
204{
205    FilterFold::new(inner, predicate)
206}
207
208/// Helper to create a map fold.
209pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
210where
211    F: Fold<L2, S>,
212    M: Fn(&L1) -> L2,
213{
214    MapFold::new(inner, mapper)
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::fold::fold_fn;
221
222    #[test]
223    fn test_filter_fold() {
224        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
225        let filtered = filter(counter, |e: &i32| *e % 2 == 0);
226        let entries = [1, 2, 3, 4, 5, 6];
227        let result = filtered.derive(entries.iter(), &FoldContext::new());
228        assert_eq!(result.state, 3);
229    }
230
231    #[test]
232    fn test_map_fold() {
233        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
234        let doubled = map(summer, |e: &i32| e * 2);
235        let entries = [1, 2, 3];
236        let result = doubled.derive(entries.iter(), &FoldContext::new());
237        assert_eq!(result.state, 12);
238    }
239
240    #[test]
241    fn test_dual_fold() {
242        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
243        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
244        let dual = DualFold::new(summer, counter);
245        let entries = [1, 2, 3, 4, 5];
246        let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
247        assert_eq!(sum_result.state, 15);
248        assert_eq!(count_result.state, 5);
249    }
250
251    #[test]
252    fn test_sequential_fold() {
253        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
254        let summer = fold_fn(
255            |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
256            |sum, entry: &i32, _ctx| sum + entry,
257        );
258        let sequential = SequentialFold::new(counter, summer, |count, ctx| {
259            let mut new_ctx = ctx.clone();
260            *new_ctx.extra_mut() = serde_json::json!({"count": *count});
261            new_ctx
262        });
263        let entries = [1, 2, 3];
264        let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
265        assert_eq!(count_result.state, 3);
266        assert_eq!(sum_result.state, 9);
267    }
268}