use std::marker::PhantomData;
use std::sync::Arc;
use crate::{FoldContext, FoldOutcome};
pub trait Fold<L, S> {
fn initial(&self, context: &FoldContext) -> S;
fn step(&self, state: S, entry: &L, context: &FoldContext) -> S;
#[inline]
fn finalize(&self, state: S, _context: &FoldContext) -> S {
state
}
fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
where
Self: Sized,
I: IntoIterator<Item = &'a L>,
L: 'a,
{
let started_at = chrono::Utc::now();
let mut state = self.initial(context);
let mut count = 0;
for entry in entries {
state = self.step(state, entry, context);
count += 1;
}
state = self.finalize(state, context);
FoldOutcome::with_timing(state, count, context.clone(), started_at)
}
fn derive_filtered<'a, I, F>(
&self,
entries: I,
context: &FoldContext,
filter: F,
) -> FoldOutcome<S>
where
Self: Sized,
I: IntoIterator<Item = &'a L>,
L: 'a,
F: Fn(&L) -> bool,
{
let started_at = chrono::Utc::now();
let mut state = self.initial(context);
let mut count = 0;
for entry in entries {
if filter(entry) {
state = self.step(state, entry, context);
count += 1;
}
}
state = self.finalize(state, context);
FoldOutcome::with_timing(state, count, context.clone(), started_at)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum FoldFailure {
#[error("Fold state mismatch: expected {expected}, got {actual}")]
StateMismatch {
expected: &'static str,
actual: &'static str,
},
}
pub trait TryFold<L, S>: Fold<L, S> {
fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
}
impl<L, S, T> Fold<L, S> for Box<T>
where
T: Fold<L, S> + ?Sized,
{
#[inline]
fn initial(&self, context: &FoldContext) -> S {
(**self).initial(context)
}
#[inline]
fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
(**self).step(state, entry, context)
}
#[inline]
fn finalize(&self, state: S, context: &FoldContext) -> S {
(**self).finalize(state, context)
}
}
impl<L, S, T> TryFold<L, S> for Box<T>
where
T: TryFold<L, S> + ?Sized,
{
#[inline]
fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
(**self).try_step(state, entry, context)
}
}
impl<L, S, T> Fold<L, S> for Arc<T>
where
T: Fold<L, S> + ?Sized,
{
#[inline]
fn initial(&self, context: &FoldContext) -> S {
(**self).initial(context)
}
#[inline]
fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
(**self).step(state, entry, context)
}
#[inline]
fn finalize(&self, state: S, context: &FoldContext) -> S {
(**self).finalize(state, context)
}
}
impl<L, S, T> TryFold<L, S> for Arc<T>
where
T: TryFold<L, S> + ?Sized,
{
#[inline]
fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
(**self).try_step(state, entry, context)
}
}
pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
pub struct FnFold<L, S, I, St, F>
where
I: Fn(&FoldContext) -> S,
St: Fn(S, &L, &FoldContext) -> S,
F: Fn(S, &FoldContext) -> S,
{
initial_fn: I,
step_fn: St,
finalize_fn: F,
_phantom: PhantomData<(L, S)>,
}
impl<L, S, I, St, F> FnFold<L, S, I, St, F>
where
I: Fn(&FoldContext) -> S,
St: Fn(S, &L, &FoldContext) -> S,
F: Fn(S, &FoldContext) -> S,
{
pub fn new(initial: I, step: St, finalize: F) -> Self {
Self {
initial_fn: initial,
step_fn: step,
finalize_fn: finalize,
_phantom: PhantomData,
}
}
}
impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
where
I: Fn(&FoldContext) -> S,
St: Fn(S, &L, &FoldContext) -> S,
F: Fn(S, &FoldContext) -> S,
{
#[inline]
fn initial(&self, context: &FoldContext) -> S {
(self.initial_fn)(context)
}
#[inline]
fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
(self.step_fn)(state, entry, context)
}
#[inline]
fn finalize(&self, state: S, context: &FoldContext) -> S {
(self.finalize_fn)(state, context)
}
}
impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
where
I: Fn(&FoldContext) -> S,
St: Fn(S, &L, &FoldContext) -> S,
F: Fn(S, &FoldContext) -> S,
{
#[inline]
fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
Ok((self.step_fn)(state, entry, context))
}
}
pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
where
I: Fn(&FoldContext) -> S,
St: Fn(S, &L, &FoldContext) -> S,
{
FnFold::new(initial, step, |s, _| s)
}
#[derive(Debug, Clone, Copy)]
pub struct CountFold<L> {
_phantom: PhantomData<fn(&L)>,
}
impl<L> CountFold<L> {
#[must_use]
pub fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
impl<L> Default for CountFold<L> {
fn default() -> Self {
Self::new()
}
}
impl<L> Fold<L, usize> for CountFold<L> {
#[inline]
fn initial(&self, _context: &FoldContext) -> usize {
0
}
#[inline]
fn step(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
state.saturating_add(1)
}
}
impl<L> TryFold<L, usize> for CountFold<L> {
#[inline]
fn try_step(
&self,
state: usize,
entry: &L,
context: &FoldContext,
) -> Result<usize, FoldFailure> {
Ok(self.step(state, entry, context))
}
}
#[derive(Clone, Copy)]
pub struct FilterCountFold<L> {
predicate: fn(&L) -> bool,
}
impl<L> std::fmt::Debug for FilterCountFold<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FilterCountFold").finish()
}
}
impl<L> FilterCountFold<L> {
#[must_use]
pub fn new(predicate: fn(&L) -> bool) -> Self {
Self { predicate }
}
}
impl<L> Fold<L, usize> for FilterCountFold<L> {
#[inline]
fn initial(&self, _context: &FoldContext) -> usize {
0
}
#[inline]
fn step(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
if (self.predicate)(entry) {
state.saturating_add(1)
} else {
state
}
}
}
impl<L> TryFold<L, usize> for FilterCountFold<L> {
#[inline]
fn try_step(
&self,
state: usize,
entry: &L,
context: &FoldContext,
) -> Result<usize, FoldFailure> {
Ok(self.step(state, entry, context))
}
}
#[derive(Clone, Copy)]
pub struct SumI64Fold<L> {
project: fn(&L) -> i64,
}
impl<L> std::fmt::Debug for SumI64Fold<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SumI64Fold").finish()
}
}
impl<L> SumI64Fold<L> {
#[must_use]
pub fn new(project: fn(&L) -> i64) -> Self {
Self { project }
}
}
impl<L> Fold<L, i64> for SumI64Fold<L> {
#[inline]
fn initial(&self, _context: &FoldContext) -> i64 {
0
}
#[inline]
fn step(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
state.saturating_add((self.project)(entry))
}
}
impl<L> TryFold<L, i64> for SumI64Fold<L> {
#[inline]
fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
Ok(self.step(state, entry, context))
}
}
#[derive(Clone, Copy)]
pub struct AnyFold<L> {
predicate: fn(&L) -> bool,
}
impl<L> std::fmt::Debug for AnyFold<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnyFold").finish()
}
}
impl<L> AnyFold<L> {
#[must_use]
pub fn new(predicate: fn(&L) -> bool) -> Self {
Self { predicate }
}
}
impl<L> Fold<L, bool> for AnyFold<L> {
#[inline]
fn initial(&self, _context: &FoldContext) -> bool {
false
}
#[inline]
fn step(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
state || (self.predicate)(entry)
}
}
impl<L> TryFold<L, bool> for AnyFold<L> {
#[inline]
fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
Ok(self.step(state, entry, context))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommonFoldState {
Count(usize),
SumI64(i64),
Any(bool),
}
impl CommonFoldState {
#[inline]
fn kind(self) -> &'static str {
match self {
Self::Count(_) => "Count",
Self::SumI64(_) => "SumI64",
Self::Any(_) => "Any",
}
}
}
#[derive(Clone)]
pub enum CommonFold<L> {
Count(CountFold<L>),
FilterCount(FilterCountFold<L>),
SumI64(SumI64Fold<L>),
Any(AnyFold<L>),
}
impl<L> std::fmt::Debug for CommonFold<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Count(_) => f.write_str("CommonFold::Count"),
Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
Self::Any(_) => f.write_str("CommonFold::Any"),
}
}
}
impl<L> CommonFold<L> {
#[must_use]
pub fn count() -> Self {
Self::Count(CountFold::new())
}
#[must_use]
pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
Self::FilterCount(FilterCountFold::new(predicate))
}
#[must_use]
pub fn sum_i64(project: fn(&L) -> i64) -> Self {
Self::SumI64(SumI64Fold::new(project))
}
#[must_use]
pub fn any(predicate: fn(&L) -> bool) -> Self {
Self::Any(AnyFold::new(predicate))
}
#[inline]
fn expected_state_kind(&self) -> &'static str {
match self {
Self::Count(_) | Self::FilterCount(_) => "Count",
Self::SumI64(_) => "SumI64",
Self::Any(_) => "Any",
}
}
pub fn try_step(
&self,
state: CommonFoldState,
entry: &L,
context: &FoldContext,
) -> Result<CommonFoldState, FoldFailure> {
match (self, state) {
(Self::Count(inner), CommonFoldState::Count(count)) => {
Ok(CommonFoldState::Count(inner.step(count, entry, context)))
}
(Self::FilterCount(inner), CommonFoldState::Count(count)) => {
Ok(CommonFoldState::Count(inner.step(count, entry, context)))
}
(Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
Ok(CommonFoldState::SumI64(inner.step(sum, entry, context)))
}
(Self::Any(inner), CommonFoldState::Any(any)) => {
Ok(CommonFoldState::Any(inner.step(any, entry, context)))
}
(kind, state) => Err(FoldFailure::StateMismatch {
expected: kind.expected_state_kind(),
actual: state.kind(),
}),
}
}
}
impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
#[inline]
fn initial(&self, _context: &FoldContext) -> CommonFoldState {
match self {
Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
Self::SumI64(_) => CommonFoldState::SumI64(0),
Self::Any(_) => CommonFoldState::Any(false),
}
}
#[inline]
fn step(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
self.try_step(state, entry, context)
.unwrap_or_else(|err| panic!("{err}"))
}
}
impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
#[inline]
fn try_step(
&self,
state: CommonFoldState,
entry: &L,
context: &FoldContext,
) -> Result<CommonFoldState, FoldFailure> {
CommonFold::try_step(self, state, entry, context)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fold_fn() {
let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
let entries = [1, 2, 3, 4, 5];
let result = counter.derive(entries.iter(), &FoldContext::new());
assert_eq!(result.state, 5);
assert_eq!(result.entries_processed, 5);
}
#[test]
fn test_fold_fn_sum() {
let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
let entries = [1, 2, 3, 4, 5];
let result = summer.derive(entries.iter(), &FoldContext::new());
assert_eq!(result.state, 15);
}
#[test]
fn test_fold_filtered() {
let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
let entries = [1, 2, 3, 4, 5, 6];
let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
assert_eq!(result.state, 12);
assert_eq!(result.entries_processed, 3);
}
#[test]
fn test_boxed_fold_derive() {
#[allow(clippy::box_default)]
let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
let entries = [1, 2, 3, 4];
let result = counter.derive(entries.iter(), &FoldContext::new());
assert_eq!(result.state, 4);
}
#[test]
fn test_common_fold_count() {
let fold = CommonFold::<i32>::count();
let entries = [1, 2, 3];
let result = fold.derive(entries.iter(), &FoldContext::new());
assert_eq!(result.state, CommonFoldState::Count(3));
}
#[test]
fn test_common_fold_sum() {
let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
let entries = [1, 2, 3];
let result = fold.derive(entries.iter(), &FoldContext::new());
assert_eq!(result.state, CommonFoldState::SumI64(6));
}
#[test]
fn count_folds_saturate_on_overflow() {
let context = FoldContext::new();
let entry = 1;
let count = CountFold::new();
assert_eq!(count.step(usize::MAX, &entry, &context), usize::MAX);
let filtered = FilterCountFold::new(|_: &i32| true);
assert_eq!(filtered.step(usize::MAX, &entry, &context), usize::MAX);
}
#[test]
fn sum_i64_fold_saturates_on_overflow() {
let context = FoldContext::new();
let fold = SumI64Fold::new(|value: &i64| *value);
assert_eq!(fold.step(i64::MAX, &1, &context), i64::MAX);
}
#[test]
fn common_fold_try_step_mismatch_returns_error() {
let context = FoldContext::new();
let fold = CommonFold::<i32>::count();
let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
assert_eq!(
err,
FoldFailure::StateMismatch {
expected: "Count",
actual: "SumI64"
}
);
}
#[test]
fn test_any_fold() {
let fold = AnyFold::new(|value: &i32| *value == 7);
let entries = [1, 2, 7, 9];
let result = fold.derive(entries.iter(), &FoldContext::new());
assert!(result.state);
}
}