Skip to main content

khive_fold/
fold.rs

1//! Core Fold trait
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8/// Core fold trait: collapses a sequence of entries into deterministic derived state.
9pub trait Fold<L, S>: Send + Sync {
10    /// Get the initial state before any entries are processed.
11    fn init(&self, context: &FoldContext) -> S;
12
13    /// Process a single entry and return the new state.
14    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S;
15
16    /// Finalize the state after all entries are processed; default returns state unchanged.
17    #[inline]
18    fn finalize(&self, state: S, _context: &FoldContext) -> S {
19        state
20    }
21
22    /// Derive state from an iterator of entries.
23    fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
24    where
25        Self: Sized,
26        I: IntoIterator<Item = &'a L>,
27        L: 'a,
28    {
29        let mut state = self.init(context);
30        let mut count = 0;
31
32        for entry in entries {
33            state = self.reduce(state, entry, context);
34            count += 1;
35        }
36
37        FoldOutcome::new(self.finalize(state, context), count)
38    }
39
40    /// Derive state with a filter.
41    fn derive_filtered<'a, I, F>(
42        &self,
43        entries: I,
44        context: &FoldContext,
45        filter: F,
46    ) -> FoldOutcome<S>
47    where
48        Self: Sized,
49        I: IntoIterator<Item = &'a L>,
50        L: 'a,
51        F: Fn(&L) -> bool,
52    {
53        let mut state = self.init(context);
54        let mut count = 0;
55
56        for entry in entries {
57            if filter(entry) {
58                state = self.reduce(state, entry, context);
59                count += 1;
60            }
61        }
62
63        FoldOutcome::new(self.finalize(state, context), count)
64    }
65}
66
67/// Failure returned by fallible fold operations.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
69pub enum FoldFailure {
70    /// The supplied state variant does not match the fold variant.
71    #[error("Fold state mismatch: expected {expected}, got {actual}")]
72    StateMismatch {
73        /// State variant expected by the fold.
74        expected: &'static str,
75        /// State variant supplied by the caller.
76        actual: &'static str,
77    },
78}
79
80/// Fallible fold step API for reducers that can reject invalid state shapes.
81pub trait TryFold<L, S>: Fold<L, S> {
82    /// Process a single entry and return an error instead of panicking.
83    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
84}
85
86impl<L, S, T> Fold<L, S> for Box<T>
87where
88    T: Fold<L, S> + ?Sized,
89{
90    #[inline]
91    fn init(&self, context: &FoldContext) -> S {
92        (**self).init(context)
93    }
94
95    #[inline]
96    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
97        (**self).reduce(state, entry, context)
98    }
99
100    #[inline]
101    fn finalize(&self, state: S, context: &FoldContext) -> S {
102        (**self).finalize(state, context)
103    }
104}
105
106impl<L, S, T> TryFold<L, S> for Box<T>
107where
108    T: TryFold<L, S> + ?Sized,
109{
110    #[inline]
111    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
112        (**self).try_step(state, entry, context)
113    }
114}
115
116impl<L, S, T> Fold<L, S> for Arc<T>
117where
118    T: Fold<L, S> + ?Sized,
119{
120    #[inline]
121    fn init(&self, context: &FoldContext) -> S {
122        (**self).init(context)
123    }
124
125    #[inline]
126    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
127        (**self).reduce(state, entry, context)
128    }
129
130    #[inline]
131    fn finalize(&self, state: S, context: &FoldContext) -> S {
132        (**self).finalize(state, context)
133    }
134}
135
136impl<L, S, T> TryFold<L, S> for Arc<T>
137where
138    T: TryFold<L, S> + ?Sized,
139{
140    #[inline]
141    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
142        (**self).try_step(state, entry, context)
143    }
144}
145
146/// A boxed fold for dynamic dispatch.
147pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
148
149/// Helper to create a fold from closures.
150pub struct FnFold<L, S, I, St, F>
151where
152    I: Fn(&FoldContext) -> S,
153    St: Fn(S, &L, &FoldContext) -> S,
154    F: Fn(S, &FoldContext) -> S,
155{
156    initial_fn: I,
157    step_fn: St,
158    finalize_fn: F,
159    _phantom: PhantomData<(L, S)>,
160}
161
162impl<L, S, I, St, F> FnFold<L, S, I, St, F>
163where
164    I: Fn(&FoldContext) -> S,
165    St: Fn(S, &L, &FoldContext) -> S,
166    F: Fn(S, &FoldContext) -> S,
167{
168    /// Create a new FnFold.
169    pub fn new(initial: I, step: St, finalize: F) -> Self {
170        Self {
171            initial_fn: initial,
172            step_fn: step,
173            finalize_fn: finalize,
174            _phantom: PhantomData,
175        }
176    }
177}
178
179impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
180where
181    L: Send + Sync,
182    S: Send + Sync,
183    I: Fn(&FoldContext) -> S + Send + Sync,
184    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
185    F: Fn(S, &FoldContext) -> S + Send + Sync,
186{
187    #[inline]
188    fn init(&self, context: &FoldContext) -> S {
189        (self.initial_fn)(context)
190    }
191
192    #[inline]
193    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
194        (self.step_fn)(state, entry, context)
195    }
196
197    #[inline]
198    fn finalize(&self, state: S, context: &FoldContext) -> S {
199        (self.finalize_fn)(state, context)
200    }
201}
202
203impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
204where
205    L: Send + Sync,
206    S: Send + Sync,
207    I: Fn(&FoldContext) -> S + Send + Sync,
208    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
209    F: Fn(S, &FoldContext) -> S + Send + Sync,
210{
211    #[inline]
212    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
213        Ok((self.step_fn)(state, entry, context))
214    }
215}
216
217/// Create a fold from just initial and step functions (no finalize).
218pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
219where
220    L: Send + Sync,
221    S: Send + Sync,
222    I: Fn(&FoldContext) -> S + Send + Sync,
223    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
224{
225    FnFold::new(initial, step, |s, _| s)
226}
227
228/// A zero-allocation count fold.
229#[derive(Debug, Clone, Copy)]
230pub struct CountFold<L> {
231    _phantom: PhantomData<fn(&L)>,
232}
233
234impl<L> CountFold<L> {
235    /// Create a new count fold.
236    #[must_use]
237    pub fn new() -> Self {
238        Self {
239            _phantom: PhantomData,
240        }
241    }
242}
243
244impl<L> Default for CountFold<L> {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250impl<L> Fold<L, usize> for CountFold<L> {
251    #[inline]
252    fn init(&self, _context: &FoldContext) -> usize {
253        0
254    }
255
256    #[inline]
257    fn reduce(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
258        state.saturating_add(1)
259    }
260}
261
262impl<L> TryFold<L, usize> for CountFold<L> {
263    #[inline]
264    fn try_step(
265        &self,
266        state: usize,
267        entry: &L,
268        context: &FoldContext,
269    ) -> Result<usize, FoldFailure> {
270        Ok(self.reduce(state, entry, context))
271    }
272}
273
274/// A zero-allocation count fold with a function-pointer predicate.
275#[derive(Clone, Copy)]
276pub struct FilterCountFold<L> {
277    predicate: fn(&L) -> bool,
278}
279
280impl<L> std::fmt::Debug for FilterCountFold<L> {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        f.debug_struct("FilterCountFold").finish()
283    }
284}
285
286impl<L> FilterCountFold<L> {
287    /// Create a new filtered count fold.
288    #[must_use]
289    pub fn new(predicate: fn(&L) -> bool) -> Self {
290        Self { predicate }
291    }
292}
293
294impl<L> Fold<L, usize> for FilterCountFold<L> {
295    #[inline]
296    fn init(&self, _context: &FoldContext) -> usize {
297        0
298    }
299
300    #[inline]
301    fn reduce(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
302        if (self.predicate)(entry) {
303            state.saturating_add(1)
304        } else {
305            state
306        }
307    }
308}
309
310impl<L> TryFold<L, usize> for FilterCountFold<L> {
311    #[inline]
312    fn try_step(
313        &self,
314        state: usize,
315        entry: &L,
316        context: &FoldContext,
317    ) -> Result<usize, FoldFailure> {
318        Ok(self.reduce(state, entry, context))
319    }
320}
321
322/// A zero-allocation i64 summation fold with a function-pointer projection.
323#[derive(Clone, Copy)]
324pub struct SumI64Fold<L> {
325    project: fn(&L) -> i64,
326}
327
328impl<L> std::fmt::Debug for SumI64Fold<L> {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        f.debug_struct("SumI64Fold").finish()
331    }
332}
333
334impl<L> SumI64Fold<L> {
335    /// Create a new summation fold.
336    #[must_use]
337    pub fn new(project: fn(&L) -> i64) -> Self {
338        Self { project }
339    }
340}
341
342impl<L> Fold<L, i64> for SumI64Fold<L> {
343    #[inline]
344    fn init(&self, _context: &FoldContext) -> i64 {
345        0
346    }
347
348    #[inline]
349    fn reduce(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
350        state.saturating_add((self.project)(entry))
351    }
352}
353
354impl<L> TryFold<L, i64> for SumI64Fold<L> {
355    #[inline]
356    fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
357        Ok(self.reduce(state, entry, context))
358    }
359}
360
361/// A zero-allocation existential fold with a function-pointer predicate.
362#[derive(Clone, Copy)]
363pub struct AnyFold<L> {
364    predicate: fn(&L) -> bool,
365}
366
367impl<L> std::fmt::Debug for AnyFold<L> {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        f.debug_struct("AnyFold").finish()
370    }
371}
372
373impl<L> AnyFold<L> {
374    /// Create a new existential fold.
375    #[must_use]
376    pub fn new(predicate: fn(&L) -> bool) -> Self {
377        Self { predicate }
378    }
379}
380
381impl<L> Fold<L, bool> for AnyFold<L> {
382    #[inline]
383    fn init(&self, _context: &FoldContext) -> bool {
384        false
385    }
386
387    #[inline]
388    fn reduce(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
389        state || (self.predicate)(entry)
390    }
391}
392
393impl<L> TryFold<L, bool> for AnyFold<L> {
394    #[inline]
395    fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
396        Ok(self.reduce(state, entry, context))
397    }
398}
399
400/// Unified state returned by [`CommonFold`].
401#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402pub enum CommonFoldState {
403    /// Count-like output.
404    Count(usize),
405    /// Summation output.
406    SumI64(i64),
407    /// Boolean existential output.
408    Any(bool),
409}
410
411impl CommonFoldState {
412    #[inline]
413    fn kind(self) -> &'static str {
414        match self {
415            Self::Count(_) => "Count",
416            Self::SumI64(_) => "SumI64",
417            Self::Any(_) => "Any",
418        }
419    }
420}
421
422/// Enum-dispatch fold for common allocation-free patterns without vtable overhead.
423#[derive(Clone)]
424pub enum CommonFold<L> {
425    /// Count every entry.
426    Count(CountFold<L>),
427    /// Count entries matching a predicate.
428    FilterCount(FilterCountFold<L>),
429    /// Sum projected i64 values.
430    SumI64(SumI64Fold<L>),
431    /// Return whether any entry matches the predicate.
432    Any(AnyFold<L>),
433}
434
435impl<L> std::fmt::Debug for CommonFold<L> {
436    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437        match self {
438            Self::Count(_) => f.write_str("CommonFold::Count"),
439            Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
440            Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
441            Self::Any(_) => f.write_str("CommonFold::Any"),
442        }
443    }
444}
445
446impl<L> CommonFold<L> {
447    /// Create a counting common fold.
448    #[must_use]
449    pub fn count() -> Self {
450        Self::Count(CountFold::new())
451    }
452
453    /// Create a filtered-count common fold.
454    #[must_use]
455    pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
456        Self::FilterCount(FilterCountFold::new(predicate))
457    }
458
459    /// Create an i64 summation common fold.
460    #[must_use]
461    pub fn sum_i64(project: fn(&L) -> i64) -> Self {
462        Self::SumI64(SumI64Fold::new(project))
463    }
464
465    /// Create an existential common fold.
466    #[must_use]
467    pub fn any(predicate: fn(&L) -> bool) -> Self {
468        Self::Any(AnyFold::new(predicate))
469    }
470
471    #[inline]
472    fn expected_state_kind(&self) -> &'static str {
473        match self {
474            Self::Count(_) | Self::FilterCount(_) => "Count",
475            Self::SumI64(_) => "SumI64",
476            Self::Any(_) => "Any",
477        }
478    }
479
480    /// Process a single entry and return an error if the state shape is invalid.
481    pub fn try_step(
482        &self,
483        state: CommonFoldState,
484        entry: &L,
485        context: &FoldContext,
486    ) -> Result<CommonFoldState, FoldFailure> {
487        match (self, state) {
488            (Self::Count(inner), CommonFoldState::Count(count)) => {
489                Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
490            }
491            (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
492                Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
493            }
494            (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
495                Ok(CommonFoldState::SumI64(inner.reduce(sum, entry, context)))
496            }
497            (Self::Any(inner), CommonFoldState::Any(any)) => {
498                Ok(CommonFoldState::Any(inner.reduce(any, entry, context)))
499            }
500            (kind, state) => Err(FoldFailure::StateMismatch {
501                expected: kind.expected_state_kind(),
502                actual: state.kind(),
503            }),
504        }
505    }
506}
507
508impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
509    #[inline]
510    fn init(&self, _context: &FoldContext) -> CommonFoldState {
511        match self {
512            Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
513            Self::SumI64(_) => CommonFoldState::SumI64(0),
514            Self::Any(_) => CommonFoldState::Any(false),
515        }
516    }
517
518    /// Panics if `state` variant does not match `self`; use `try_step` for fallible version.
519    #[inline]
520    fn reduce(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
521        self.try_step(state, entry, context)
522            .unwrap_or_else(|err| panic!("{err}"))
523    }
524}
525
526impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
527    #[inline]
528    fn try_step(
529        &self,
530        state: CommonFoldState,
531        entry: &L,
532        context: &FoldContext,
533    ) -> Result<CommonFoldState, FoldFailure> {
534        CommonFold::try_step(self, state, entry, context)
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_fold_fn() {
544        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
545        let entries = [1, 2, 3, 4, 5];
546        let result = counter.derive(entries.iter(), &FoldContext::new());
547        assert_eq!(result.state, 5);
548        assert_eq!(result.entries_processed, 5);
549    }
550
551    #[test]
552    fn test_fold_fn_sum() {
553        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
554        let entries = [1, 2, 3, 4, 5];
555        let result = summer.derive(entries.iter(), &FoldContext::new());
556        assert_eq!(result.state, 15);
557    }
558
559    #[test]
560    fn test_fold_filtered() {
561        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
562        let entries = [1, 2, 3, 4, 5, 6];
563        let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
564        assert_eq!(result.state, 12);
565        assert_eq!(result.entries_processed, 3);
566    }
567
568    #[test]
569    fn test_boxed_fold_derive() {
570        // REASON: box_default fires because CountFold implements Default, but the test
571        // explicitly exercises the BoxedFold type alias which requires Box::new().
572        #[allow(clippy::box_default)]
573        let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
574        let entries = [1, 2, 3, 4];
575        let result = counter.derive(entries.iter(), &FoldContext::new());
576        assert_eq!(result.state, 4);
577    }
578
579    #[test]
580    fn test_common_fold_count() {
581        let fold = CommonFold::<i32>::count();
582        let entries = [1, 2, 3];
583        let result = fold.derive(entries.iter(), &FoldContext::new());
584        assert_eq!(result.state, CommonFoldState::Count(3));
585    }
586
587    #[test]
588    fn test_common_fold_sum() {
589        let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
590        let entries = [1, 2, 3];
591        let result = fold.derive(entries.iter(), &FoldContext::new());
592        assert_eq!(result.state, CommonFoldState::SumI64(6));
593    }
594
595    #[test]
596    fn count_folds_saturate_on_overflow() {
597        let context = FoldContext::new();
598        let entry = 1;
599
600        let count = CountFold::new();
601        assert_eq!(count.reduce(usize::MAX, &entry, &context), usize::MAX);
602
603        let filtered = FilterCountFold::new(|_: &i32| true);
604        assert_eq!(filtered.reduce(usize::MAX, &entry, &context), usize::MAX);
605    }
606
607    #[test]
608    fn sum_i64_fold_saturates_on_overflow() {
609        let context = FoldContext::new();
610        let fold = SumI64Fold::new(|value: &i64| *value);
611        assert_eq!(fold.reduce(i64::MAX, &1, &context), i64::MAX);
612    }
613
614    #[test]
615    fn common_fold_try_step_mismatch_returns_error() {
616        let context = FoldContext::new();
617        let fold = CommonFold::<i32>::count();
618        let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
619        assert_eq!(
620            err,
621            FoldFailure::StateMismatch {
622                expected: "Count",
623                actual: "SumI64"
624            }
625        );
626    }
627
628    #[test]
629    fn test_any_fold() {
630        let fold = AnyFold::new(|value: &i32| *value == 7);
631        let entries = [1, 2, 7, 9];
632        let result = fold.derive(entries.iter(), &FoldContext::new());
633        assert!(result.state);
634    }
635
636    #[test]
637    fn fold_is_deterministic_no_timing() {
638        // Same inputs must produce equal FoldOutcome (PartialEq holds).
639        let fold = fold_fn(|_ctx| 0usize, |c, _: &i32, _ctx| c + 1);
640        let entries = [1, 2, 3];
641        let ctx = FoldContext::new();
642        let a = fold.derive(entries.iter(), &ctx);
643        let b = fold.derive(entries.iter(), &ctx);
644        assert_eq!(a, b);
645    }
646}