Skip to main content

khive_fold/
fold.rs

1//! Core Fold trait
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8/// Core fold trait for deriving state from entries.
9///
10/// A fold is the "measurement operator" that collapses a sequence of
11/// entries into a derived state. The fold is parameterized by:
12/// - L: The entry type (LogEntry, AtomEntry, MemoryEntry, etc.)
13/// - S: The derived state type
14///
15/// Folds are deterministic: same entries + same context = same state.
16pub trait Fold<L, S>: Send + Sync {
17    /// Get the initial state before any entries are processed.
18    fn init(&self, context: &FoldContext) -> S;
19
20    /// Process a single entry and return the new state.
21    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S;
22
23    /// Finalize the state after all entries are processed.
24    ///
25    /// Default implementation returns state unchanged.
26    #[inline]
27    fn finalize(&self, state: S, _context: &FoldContext) -> S {
28        state
29    }
30
31    /// Derive state from an iterator of entries.
32    ///
33    /// This is the main entry point for using a fold.
34    fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
35    where
36        Self: Sized,
37        I: IntoIterator<Item = &'a L>,
38        L: 'a,
39    {
40        let mut state = self.init(context);
41        let mut count = 0;
42
43        for entry in entries {
44            state = self.reduce(state, entry, context);
45            count += 1;
46        }
47
48        FoldOutcome::new(self.finalize(state, context), count)
49    }
50
51    /// Derive state with a filter.
52    fn derive_filtered<'a, I, F>(
53        &self,
54        entries: I,
55        context: &FoldContext,
56        filter: F,
57    ) -> FoldOutcome<S>
58    where
59        Self: Sized,
60        I: IntoIterator<Item = &'a L>,
61        L: 'a,
62        F: Fn(&L) -> bool,
63    {
64        let mut state = self.init(context);
65        let mut count = 0;
66
67        for entry in entries {
68            if filter(entry) {
69                state = self.reduce(state, entry, context);
70                count += 1;
71            }
72        }
73
74        FoldOutcome::new(self.finalize(state, context), count)
75    }
76}
77
78/// Failure returned by fallible fold operations.
79#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
80pub enum FoldFailure {
81    /// The supplied state variant does not match the fold variant.
82    #[error("Fold state mismatch: expected {expected}, got {actual}")]
83    StateMismatch {
84        /// State variant expected by the fold.
85        expected: &'static str,
86        /// State variant supplied by the caller.
87        actual: &'static str,
88    },
89}
90
91/// Fallible fold step API for reducers that can reject invalid state shapes.
92pub trait TryFold<L, S>: Fold<L, S> {
93    /// Process a single entry and return an error instead of panicking.
94    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
95}
96
97impl<L, S, T> Fold<L, S> for Box<T>
98where
99    T: Fold<L, S> + ?Sized,
100{
101    #[inline]
102    fn init(&self, context: &FoldContext) -> S {
103        (**self).init(context)
104    }
105
106    #[inline]
107    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
108        (**self).reduce(state, entry, context)
109    }
110
111    #[inline]
112    fn finalize(&self, state: S, context: &FoldContext) -> S {
113        (**self).finalize(state, context)
114    }
115}
116
117impl<L, S, T> TryFold<L, S> for Box<T>
118where
119    T: TryFold<L, S> + ?Sized,
120{
121    #[inline]
122    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
123        (**self).try_step(state, entry, context)
124    }
125}
126
127impl<L, S, T> Fold<L, S> for Arc<T>
128where
129    T: Fold<L, S> + ?Sized,
130{
131    #[inline]
132    fn init(&self, context: &FoldContext) -> S {
133        (**self).init(context)
134    }
135
136    #[inline]
137    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
138        (**self).reduce(state, entry, context)
139    }
140
141    #[inline]
142    fn finalize(&self, state: S, context: &FoldContext) -> S {
143        (**self).finalize(state, context)
144    }
145}
146
147impl<L, S, T> TryFold<L, S> for Arc<T>
148where
149    T: TryFold<L, S> + ?Sized,
150{
151    #[inline]
152    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
153        (**self).try_step(state, entry, context)
154    }
155}
156
157/// A boxed fold for dynamic dispatch.
158pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
159
160/// Helper to create a fold from closures.
161pub struct FnFold<L, S, I, St, F>
162where
163    I: Fn(&FoldContext) -> S,
164    St: Fn(S, &L, &FoldContext) -> S,
165    F: Fn(S, &FoldContext) -> S,
166{
167    initial_fn: I,
168    step_fn: St,
169    finalize_fn: F,
170    _phantom: PhantomData<(L, S)>,
171}
172
173impl<L, S, I, St, F> FnFold<L, S, I, St, F>
174where
175    I: Fn(&FoldContext) -> S,
176    St: Fn(S, &L, &FoldContext) -> S,
177    F: Fn(S, &FoldContext) -> S,
178{
179    /// Create a new FnFold.
180    pub fn new(initial: I, step: St, finalize: F) -> Self {
181        Self {
182            initial_fn: initial,
183            step_fn: step,
184            finalize_fn: finalize,
185            _phantom: PhantomData,
186        }
187    }
188}
189
190impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
191where
192    L: Send + Sync,
193    S: Send + Sync,
194    I: Fn(&FoldContext) -> S + Send + Sync,
195    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
196    F: Fn(S, &FoldContext) -> S + Send + Sync,
197{
198    #[inline]
199    fn init(&self, context: &FoldContext) -> S {
200        (self.initial_fn)(context)
201    }
202
203    #[inline]
204    fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
205        (self.step_fn)(state, entry, context)
206    }
207
208    #[inline]
209    fn finalize(&self, state: S, context: &FoldContext) -> S {
210        (self.finalize_fn)(state, context)
211    }
212}
213
214impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
215where
216    L: Send + Sync,
217    S: Send + Sync,
218    I: Fn(&FoldContext) -> S + Send + Sync,
219    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
220    F: Fn(S, &FoldContext) -> S + Send + Sync,
221{
222    #[inline]
223    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
224        Ok((self.step_fn)(state, entry, context))
225    }
226}
227
228/// Create a fold from just initial and step functions (no finalize).
229pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
230where
231    L: Send + Sync,
232    S: Send + Sync,
233    I: Fn(&FoldContext) -> S + Send + Sync,
234    St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
235{
236    FnFold::new(initial, step, |s, _| s)
237}
238
239/// A zero-allocation count fold.
240#[derive(Debug, Clone, Copy)]
241pub struct CountFold<L> {
242    _phantom: PhantomData<fn(&L)>,
243}
244
245impl<L> CountFold<L> {
246    /// Create a new count fold.
247    #[must_use]
248    pub fn new() -> Self {
249        Self {
250            _phantom: PhantomData,
251        }
252    }
253}
254
255impl<L> Default for CountFold<L> {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261impl<L> Fold<L, usize> for CountFold<L> {
262    #[inline]
263    fn init(&self, _context: &FoldContext) -> usize {
264        0
265    }
266
267    #[inline]
268    fn reduce(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
269        state.saturating_add(1)
270    }
271}
272
273impl<L> TryFold<L, usize> for CountFold<L> {
274    #[inline]
275    fn try_step(
276        &self,
277        state: usize,
278        entry: &L,
279        context: &FoldContext,
280    ) -> Result<usize, FoldFailure> {
281        Ok(self.reduce(state, entry, context))
282    }
283}
284
285/// A zero-allocation count fold with a function-pointer predicate.
286#[derive(Clone, Copy)]
287pub struct FilterCountFold<L> {
288    predicate: fn(&L) -> bool,
289}
290
291impl<L> std::fmt::Debug for FilterCountFold<L> {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("FilterCountFold").finish()
294    }
295}
296
297impl<L> FilterCountFold<L> {
298    /// Create a new filtered count fold.
299    #[must_use]
300    pub fn new(predicate: fn(&L) -> bool) -> Self {
301        Self { predicate }
302    }
303}
304
305impl<L> Fold<L, usize> for FilterCountFold<L> {
306    #[inline]
307    fn init(&self, _context: &FoldContext) -> usize {
308        0
309    }
310
311    #[inline]
312    fn reduce(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
313        if (self.predicate)(entry) {
314            state.saturating_add(1)
315        } else {
316            state
317        }
318    }
319}
320
321impl<L> TryFold<L, usize> for FilterCountFold<L> {
322    #[inline]
323    fn try_step(
324        &self,
325        state: usize,
326        entry: &L,
327        context: &FoldContext,
328    ) -> Result<usize, FoldFailure> {
329        Ok(self.reduce(state, entry, context))
330    }
331}
332
333/// A zero-allocation i64 summation fold with a function-pointer projection.
334#[derive(Clone, Copy)]
335pub struct SumI64Fold<L> {
336    project: fn(&L) -> i64,
337}
338
339impl<L> std::fmt::Debug for SumI64Fold<L> {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        f.debug_struct("SumI64Fold").finish()
342    }
343}
344
345impl<L> SumI64Fold<L> {
346    /// Create a new summation fold.
347    #[must_use]
348    pub fn new(project: fn(&L) -> i64) -> Self {
349        Self { project }
350    }
351}
352
353impl<L> Fold<L, i64> for SumI64Fold<L> {
354    #[inline]
355    fn init(&self, _context: &FoldContext) -> i64 {
356        0
357    }
358
359    #[inline]
360    fn reduce(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
361        state.saturating_add((self.project)(entry))
362    }
363}
364
365impl<L> TryFold<L, i64> for SumI64Fold<L> {
366    #[inline]
367    fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
368        Ok(self.reduce(state, entry, context))
369    }
370}
371
372/// A zero-allocation existential fold with a function-pointer predicate.
373#[derive(Clone, Copy)]
374pub struct AnyFold<L> {
375    predicate: fn(&L) -> bool,
376}
377
378impl<L> std::fmt::Debug for AnyFold<L> {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        f.debug_struct("AnyFold").finish()
381    }
382}
383
384impl<L> AnyFold<L> {
385    /// Create a new existential fold.
386    #[must_use]
387    pub fn new(predicate: fn(&L) -> bool) -> Self {
388        Self { predicate }
389    }
390}
391
392impl<L> Fold<L, bool> for AnyFold<L> {
393    #[inline]
394    fn init(&self, _context: &FoldContext) -> bool {
395        false
396    }
397
398    #[inline]
399    fn reduce(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
400        state || (self.predicate)(entry)
401    }
402}
403
404impl<L> TryFold<L, bool> for AnyFold<L> {
405    #[inline]
406    fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
407        Ok(self.reduce(state, entry, context))
408    }
409}
410
411/// Unified state returned by [`CommonFold`].
412#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum CommonFoldState {
414    /// Count-like output.
415    Count(usize),
416    /// Summation output.
417    SumI64(i64),
418    /// Boolean existential output.
419    Any(bool),
420}
421
422impl CommonFoldState {
423    #[inline]
424    fn kind(self) -> &'static str {
425        match self {
426            Self::Count(_) => "Count",
427            Self::SumI64(_) => "SumI64",
428            Self::Any(_) => "Any",
429        }
430    }
431}
432
433/// Enum-dispatch fold for common patterns that would otherwise use `Box<dyn Fold>`.
434///
435/// This is intentionally limited to a few hot, allocation-free cases so callers
436/// can avoid vtable dispatch in heterogeneous fold collections.
437#[derive(Clone)]
438pub enum CommonFold<L> {
439    /// Count every entry.
440    Count(CountFold<L>),
441    /// Count entries matching a predicate.
442    FilterCount(FilterCountFold<L>),
443    /// Sum projected i64 values.
444    SumI64(SumI64Fold<L>),
445    /// Return whether any entry matches the predicate.
446    Any(AnyFold<L>),
447}
448
449impl<L> std::fmt::Debug for CommonFold<L> {
450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451        match self {
452            Self::Count(_) => f.write_str("CommonFold::Count"),
453            Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
454            Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
455            Self::Any(_) => f.write_str("CommonFold::Any"),
456        }
457    }
458}
459
460impl<L> CommonFold<L> {
461    /// Create a counting common fold.
462    #[must_use]
463    pub fn count() -> Self {
464        Self::Count(CountFold::new())
465    }
466
467    /// Create a filtered-count common fold.
468    #[must_use]
469    pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
470        Self::FilterCount(FilterCountFold::new(predicate))
471    }
472
473    /// Create an i64 summation common fold.
474    #[must_use]
475    pub fn sum_i64(project: fn(&L) -> i64) -> Self {
476        Self::SumI64(SumI64Fold::new(project))
477    }
478
479    /// Create an existential common fold.
480    #[must_use]
481    pub fn any(predicate: fn(&L) -> bool) -> Self {
482        Self::Any(AnyFold::new(predicate))
483    }
484
485    #[inline]
486    fn expected_state_kind(&self) -> &'static str {
487        match self {
488            Self::Count(_) | Self::FilterCount(_) => "Count",
489            Self::SumI64(_) => "SumI64",
490            Self::Any(_) => "Any",
491        }
492    }
493
494    /// Process a single entry and return an error if the state shape is invalid.
495    pub fn try_step(
496        &self,
497        state: CommonFoldState,
498        entry: &L,
499        context: &FoldContext,
500    ) -> Result<CommonFoldState, FoldFailure> {
501        match (self, state) {
502            (Self::Count(inner), CommonFoldState::Count(count)) => {
503                Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
504            }
505            (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
506                Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
507            }
508            (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
509                Ok(CommonFoldState::SumI64(inner.reduce(sum, entry, context)))
510            }
511            (Self::Any(inner), CommonFoldState::Any(any)) => {
512                Ok(CommonFoldState::Any(inner.reduce(any, entry, context)))
513            }
514            (kind, state) => Err(FoldFailure::StateMismatch {
515                expected: kind.expected_state_kind(),
516                actual: state.kind(),
517            }),
518        }
519    }
520}
521
522impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
523    #[inline]
524    fn init(&self, _context: &FoldContext) -> CommonFoldState {
525        match self {
526            Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
527            Self::SumI64(_) => CommonFoldState::SumI64(0),
528            Self::Any(_) => CommonFoldState::Any(false),
529        }
530    }
531
532    /// # Panics
533    ///
534    /// Panics if `state` does not match the variant expected by `self`.
535    /// Use [`TryFold::try_step`] to handle the mismatch as an error instead.
536    #[inline]
537    fn reduce(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
538        self.try_step(state, entry, context)
539            .unwrap_or_else(|err| panic!("{err}"))
540    }
541}
542
543impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
544    #[inline]
545    fn try_step(
546        &self,
547        state: CommonFoldState,
548        entry: &L,
549        context: &FoldContext,
550    ) -> Result<CommonFoldState, FoldFailure> {
551        CommonFold::try_step(self, state, entry, context)
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn test_fold_fn() {
561        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
562        let entries = [1, 2, 3, 4, 5];
563        let result = counter.derive(entries.iter(), &FoldContext::new());
564        assert_eq!(result.state, 5);
565        assert_eq!(result.entries_processed, 5);
566    }
567
568    #[test]
569    fn test_fold_fn_sum() {
570        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
571        let entries = [1, 2, 3, 4, 5];
572        let result = summer.derive(entries.iter(), &FoldContext::new());
573        assert_eq!(result.state, 15);
574    }
575
576    #[test]
577    fn test_fold_filtered() {
578        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
579        let entries = [1, 2, 3, 4, 5, 6];
580        let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
581        assert_eq!(result.state, 12);
582        assert_eq!(result.entries_processed, 3);
583    }
584
585    #[test]
586    fn test_boxed_fold_derive() {
587        #[allow(clippy::box_default)]
588        let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
589        let entries = [1, 2, 3, 4];
590        let result = counter.derive(entries.iter(), &FoldContext::new());
591        assert_eq!(result.state, 4);
592    }
593
594    #[test]
595    fn test_common_fold_count() {
596        let fold = CommonFold::<i32>::count();
597        let entries = [1, 2, 3];
598        let result = fold.derive(entries.iter(), &FoldContext::new());
599        assert_eq!(result.state, CommonFoldState::Count(3));
600    }
601
602    #[test]
603    fn test_common_fold_sum() {
604        let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
605        let entries = [1, 2, 3];
606        let result = fold.derive(entries.iter(), &FoldContext::new());
607        assert_eq!(result.state, CommonFoldState::SumI64(6));
608    }
609
610    #[test]
611    fn count_folds_saturate_on_overflow() {
612        let context = FoldContext::new();
613        let entry = 1;
614
615        let count = CountFold::new();
616        assert_eq!(count.reduce(usize::MAX, &entry, &context), usize::MAX);
617
618        let filtered = FilterCountFold::new(|_: &i32| true);
619        assert_eq!(filtered.reduce(usize::MAX, &entry, &context), usize::MAX);
620    }
621
622    #[test]
623    fn sum_i64_fold_saturates_on_overflow() {
624        let context = FoldContext::new();
625        let fold = SumI64Fold::new(|value: &i64| *value);
626        assert_eq!(fold.reduce(i64::MAX, &1, &context), i64::MAX);
627    }
628
629    #[test]
630    fn common_fold_try_step_mismatch_returns_error() {
631        let context = FoldContext::new();
632        let fold = CommonFold::<i32>::count();
633        let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
634        assert_eq!(
635            err,
636            FoldFailure::StateMismatch {
637                expected: "Count",
638                actual: "SumI64"
639            }
640        );
641    }
642
643    #[test]
644    fn test_any_fold() {
645        let fold = AnyFold::new(|value: &i32| *value == 7);
646        let entries = [1, 2, 7, 9];
647        let result = fold.derive(entries.iter(), &FoldContext::new());
648        assert!(result.state);
649    }
650
651    #[test]
652    fn fold_is_deterministic_no_timing() {
653        // Same inputs must produce equal FoldOutcome (PartialEq holds).
654        let fold = fold_fn(|_ctx| 0usize, |c, _: &i32, _ctx| c + 1);
655        let entries = [1, 2, 3];
656        let ctx = FoldContext::new();
657        let a = fold.derive(entries.iter(), &ctx);
658        let b = fold.derive(entries.iter(), &ctx);
659        assert_eq!(a, b);
660    }
661}