Skip to main content

khive_fold/
compose.rs

1//! Fold composition utilities
2
3use crate::{Fold, FoldContext, FoldOutcome};
4
5/// Sequential fold — run one fold, then use its output to inform another.
6pub struct SequentialFold<L, S1, S2, F1, F2, M>
7where
8    F1: Fold<L, S1>,
9    F2: Fold<L, S2>,
10    M: Fn(&S1, &FoldContext) -> FoldContext,
11{
12    first: F1,
13    second: F2,
14    context_mapper: M,
15    _phantom: std::marker::PhantomData<(L, S1, S2)>,
16}
17
18impl<L, S1, S2, F1, F2, M> SequentialFold<L, S1, S2, F1, F2, M>
19where
20    F1: Fold<L, S1>,
21    F2: Fold<L, S2>,
22    M: Fn(&S1, &FoldContext) -> FoldContext,
23{
24    /// Create a sequential fold.
25    pub fn new(first: F1, second: F2, context_mapper: M) -> Self {
26        Self {
27            first,
28            second,
29            context_mapper,
30            _phantom: std::marker::PhantomData,
31        }
32    }
33
34    /// Execute the sequential fold.
35    pub fn execute<'a, I>(
36        &self,
37        entries: I,
38        context: &FoldContext,
39    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
40    where
41        I: IntoIterator<Item = &'a L> + Clone,
42        L: 'a,
43    {
44        let result1 = self.first.derive(entries.clone(), context);
45        let context2 = (self.context_mapper)(&result1.state, context);
46        let result2 = self.second.derive(entries, &context2);
47        (result1, result2)
48    }
49}
50
51/// Dual fold — run two independent folds over the same entries sequentially.
52pub struct DualFold<L, S1, S2, F1, F2>
53where
54    F1: Fold<L, S1>,
55    F2: Fold<L, S2>,
56{
57    fold1: F1,
58    fold2: F2,
59    _phantom: std::marker::PhantomData<(L, S1, S2)>,
60}
61
62impl<L, S1, S2, F1, F2> DualFold<L, S1, S2, F1, F2>
63where
64    F1: Fold<L, S1>,
65    F2: Fold<L, S2>,
66{
67    /// Create a dual fold.
68    pub fn new(fold1: F1, fold2: F2) -> Self {
69        Self {
70            fold1,
71            fold2,
72            _phantom: std::marker::PhantomData,
73        }
74    }
75
76    /// Execute both folds over the same entries.
77    pub fn execute<'a, I>(
78        &self,
79        entries: I,
80        context: &FoldContext,
81    ) -> (FoldOutcome<S1>, FoldOutcome<S2>)
82    where
83        I: IntoIterator<Item = &'a L> + Clone,
84        L: 'a,
85    {
86        let result1 = self.fold1.derive(entries.clone(), context);
87        let result2 = self.fold2.derive(entries, context);
88        (result1, result2)
89    }
90}
91
92/// Filter fold — only process entries matching a predicate.
93pub struct FilterFold<L, S, F, P>
94where
95    F: Fold<L, S>,
96    P: Fn(&L) -> bool,
97{
98    inner: F,
99    predicate: P,
100    _phantom: std::marker::PhantomData<(L, S)>,
101}
102
103impl<L, S, F, P> FilterFold<L, S, F, P>
104where
105    F: Fold<L, S>,
106    P: Fn(&L) -> bool,
107{
108    /// Create a filter fold.
109    pub fn new(inner: F, predicate: P) -> Self {
110        Self {
111            inner,
112            predicate,
113            _phantom: std::marker::PhantomData,
114        }
115    }
116}
117
118impl<L, S, F, P> Fold<L, S> for FilterFold<L, S, F, P>
119where
120    L: Send + Sync,
121    S: Send + Sync,
122    F: Fold<L, S>,
123    P: Fn(&L) -> bool + Send + Sync,
124{
125    fn init(&self, context: &FoldContext) -> S {
126        self.inner.init(context)
127    }
128
129    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
130        if (self.predicate)(entry) {
131            self.inner.reduce(state, entry, context)
132        } else {
133            state
134        }
135    }
136
137    fn finalize(&self, state: S, context: &FoldContext) -> S {
138        self.inner.finalize(state, context)
139    }
140}
141
142/// Map fold — transform entries before folding.
143pub struct MapFold<L1, L2, S, F, M>
144where
145    F: Fold<L2, S>,
146    M: Fn(&L1) -> L2,
147{
148    inner: F,
149    mapper: M,
150    _phantom: std::marker::PhantomData<(L1, L2, S)>,
151}
152
153impl<L1, L2, S, F, M> MapFold<L1, L2, S, F, M>
154where
155    F: Fold<L2, S>,
156    M: Fn(&L1) -> L2,
157{
158    /// Create a map fold.
159    pub fn new(inner: F, mapper: M) -> Self {
160        Self {
161            inner,
162            mapper,
163            _phantom: std::marker::PhantomData,
164        }
165    }
166}
167
168impl<L1, L2, S, F, M> Fold<L1, S> for MapFold<L1, L2, S, F, M>
169where
170    L1: Send + Sync,
171    L2: Send + Sync,
172    S: Send + Sync,
173    F: Fold<L2, S>,
174    M: Fn(&L1) -> L2 + Send + Sync,
175{
176    fn init(&self, context: &FoldContext) -> S {
177        self.inner.init(context)
178    }
179
180    fn reduce(&self, state: S, entry: &L1, context: &FoldContext) -> S {
181        let mapped = (self.mapper)(entry);
182        self.inner.reduce(state, &mapped, context)
183    }
184
185    fn finalize(&self, state: S, context: &FoldContext) -> S {
186        self.inner.finalize(state, context)
187    }
188}
189
190/// Helper to create a filter fold.
191pub fn filter<L, S, F, P>(inner: F, predicate: P) -> FilterFold<L, S, F, P>
192where
193    F: Fold<L, S>,
194    P: Fn(&L) -> bool,
195{
196    FilterFold::new(inner, predicate)
197}
198
199/// Helper to create a map fold.
200pub fn map<L1, L2, S, F, M>(inner: F, mapper: M) -> MapFold<L1, L2, S, F, M>
201where
202    F: Fold<L2, S>,
203    M: Fn(&L1) -> L2,
204{
205    MapFold::new(inner, mapper)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::fold::fold_fn;
212
213    #[test]
214    fn test_filter_fold() {
215        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
216        let filtered = filter(counter, |e: &i32| *e % 2 == 0);
217        let entries = [1, 2, 3, 4, 5, 6];
218        let result = filtered.derive(entries.iter(), &FoldContext::new());
219        assert_eq!(result.state, 3);
220    }
221
222    #[test]
223    fn test_map_fold() {
224        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
225        let doubled = map(summer, |e: &i32| e * 2);
226        let entries = [1, 2, 3];
227        let result = doubled.derive(entries.iter(), &FoldContext::new());
228        assert_eq!(result.state, 12);
229    }
230
231    #[test]
232    fn test_dual_fold() {
233        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
234        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
235        let dual = DualFold::new(summer, counter);
236        let entries = [1, 2, 3, 4, 5];
237        let (sum_result, count_result) = dual.execute(entries.iter(), &FoldContext::new());
238        assert_eq!(sum_result.state, 15);
239        assert_eq!(count_result.state, 5);
240    }
241
242    #[test]
243    fn test_sequential_fold() {
244        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
245        let summer = fold_fn(
246            |ctx: &FoldContext| ctx.extra.get("count").and_then(|v| v.as_i64()).unwrap_or(0) as i32,
247            |sum, entry: &i32, _ctx| sum + entry,
248        );
249        let sequential = SequentialFold::new(counter, summer, |count, ctx| {
250            let mut new_ctx = ctx.clone();
251            *new_ctx.extra_mut() = serde_json::json!({"count": *count});
252            new_ctx
253        });
254        let entries = [1, 2, 3];
255        let (count_result, sum_result) = sequential.execute(entries.iter(), &FoldContext::new());
256        assert_eq!(count_result.state, 3);
257        assert_eq!(sum_result.state, 9);
258    }
259}