Skip to main content

khive_fold/
fold.rs

1//! Core Fold trait
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8/// Core fold trait for deriving state from entries.
9///
10/// A fold is the "measurement operator" that collapses a sequence of
11/// entries into a derived state. The fold is parameterized by:
12/// - L: The entry type (LogEntry, AtomEntry, MemoryEntry, etc.)
13/// - S: The derived state type
14///
15/// Folds are deterministic: same entries + same context = same state.
16pub trait Fold<L, S> {
17    /// Get the initial state before any entries are processed.
18    fn initial(&self, context: &FoldContext) -> S;
19
20    /// Process a single entry and return the new state.
21    ///
22    /// This is the core step function: state' = step(state, entry, context)
23    fn step(&self, state: S, entry: &L, context: &FoldContext) -> S;
24
25    /// Finalize the state after all entries are processed.
26    ///
27    /// Default implementation returns state unchanged.
28    #[inline]
29    fn finalize(&self, state: S, _context: &FoldContext) -> S {
30        state
31    }
32
33    /// Derive state from an iterator of entries.
34    ///
35    /// This is the main entry point for using a fold.
36    fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
37    where
38        Self: Sized,
39        I: IntoIterator<Item = &'a L>,
40        L: 'a,
41    {
42        let started_at = chrono::Utc::now();
43        let mut state = self.initial(context);
44        let mut count = 0;
45
46        for entry in entries {
47            state = self.step(state, entry, context);
48            count += 1;
49        }
50
51        state = self.finalize(state, context);
52
53        FoldOutcome::with_timing(state, count, context.clone(), started_at)
54    }
55
56    /// Derive state with a filter.
57    fn derive_filtered<'a, I, F>(
58        &self,
59        entries: I,
60        context: &FoldContext,
61        filter: F,
62    ) -> FoldOutcome<S>
63    where
64        Self: Sized,
65        I: IntoIterator<Item = &'a L>,
66        L: 'a,
67        F: Fn(&L) -> bool,
68    {
69        let started_at = chrono::Utc::now();
70        let mut state = self.initial(context);
71        let mut count = 0;
72
73        for entry in entries {
74            if filter(entry) {
75                state = self.step(state, entry, context);
76                count += 1;
77            }
78        }
79
80        state = self.finalize(state, context);
81
82        FoldOutcome::with_timing(state, count, context.clone(), started_at)
83    }
84}
85
86/// Failure returned by fallible fold operations.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
88pub enum FoldFailure {
89    /// The supplied state variant does not match the fold variant.
90    #[error("Fold state mismatch: expected {expected}, got {actual}")]
91    StateMismatch {
92        /// State variant expected by the fold.
93        expected: &'static str,
94        /// State variant supplied by the caller.
95        actual: &'static str,
96    },
97}
98
99/// Fallible fold step API for reducers that can reject invalid state shapes.
100pub trait TryFold<L, S>: Fold<L, S> {
101    /// Process a single entry and return an error instead of panicking.
102    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
103}
104
105impl<L, S, T> Fold<L, S> for Box<T>
106where
107    T: Fold<L, S> + ?Sized,
108{
109    #[inline]
110    fn initial(&self, context: &FoldContext) -> S {
111        (**self).initial(context)
112    }
113
114    #[inline]
115    fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
116        (**self).step(state, entry, context)
117    }
118
119    #[inline]
120    fn finalize(&self, state: S, context: &FoldContext) -> S {
121        (**self).finalize(state, context)
122    }
123}
124
125impl<L, S, T> TryFold<L, S> for Box<T>
126where
127    T: TryFold<L, S> + ?Sized,
128{
129    #[inline]
130    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
131        (**self).try_step(state, entry, context)
132    }
133}
134
135impl<L, S, T> Fold<L, S> for Arc<T>
136where
137    T: Fold<L, S> + ?Sized,
138{
139    #[inline]
140    fn initial(&self, context: &FoldContext) -> S {
141        (**self).initial(context)
142    }
143
144    #[inline]
145    fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
146        (**self).step(state, entry, context)
147    }
148
149    #[inline]
150    fn finalize(&self, state: S, context: &FoldContext) -> S {
151        (**self).finalize(state, context)
152    }
153}
154
155impl<L, S, T> TryFold<L, S> for Arc<T>
156where
157    T: TryFold<L, S> + ?Sized,
158{
159    #[inline]
160    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
161        (**self).try_step(state, entry, context)
162    }
163}
164
165/// A boxed fold for dynamic dispatch.
166pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
167
168/// Helper to create a fold from closures.
169pub struct FnFold<L, S, I, St, F>
170where
171    I: Fn(&FoldContext) -> S,
172    St: Fn(S, &L, &FoldContext) -> S,
173    F: Fn(S, &FoldContext) -> S,
174{
175    initial_fn: I,
176    step_fn: St,
177    finalize_fn: F,
178    _phantom: PhantomData<(L, S)>,
179}
180
181impl<L, S, I, St, F> FnFold<L, S, I, St, F>
182where
183    I: Fn(&FoldContext) -> S,
184    St: Fn(S, &L, &FoldContext) -> S,
185    F: Fn(S, &FoldContext) -> S,
186{
187    /// Create a new FnFold.
188    pub fn new(initial: I, step: St, finalize: F) -> Self {
189        Self {
190            initial_fn: initial,
191            step_fn: step,
192            finalize_fn: finalize,
193            _phantom: PhantomData,
194        }
195    }
196}
197
198impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
199where
200    I: Fn(&FoldContext) -> S,
201    St: Fn(S, &L, &FoldContext) -> S,
202    F: Fn(S, &FoldContext) -> S,
203{
204    #[inline]
205    fn initial(&self, context: &FoldContext) -> S {
206        (self.initial_fn)(context)
207    }
208
209    #[inline]
210    fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
211        (self.step_fn)(state, entry, context)
212    }
213
214    #[inline]
215    fn finalize(&self, state: S, context: &FoldContext) -> S {
216        (self.finalize_fn)(state, context)
217    }
218}
219
220impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
221where
222    I: Fn(&FoldContext) -> S,
223    St: Fn(S, &L, &FoldContext) -> S,
224    F: Fn(S, &FoldContext) -> S,
225{
226    #[inline]
227    fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
228        Ok((self.step_fn)(state, entry, context))
229    }
230}
231
232/// Create a fold from just initial and step functions (no finalize).
233pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
234where
235    I: Fn(&FoldContext) -> S,
236    St: Fn(S, &L, &FoldContext) -> S,
237{
238    FnFold::new(initial, step, |s, _| s)
239}
240
241/// A zero-allocation count fold.
242#[derive(Debug, Clone, Copy)]
243pub struct CountFold<L> {
244    _phantom: PhantomData<fn(&L)>,
245}
246
247impl<L> CountFold<L> {
248    /// Create a new count fold.
249    #[must_use]
250    pub fn new() -> Self {
251        Self {
252            _phantom: PhantomData,
253        }
254    }
255}
256
257impl<L> Default for CountFold<L> {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263impl<L> Fold<L, usize> for CountFold<L> {
264    #[inline]
265    fn initial(&self, _context: &FoldContext) -> usize {
266        0
267    }
268
269    #[inline]
270    fn step(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
271        state.saturating_add(1)
272    }
273}
274
275impl<L> TryFold<L, usize> for CountFold<L> {
276    #[inline]
277    fn try_step(
278        &self,
279        state: usize,
280        entry: &L,
281        context: &FoldContext,
282    ) -> Result<usize, FoldFailure> {
283        Ok(self.step(state, entry, context))
284    }
285}
286
287/// A zero-allocation count fold with a function-pointer predicate.
288#[derive(Clone, Copy)]
289pub struct FilterCountFold<L> {
290    predicate: fn(&L) -> bool,
291}
292
293impl<L> std::fmt::Debug for FilterCountFold<L> {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        f.debug_struct("FilterCountFold").finish()
296    }
297}
298
299impl<L> FilterCountFold<L> {
300    /// Create a new filtered count fold.
301    #[must_use]
302    pub fn new(predicate: fn(&L) -> bool) -> Self {
303        Self { predicate }
304    }
305}
306
307impl<L> Fold<L, usize> for FilterCountFold<L> {
308    #[inline]
309    fn initial(&self, _context: &FoldContext) -> usize {
310        0
311    }
312
313    #[inline]
314    fn step(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
315        if (self.predicate)(entry) {
316            state.saturating_add(1)
317        } else {
318            state
319        }
320    }
321}
322
323impl<L> TryFold<L, usize> for FilterCountFold<L> {
324    #[inline]
325    fn try_step(
326        &self,
327        state: usize,
328        entry: &L,
329        context: &FoldContext,
330    ) -> Result<usize, FoldFailure> {
331        Ok(self.step(state, entry, context))
332    }
333}
334
335/// A zero-allocation i64 summation fold with a function-pointer projection.
336#[derive(Clone, Copy)]
337pub struct SumI64Fold<L> {
338    project: fn(&L) -> i64,
339}
340
341impl<L> std::fmt::Debug for SumI64Fold<L> {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        f.debug_struct("SumI64Fold").finish()
344    }
345}
346
347impl<L> SumI64Fold<L> {
348    /// Create a new summation fold.
349    #[must_use]
350    pub fn new(project: fn(&L) -> i64) -> Self {
351        Self { project }
352    }
353}
354
355impl<L> Fold<L, i64> for SumI64Fold<L> {
356    #[inline]
357    fn initial(&self, _context: &FoldContext) -> i64 {
358        0
359    }
360
361    #[inline]
362    fn step(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
363        state.saturating_add((self.project)(entry))
364    }
365}
366
367impl<L> TryFold<L, i64> for SumI64Fold<L> {
368    #[inline]
369    fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
370        Ok(self.step(state, entry, context))
371    }
372}
373
374/// A zero-allocation existential fold with a function-pointer predicate.
375#[derive(Clone, Copy)]
376pub struct AnyFold<L> {
377    predicate: fn(&L) -> bool,
378}
379
380impl<L> std::fmt::Debug for AnyFold<L> {
381    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382        f.debug_struct("AnyFold").finish()
383    }
384}
385
386impl<L> AnyFold<L> {
387    /// Create a new existential fold.
388    #[must_use]
389    pub fn new(predicate: fn(&L) -> bool) -> Self {
390        Self { predicate }
391    }
392}
393
394impl<L> Fold<L, bool> for AnyFold<L> {
395    #[inline]
396    fn initial(&self, _context: &FoldContext) -> bool {
397        false
398    }
399
400    #[inline]
401    fn step(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
402        state || (self.predicate)(entry)
403    }
404}
405
406impl<L> TryFold<L, bool> for AnyFold<L> {
407    #[inline]
408    fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
409        Ok(self.step(state, entry, context))
410    }
411}
412
413/// Unified state returned by [`CommonFold`].
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
415pub enum CommonFoldState {
416    /// Count-like output.
417    Count(usize),
418    /// Summation output.
419    SumI64(i64),
420    /// Boolean existential output.
421    Any(bool),
422}
423
424impl CommonFoldState {
425    #[inline]
426    fn kind(self) -> &'static str {
427        match self {
428            Self::Count(_) => "Count",
429            Self::SumI64(_) => "SumI64",
430            Self::Any(_) => "Any",
431        }
432    }
433}
434
435/// Enum-dispatch fold for common patterns that would otherwise use `Box<dyn Fold>`.
436///
437/// This is intentionally limited to a few hot, allocation-free cases so callers
438/// can avoid vtable dispatch in heterogeneous fold collections.
439#[derive(Clone)]
440pub enum CommonFold<L> {
441    /// Count every entry.
442    Count(CountFold<L>),
443    /// Count entries matching a predicate.
444    FilterCount(FilterCountFold<L>),
445    /// Sum projected i64 values.
446    SumI64(SumI64Fold<L>),
447    /// Return whether any entry matches the predicate.
448    Any(AnyFold<L>),
449}
450
451impl<L> std::fmt::Debug for CommonFold<L> {
452    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453        match self {
454            Self::Count(_) => f.write_str("CommonFold::Count"),
455            Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
456            Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
457            Self::Any(_) => f.write_str("CommonFold::Any"),
458        }
459    }
460}
461
462impl<L> CommonFold<L> {
463    /// Create a counting common fold.
464    #[must_use]
465    pub fn count() -> Self {
466        Self::Count(CountFold::new())
467    }
468
469    /// Create a filtered-count common fold.
470    #[must_use]
471    pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
472        Self::FilterCount(FilterCountFold::new(predicate))
473    }
474
475    /// Create an i64 summation common fold.
476    #[must_use]
477    pub fn sum_i64(project: fn(&L) -> i64) -> Self {
478        Self::SumI64(SumI64Fold::new(project))
479    }
480
481    /// Create an existential common fold.
482    #[must_use]
483    pub fn any(predicate: fn(&L) -> bool) -> Self {
484        Self::Any(AnyFold::new(predicate))
485    }
486
487    #[inline]
488    fn expected_state_kind(&self) -> &'static str {
489        match self {
490            Self::Count(_) | Self::FilterCount(_) => "Count",
491            Self::SumI64(_) => "SumI64",
492            Self::Any(_) => "Any",
493        }
494    }
495
496    /// Process a single entry and return an error if the state shape is invalid.
497    pub fn try_step(
498        &self,
499        state: CommonFoldState,
500        entry: &L,
501        context: &FoldContext,
502    ) -> Result<CommonFoldState, FoldFailure> {
503        match (self, state) {
504            (Self::Count(inner), CommonFoldState::Count(count)) => {
505                Ok(CommonFoldState::Count(inner.step(count, entry, context)))
506            }
507            (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
508                Ok(CommonFoldState::Count(inner.step(count, entry, context)))
509            }
510            (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
511                Ok(CommonFoldState::SumI64(inner.step(sum, entry, context)))
512            }
513            (Self::Any(inner), CommonFoldState::Any(any)) => {
514                Ok(CommonFoldState::Any(inner.step(any, entry, context)))
515            }
516            (kind, state) => Err(FoldFailure::StateMismatch {
517                expected: kind.expected_state_kind(),
518                actual: state.kind(),
519            }),
520        }
521    }
522}
523
524impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
525    #[inline]
526    fn initial(&self, _context: &FoldContext) -> CommonFoldState {
527        match self {
528            Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
529            Self::SumI64(_) => CommonFoldState::SumI64(0),
530            Self::Any(_) => CommonFoldState::Any(false),
531        }
532    }
533
534    /// # Panics
535    ///
536    /// Panics if `state` does not match the variant expected by `self`.
537    /// Use [`TryFold::try_step`] to handle the mismatch as an error instead.
538    #[inline]
539    fn step(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
540        self.try_step(state, entry, context)
541            .unwrap_or_else(|err| panic!("{err}"))
542    }
543}
544
545impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
546    #[inline]
547    fn try_step(
548        &self,
549        state: CommonFoldState,
550        entry: &L,
551        context: &FoldContext,
552    ) -> Result<CommonFoldState, FoldFailure> {
553        CommonFold::try_step(self, state, entry, context)
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_fold_fn() {
563        let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
564        let entries = [1, 2, 3, 4, 5];
565        let result = counter.derive(entries.iter(), &FoldContext::new());
566        assert_eq!(result.state, 5);
567        assert_eq!(result.entries_processed, 5);
568    }
569
570    #[test]
571    fn test_fold_fn_sum() {
572        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
573        let entries = [1, 2, 3, 4, 5];
574        let result = summer.derive(entries.iter(), &FoldContext::new());
575        assert_eq!(result.state, 15);
576    }
577
578    #[test]
579    fn test_fold_filtered() {
580        let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
581        let entries = [1, 2, 3, 4, 5, 6];
582        let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
583        assert_eq!(result.state, 12);
584        assert_eq!(result.entries_processed, 3);
585    }
586
587    #[test]
588    fn test_boxed_fold_derive() {
589        #[allow(clippy::box_default)]
590        let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
591        let entries = [1, 2, 3, 4];
592        let result = counter.derive(entries.iter(), &FoldContext::new());
593        assert_eq!(result.state, 4);
594    }
595
596    #[test]
597    fn test_common_fold_count() {
598        let fold = CommonFold::<i32>::count();
599        let entries = [1, 2, 3];
600        let result = fold.derive(entries.iter(), &FoldContext::new());
601        assert_eq!(result.state, CommonFoldState::Count(3));
602    }
603
604    #[test]
605    fn test_common_fold_sum() {
606        let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
607        let entries = [1, 2, 3];
608        let result = fold.derive(entries.iter(), &FoldContext::new());
609        assert_eq!(result.state, CommonFoldState::SumI64(6));
610    }
611
612    #[test]
613    fn count_folds_saturate_on_overflow() {
614        let context = FoldContext::new();
615        let entry = 1;
616
617        let count = CountFold::new();
618        assert_eq!(count.step(usize::MAX, &entry, &context), usize::MAX);
619
620        let filtered = FilterCountFold::new(|_: &i32| true);
621        assert_eq!(filtered.step(usize::MAX, &entry, &context), usize::MAX);
622    }
623
624    #[test]
625    fn sum_i64_fold_saturates_on_overflow() {
626        let context = FoldContext::new();
627        let fold = SumI64Fold::new(|value: &i64| *value);
628        assert_eq!(fold.step(i64::MAX, &1, &context), i64::MAX);
629    }
630
631    #[test]
632    fn common_fold_try_step_mismatch_returns_error() {
633        let context = FoldContext::new();
634        let fold = CommonFold::<i32>::count();
635        let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
636        assert_eq!(
637            err,
638            FoldFailure::StateMismatch {
639                expected: "Count",
640                actual: "SumI64"
641            }
642        );
643    }
644
645    #[test]
646    fn test_any_fold() {
647        let fold = AnyFold::new(|value: &i32| *value == 7);
648        let entries = [1, 2, 7, 9];
649        let result = fold.derive(entries.iter(), &FoldContext::new());
650        assert!(result.state);
651    }
652}