1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8pub trait Fold<L, S> {
17 fn initial(&self, context: &FoldContext) -> S;
19
20 fn step(&self, state: S, entry: &L, context: &FoldContext) -> S;
24
25 #[inline]
29 fn finalize(&self, state: S, _context: &FoldContext) -> S {
30 state
31 }
32
33 fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
37 where
38 Self: Sized,
39 I: IntoIterator<Item = &'a L>,
40 L: 'a,
41 {
42 let started_at = chrono::Utc::now();
43 let mut state = self.initial(context);
44 let mut count = 0;
45
46 for entry in entries {
47 state = self.step(state, entry, context);
48 count += 1;
49 }
50
51 state = self.finalize(state, context);
52
53 FoldOutcome::with_timing(state, count, context.clone(), started_at)
54 }
55
56 fn derive_filtered<'a, I, F>(
58 &self,
59 entries: I,
60 context: &FoldContext,
61 filter: F,
62 ) -> FoldOutcome<S>
63 where
64 Self: Sized,
65 I: IntoIterator<Item = &'a L>,
66 L: 'a,
67 F: Fn(&L) -> bool,
68 {
69 let started_at = chrono::Utc::now();
70 let mut state = self.initial(context);
71 let mut count = 0;
72
73 for entry in entries {
74 if filter(entry) {
75 state = self.step(state, entry, context);
76 count += 1;
77 }
78 }
79
80 state = self.finalize(state, context);
81
82 FoldOutcome::with_timing(state, count, context.clone(), started_at)
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
88pub enum FoldFailure {
89 #[error("Fold state mismatch: expected {expected}, got {actual}")]
91 StateMismatch {
92 expected: &'static str,
94 actual: &'static str,
96 },
97}
98
99pub trait TryFold<L, S>: Fold<L, S> {
101 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
103}
104
105impl<L, S, T> Fold<L, S> for Box<T>
106where
107 T: Fold<L, S> + ?Sized,
108{
109 #[inline]
110 fn initial(&self, context: &FoldContext) -> S {
111 (**self).initial(context)
112 }
113
114 #[inline]
115 fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
116 (**self).step(state, entry, context)
117 }
118
119 #[inline]
120 fn finalize(&self, state: S, context: &FoldContext) -> S {
121 (**self).finalize(state, context)
122 }
123}
124
125impl<L, S, T> TryFold<L, S> for Box<T>
126where
127 T: TryFold<L, S> + ?Sized,
128{
129 #[inline]
130 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
131 (**self).try_step(state, entry, context)
132 }
133}
134
135impl<L, S, T> Fold<L, S> for Arc<T>
136where
137 T: Fold<L, S> + ?Sized,
138{
139 #[inline]
140 fn initial(&self, context: &FoldContext) -> S {
141 (**self).initial(context)
142 }
143
144 #[inline]
145 fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
146 (**self).step(state, entry, context)
147 }
148
149 #[inline]
150 fn finalize(&self, state: S, context: &FoldContext) -> S {
151 (**self).finalize(state, context)
152 }
153}
154
155impl<L, S, T> TryFold<L, S> for Arc<T>
156where
157 T: TryFold<L, S> + ?Sized,
158{
159 #[inline]
160 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
161 (**self).try_step(state, entry, context)
162 }
163}
164
165pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
167
168pub struct FnFold<L, S, I, St, F>
170where
171 I: Fn(&FoldContext) -> S,
172 St: Fn(S, &L, &FoldContext) -> S,
173 F: Fn(S, &FoldContext) -> S,
174{
175 initial_fn: I,
176 step_fn: St,
177 finalize_fn: F,
178 _phantom: PhantomData<(L, S)>,
179}
180
181impl<L, S, I, St, F> FnFold<L, S, I, St, F>
182where
183 I: Fn(&FoldContext) -> S,
184 St: Fn(S, &L, &FoldContext) -> S,
185 F: Fn(S, &FoldContext) -> S,
186{
187 pub fn new(initial: I, step: St, finalize: F) -> Self {
189 Self {
190 initial_fn: initial,
191 step_fn: step,
192 finalize_fn: finalize,
193 _phantom: PhantomData,
194 }
195 }
196}
197
198impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
199where
200 I: Fn(&FoldContext) -> S,
201 St: Fn(S, &L, &FoldContext) -> S,
202 F: Fn(S, &FoldContext) -> S,
203{
204 #[inline]
205 fn initial(&self, context: &FoldContext) -> S {
206 (self.initial_fn)(context)
207 }
208
209 #[inline]
210 fn step(&self, state: S, entry: &L, context: &FoldContext) -> S {
211 (self.step_fn)(state, entry, context)
212 }
213
214 #[inline]
215 fn finalize(&self, state: S, context: &FoldContext) -> S {
216 (self.finalize_fn)(state, context)
217 }
218}
219
220impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
221where
222 I: Fn(&FoldContext) -> S,
223 St: Fn(S, &L, &FoldContext) -> S,
224 F: Fn(S, &FoldContext) -> S,
225{
226 #[inline]
227 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
228 Ok((self.step_fn)(state, entry, context))
229 }
230}
231
232pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
234where
235 I: Fn(&FoldContext) -> S,
236 St: Fn(S, &L, &FoldContext) -> S,
237{
238 FnFold::new(initial, step, |s, _| s)
239}
240
241#[derive(Debug, Clone, Copy)]
243pub struct CountFold<L> {
244 _phantom: PhantomData<fn(&L)>,
245}
246
247impl<L> CountFold<L> {
248 #[must_use]
250 pub fn new() -> Self {
251 Self {
252 _phantom: PhantomData,
253 }
254 }
255}
256
257impl<L> Default for CountFold<L> {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263impl<L> Fold<L, usize> for CountFold<L> {
264 #[inline]
265 fn initial(&self, _context: &FoldContext) -> usize {
266 0
267 }
268
269 #[inline]
270 fn step(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
271 state.saturating_add(1)
272 }
273}
274
275impl<L> TryFold<L, usize> for CountFold<L> {
276 #[inline]
277 fn try_step(
278 &self,
279 state: usize,
280 entry: &L,
281 context: &FoldContext,
282 ) -> Result<usize, FoldFailure> {
283 Ok(self.step(state, entry, context))
284 }
285}
286
287#[derive(Clone, Copy)]
289pub struct FilterCountFold<L> {
290 predicate: fn(&L) -> bool,
291}
292
293impl<L> std::fmt::Debug for FilterCountFold<L> {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 f.debug_struct("FilterCountFold").finish()
296 }
297}
298
299impl<L> FilterCountFold<L> {
300 #[must_use]
302 pub fn new(predicate: fn(&L) -> bool) -> Self {
303 Self { predicate }
304 }
305}
306
307impl<L> Fold<L, usize> for FilterCountFold<L> {
308 #[inline]
309 fn initial(&self, _context: &FoldContext) -> usize {
310 0
311 }
312
313 #[inline]
314 fn step(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
315 if (self.predicate)(entry) {
316 state.saturating_add(1)
317 } else {
318 state
319 }
320 }
321}
322
323impl<L> TryFold<L, usize> for FilterCountFold<L> {
324 #[inline]
325 fn try_step(
326 &self,
327 state: usize,
328 entry: &L,
329 context: &FoldContext,
330 ) -> Result<usize, FoldFailure> {
331 Ok(self.step(state, entry, context))
332 }
333}
334
335#[derive(Clone, Copy)]
337pub struct SumI64Fold<L> {
338 project: fn(&L) -> i64,
339}
340
341impl<L> std::fmt::Debug for SumI64Fold<L> {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 f.debug_struct("SumI64Fold").finish()
344 }
345}
346
347impl<L> SumI64Fold<L> {
348 #[must_use]
350 pub fn new(project: fn(&L) -> i64) -> Self {
351 Self { project }
352 }
353}
354
355impl<L> Fold<L, i64> for SumI64Fold<L> {
356 #[inline]
357 fn initial(&self, _context: &FoldContext) -> i64 {
358 0
359 }
360
361 #[inline]
362 fn step(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
363 state.saturating_add((self.project)(entry))
364 }
365}
366
367impl<L> TryFold<L, i64> for SumI64Fold<L> {
368 #[inline]
369 fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
370 Ok(self.step(state, entry, context))
371 }
372}
373
374#[derive(Clone, Copy)]
376pub struct AnyFold<L> {
377 predicate: fn(&L) -> bool,
378}
379
380impl<L> std::fmt::Debug for AnyFold<L> {
381 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 f.debug_struct("AnyFold").finish()
383 }
384}
385
386impl<L> AnyFold<L> {
387 #[must_use]
389 pub fn new(predicate: fn(&L) -> bool) -> Self {
390 Self { predicate }
391 }
392}
393
394impl<L> Fold<L, bool> for AnyFold<L> {
395 #[inline]
396 fn initial(&self, _context: &FoldContext) -> bool {
397 false
398 }
399
400 #[inline]
401 fn step(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
402 state || (self.predicate)(entry)
403 }
404}
405
406impl<L> TryFold<L, bool> for AnyFold<L> {
407 #[inline]
408 fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
409 Ok(self.step(state, entry, context))
410 }
411}
412
413#[derive(Debug, Clone, Copy, PartialEq, Eq)]
415pub enum CommonFoldState {
416 Count(usize),
418 SumI64(i64),
420 Any(bool),
422}
423
424impl CommonFoldState {
425 #[inline]
426 fn kind(self) -> &'static str {
427 match self {
428 Self::Count(_) => "Count",
429 Self::SumI64(_) => "SumI64",
430 Self::Any(_) => "Any",
431 }
432 }
433}
434
435#[derive(Clone)]
440pub enum CommonFold<L> {
441 Count(CountFold<L>),
443 FilterCount(FilterCountFold<L>),
445 SumI64(SumI64Fold<L>),
447 Any(AnyFold<L>),
449}
450
451impl<L> std::fmt::Debug for CommonFold<L> {
452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 match self {
454 Self::Count(_) => f.write_str("CommonFold::Count"),
455 Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
456 Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
457 Self::Any(_) => f.write_str("CommonFold::Any"),
458 }
459 }
460}
461
462impl<L> CommonFold<L> {
463 #[must_use]
465 pub fn count() -> Self {
466 Self::Count(CountFold::new())
467 }
468
469 #[must_use]
471 pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
472 Self::FilterCount(FilterCountFold::new(predicate))
473 }
474
475 #[must_use]
477 pub fn sum_i64(project: fn(&L) -> i64) -> Self {
478 Self::SumI64(SumI64Fold::new(project))
479 }
480
481 #[must_use]
483 pub fn any(predicate: fn(&L) -> bool) -> Self {
484 Self::Any(AnyFold::new(predicate))
485 }
486
487 #[inline]
488 fn expected_state_kind(&self) -> &'static str {
489 match self {
490 Self::Count(_) | Self::FilterCount(_) => "Count",
491 Self::SumI64(_) => "SumI64",
492 Self::Any(_) => "Any",
493 }
494 }
495
496 pub fn try_step(
498 &self,
499 state: CommonFoldState,
500 entry: &L,
501 context: &FoldContext,
502 ) -> Result<CommonFoldState, FoldFailure> {
503 match (self, state) {
504 (Self::Count(inner), CommonFoldState::Count(count)) => {
505 Ok(CommonFoldState::Count(inner.step(count, entry, context)))
506 }
507 (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
508 Ok(CommonFoldState::Count(inner.step(count, entry, context)))
509 }
510 (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
511 Ok(CommonFoldState::SumI64(inner.step(sum, entry, context)))
512 }
513 (Self::Any(inner), CommonFoldState::Any(any)) => {
514 Ok(CommonFoldState::Any(inner.step(any, entry, context)))
515 }
516 (kind, state) => Err(FoldFailure::StateMismatch {
517 expected: kind.expected_state_kind(),
518 actual: state.kind(),
519 }),
520 }
521 }
522}
523
524impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
525 #[inline]
526 fn initial(&self, _context: &FoldContext) -> CommonFoldState {
527 match self {
528 Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
529 Self::SumI64(_) => CommonFoldState::SumI64(0),
530 Self::Any(_) => CommonFoldState::Any(false),
531 }
532 }
533
534 #[inline]
539 fn step(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
540 self.try_step(state, entry, context)
541 .unwrap_or_else(|err| panic!("{err}"))
542 }
543}
544
545impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
546 #[inline]
547 fn try_step(
548 &self,
549 state: CommonFoldState,
550 entry: &L,
551 context: &FoldContext,
552 ) -> Result<CommonFoldState, FoldFailure> {
553 CommonFold::try_step(self, state, entry, context)
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_fold_fn() {
563 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
564 let entries = [1, 2, 3, 4, 5];
565 let result = counter.derive(entries.iter(), &FoldContext::new());
566 assert_eq!(result.state, 5);
567 assert_eq!(result.entries_processed, 5);
568 }
569
570 #[test]
571 fn test_fold_fn_sum() {
572 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
573 let entries = [1, 2, 3, 4, 5];
574 let result = summer.derive(entries.iter(), &FoldContext::new());
575 assert_eq!(result.state, 15);
576 }
577
578 #[test]
579 fn test_fold_filtered() {
580 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
581 let entries = [1, 2, 3, 4, 5, 6];
582 let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
583 assert_eq!(result.state, 12);
584 assert_eq!(result.entries_processed, 3);
585 }
586
587 #[test]
588 fn test_boxed_fold_derive() {
589 #[allow(clippy::box_default)]
590 let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
591 let entries = [1, 2, 3, 4];
592 let result = counter.derive(entries.iter(), &FoldContext::new());
593 assert_eq!(result.state, 4);
594 }
595
596 #[test]
597 fn test_common_fold_count() {
598 let fold = CommonFold::<i32>::count();
599 let entries = [1, 2, 3];
600 let result = fold.derive(entries.iter(), &FoldContext::new());
601 assert_eq!(result.state, CommonFoldState::Count(3));
602 }
603
604 #[test]
605 fn test_common_fold_sum() {
606 let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
607 let entries = [1, 2, 3];
608 let result = fold.derive(entries.iter(), &FoldContext::new());
609 assert_eq!(result.state, CommonFoldState::SumI64(6));
610 }
611
612 #[test]
613 fn count_folds_saturate_on_overflow() {
614 let context = FoldContext::new();
615 let entry = 1;
616
617 let count = CountFold::new();
618 assert_eq!(count.step(usize::MAX, &entry, &context), usize::MAX);
619
620 let filtered = FilterCountFold::new(|_: &i32| true);
621 assert_eq!(filtered.step(usize::MAX, &entry, &context), usize::MAX);
622 }
623
624 #[test]
625 fn sum_i64_fold_saturates_on_overflow() {
626 let context = FoldContext::new();
627 let fold = SumI64Fold::new(|value: &i64| *value);
628 assert_eq!(fold.step(i64::MAX, &1, &context), i64::MAX);
629 }
630
631 #[test]
632 fn common_fold_try_step_mismatch_returns_error() {
633 let context = FoldContext::new();
634 let fold = CommonFold::<i32>::count();
635 let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
636 assert_eq!(
637 err,
638 FoldFailure::StateMismatch {
639 expected: "Count",
640 actual: "SumI64"
641 }
642 );
643 }
644
645 #[test]
646 fn test_any_fold() {
647 let fold = AnyFold::new(|value: &i32| *value == 7);
648 let entries = [1, 2, 7, 9];
649 let result = fold.derive(entries.iter(), &FoldContext::new());
650 assert!(result.state);
651 }
652}