1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8pub trait Fold<L, S>: Send + Sync {
17 fn init(&self, context: &FoldContext) -> S;
19
20 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S;
22
23 #[inline]
27 fn finalize(&self, state: S, _context: &FoldContext) -> S {
28 state
29 }
30
31 fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
35 where
36 Self: Sized,
37 I: IntoIterator<Item = &'a L>,
38 L: 'a,
39 {
40 let mut state = self.init(context);
41 let mut count = 0;
42
43 for entry in entries {
44 state = self.reduce(state, entry, context);
45 count += 1;
46 }
47
48 FoldOutcome::new(self.finalize(state, context), count)
49 }
50
51 fn derive_filtered<'a, I, F>(
53 &self,
54 entries: I,
55 context: &FoldContext,
56 filter: F,
57 ) -> FoldOutcome<S>
58 where
59 Self: Sized,
60 I: IntoIterator<Item = &'a L>,
61 L: 'a,
62 F: Fn(&L) -> bool,
63 {
64 let mut state = self.init(context);
65 let mut count = 0;
66
67 for entry in entries {
68 if filter(entry) {
69 state = self.reduce(state, entry, context);
70 count += 1;
71 }
72 }
73
74 FoldOutcome::new(self.finalize(state, context), count)
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
80pub enum FoldFailure {
81 #[error("Fold state mismatch: expected {expected}, got {actual}")]
83 StateMismatch {
84 expected: &'static str,
86 actual: &'static str,
88 },
89}
90
91pub trait TryFold<L, S>: Fold<L, S> {
93 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
95}
96
97impl<L, S, T> Fold<L, S> for Box<T>
98where
99 T: Fold<L, S> + ?Sized,
100{
101 #[inline]
102 fn init(&self, context: &FoldContext) -> S {
103 (**self).init(context)
104 }
105
106 #[inline]
107 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
108 (**self).reduce(state, entry, context)
109 }
110
111 #[inline]
112 fn finalize(&self, state: S, context: &FoldContext) -> S {
113 (**self).finalize(state, context)
114 }
115}
116
117impl<L, S, T> TryFold<L, S> for Box<T>
118where
119 T: TryFold<L, S> + ?Sized,
120{
121 #[inline]
122 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
123 (**self).try_step(state, entry, context)
124 }
125}
126
127impl<L, S, T> Fold<L, S> for Arc<T>
128where
129 T: Fold<L, S> + ?Sized,
130{
131 #[inline]
132 fn init(&self, context: &FoldContext) -> S {
133 (**self).init(context)
134 }
135
136 #[inline]
137 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
138 (**self).reduce(state, entry, context)
139 }
140
141 #[inline]
142 fn finalize(&self, state: S, context: &FoldContext) -> S {
143 (**self).finalize(state, context)
144 }
145}
146
147impl<L, S, T> TryFold<L, S> for Arc<T>
148where
149 T: TryFold<L, S> + ?Sized,
150{
151 #[inline]
152 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
153 (**self).try_step(state, entry, context)
154 }
155}
156
157pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
159
160pub struct FnFold<L, S, I, St, F>
162where
163 I: Fn(&FoldContext) -> S,
164 St: Fn(S, &L, &FoldContext) -> S,
165 F: Fn(S, &FoldContext) -> S,
166{
167 initial_fn: I,
168 step_fn: St,
169 finalize_fn: F,
170 _phantom: PhantomData<(L, S)>,
171}
172
173impl<L, S, I, St, F> FnFold<L, S, I, St, F>
174where
175 I: Fn(&FoldContext) -> S,
176 St: Fn(S, &L, &FoldContext) -> S,
177 F: Fn(S, &FoldContext) -> S,
178{
179 pub fn new(initial: I, step: St, finalize: F) -> Self {
181 Self {
182 initial_fn: initial,
183 step_fn: step,
184 finalize_fn: finalize,
185 _phantom: PhantomData,
186 }
187 }
188}
189
190impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
191where
192 L: Send + Sync,
193 S: Send + Sync,
194 I: Fn(&FoldContext) -> S + Send + Sync,
195 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
196 F: Fn(S, &FoldContext) -> S + Send + Sync,
197{
198 #[inline]
199 fn init(&self, context: &FoldContext) -> S {
200 (self.initial_fn)(context)
201 }
202
203 #[inline]
204 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
205 (self.step_fn)(state, entry, context)
206 }
207
208 #[inline]
209 fn finalize(&self, state: S, context: &FoldContext) -> S {
210 (self.finalize_fn)(state, context)
211 }
212}
213
214impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
215where
216 L: Send + Sync,
217 S: Send + Sync,
218 I: Fn(&FoldContext) -> S + Send + Sync,
219 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
220 F: Fn(S, &FoldContext) -> S + Send + Sync,
221{
222 #[inline]
223 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
224 Ok((self.step_fn)(state, entry, context))
225 }
226}
227
228pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
230where
231 L: Send + Sync,
232 S: Send + Sync,
233 I: Fn(&FoldContext) -> S + Send + Sync,
234 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
235{
236 FnFold::new(initial, step, |s, _| s)
237}
238
239#[derive(Debug, Clone, Copy)]
241pub struct CountFold<L> {
242 _phantom: PhantomData<fn(&L)>,
243}
244
245impl<L> CountFold<L> {
246 #[must_use]
248 pub fn new() -> Self {
249 Self {
250 _phantom: PhantomData,
251 }
252 }
253}
254
255impl<L> Default for CountFold<L> {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261impl<L> Fold<L, usize> for CountFold<L> {
262 #[inline]
263 fn init(&self, _context: &FoldContext) -> usize {
264 0
265 }
266
267 #[inline]
268 fn reduce(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
269 state.saturating_add(1)
270 }
271}
272
273impl<L> TryFold<L, usize> for CountFold<L> {
274 #[inline]
275 fn try_step(
276 &self,
277 state: usize,
278 entry: &L,
279 context: &FoldContext,
280 ) -> Result<usize, FoldFailure> {
281 Ok(self.reduce(state, entry, context))
282 }
283}
284
285#[derive(Clone, Copy)]
287pub struct FilterCountFold<L> {
288 predicate: fn(&L) -> bool,
289}
290
291impl<L> std::fmt::Debug for FilterCountFold<L> {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("FilterCountFold").finish()
294 }
295}
296
297impl<L> FilterCountFold<L> {
298 #[must_use]
300 pub fn new(predicate: fn(&L) -> bool) -> Self {
301 Self { predicate }
302 }
303}
304
305impl<L> Fold<L, usize> for FilterCountFold<L> {
306 #[inline]
307 fn init(&self, _context: &FoldContext) -> usize {
308 0
309 }
310
311 #[inline]
312 fn reduce(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
313 if (self.predicate)(entry) {
314 state.saturating_add(1)
315 } else {
316 state
317 }
318 }
319}
320
321impl<L> TryFold<L, usize> for FilterCountFold<L> {
322 #[inline]
323 fn try_step(
324 &self,
325 state: usize,
326 entry: &L,
327 context: &FoldContext,
328 ) -> Result<usize, FoldFailure> {
329 Ok(self.reduce(state, entry, context))
330 }
331}
332
333#[derive(Clone, Copy)]
335pub struct SumI64Fold<L> {
336 project: fn(&L) -> i64,
337}
338
339impl<L> std::fmt::Debug for SumI64Fold<L> {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_struct("SumI64Fold").finish()
342 }
343}
344
345impl<L> SumI64Fold<L> {
346 #[must_use]
348 pub fn new(project: fn(&L) -> i64) -> Self {
349 Self { project }
350 }
351}
352
353impl<L> Fold<L, i64> for SumI64Fold<L> {
354 #[inline]
355 fn init(&self, _context: &FoldContext) -> i64 {
356 0
357 }
358
359 #[inline]
360 fn reduce(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
361 state.saturating_add((self.project)(entry))
362 }
363}
364
365impl<L> TryFold<L, i64> for SumI64Fold<L> {
366 #[inline]
367 fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
368 Ok(self.reduce(state, entry, context))
369 }
370}
371
372#[derive(Clone, Copy)]
374pub struct AnyFold<L> {
375 predicate: fn(&L) -> bool,
376}
377
378impl<L> std::fmt::Debug for AnyFold<L> {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.debug_struct("AnyFold").finish()
381 }
382}
383
384impl<L> AnyFold<L> {
385 #[must_use]
387 pub fn new(predicate: fn(&L) -> bool) -> Self {
388 Self { predicate }
389 }
390}
391
392impl<L> Fold<L, bool> for AnyFold<L> {
393 #[inline]
394 fn init(&self, _context: &FoldContext) -> bool {
395 false
396 }
397
398 #[inline]
399 fn reduce(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
400 state || (self.predicate)(entry)
401 }
402}
403
404impl<L> TryFold<L, bool> for AnyFold<L> {
405 #[inline]
406 fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
407 Ok(self.reduce(state, entry, context))
408 }
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum CommonFoldState {
414 Count(usize),
416 SumI64(i64),
418 Any(bool),
420}
421
422impl CommonFoldState {
423 #[inline]
424 fn kind(self) -> &'static str {
425 match self {
426 Self::Count(_) => "Count",
427 Self::SumI64(_) => "SumI64",
428 Self::Any(_) => "Any",
429 }
430 }
431}
432
433#[derive(Clone)]
438pub enum CommonFold<L> {
439 Count(CountFold<L>),
441 FilterCount(FilterCountFold<L>),
443 SumI64(SumI64Fold<L>),
445 Any(AnyFold<L>),
447}
448
449impl<L> std::fmt::Debug for CommonFold<L> {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 match self {
452 Self::Count(_) => f.write_str("CommonFold::Count"),
453 Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
454 Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
455 Self::Any(_) => f.write_str("CommonFold::Any"),
456 }
457 }
458}
459
460impl<L> CommonFold<L> {
461 #[must_use]
463 pub fn count() -> Self {
464 Self::Count(CountFold::new())
465 }
466
467 #[must_use]
469 pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
470 Self::FilterCount(FilterCountFold::new(predicate))
471 }
472
473 #[must_use]
475 pub fn sum_i64(project: fn(&L) -> i64) -> Self {
476 Self::SumI64(SumI64Fold::new(project))
477 }
478
479 #[must_use]
481 pub fn any(predicate: fn(&L) -> bool) -> Self {
482 Self::Any(AnyFold::new(predicate))
483 }
484
485 #[inline]
486 fn expected_state_kind(&self) -> &'static str {
487 match self {
488 Self::Count(_) | Self::FilterCount(_) => "Count",
489 Self::SumI64(_) => "SumI64",
490 Self::Any(_) => "Any",
491 }
492 }
493
494 pub fn try_step(
496 &self,
497 state: CommonFoldState,
498 entry: &L,
499 context: &FoldContext,
500 ) -> Result<CommonFoldState, FoldFailure> {
501 match (self, state) {
502 (Self::Count(inner), CommonFoldState::Count(count)) => {
503 Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
504 }
505 (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
506 Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
507 }
508 (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
509 Ok(CommonFoldState::SumI64(inner.reduce(sum, entry, context)))
510 }
511 (Self::Any(inner), CommonFoldState::Any(any)) => {
512 Ok(CommonFoldState::Any(inner.reduce(any, entry, context)))
513 }
514 (kind, state) => Err(FoldFailure::StateMismatch {
515 expected: kind.expected_state_kind(),
516 actual: state.kind(),
517 }),
518 }
519 }
520}
521
522impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
523 #[inline]
524 fn init(&self, _context: &FoldContext) -> CommonFoldState {
525 match self {
526 Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
527 Self::SumI64(_) => CommonFoldState::SumI64(0),
528 Self::Any(_) => CommonFoldState::Any(false),
529 }
530 }
531
532 #[inline]
537 fn reduce(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
538 self.try_step(state, entry, context)
539 .unwrap_or_else(|err| panic!("{err}"))
540 }
541}
542
543impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
544 #[inline]
545 fn try_step(
546 &self,
547 state: CommonFoldState,
548 entry: &L,
549 context: &FoldContext,
550 ) -> Result<CommonFoldState, FoldFailure> {
551 CommonFold::try_step(self, state, entry, context)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn test_fold_fn() {
561 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
562 let entries = [1, 2, 3, 4, 5];
563 let result = counter.derive(entries.iter(), &FoldContext::new());
564 assert_eq!(result.state, 5);
565 assert_eq!(result.entries_processed, 5);
566 }
567
568 #[test]
569 fn test_fold_fn_sum() {
570 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
571 let entries = [1, 2, 3, 4, 5];
572 let result = summer.derive(entries.iter(), &FoldContext::new());
573 assert_eq!(result.state, 15);
574 }
575
576 #[test]
577 fn test_fold_filtered() {
578 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
579 let entries = [1, 2, 3, 4, 5, 6];
580 let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
581 assert_eq!(result.state, 12);
582 assert_eq!(result.entries_processed, 3);
583 }
584
585 #[test]
586 fn test_boxed_fold_derive() {
587 #[allow(clippy::box_default)]
588 let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
589 let entries = [1, 2, 3, 4];
590 let result = counter.derive(entries.iter(), &FoldContext::new());
591 assert_eq!(result.state, 4);
592 }
593
594 #[test]
595 fn test_common_fold_count() {
596 let fold = CommonFold::<i32>::count();
597 let entries = [1, 2, 3];
598 let result = fold.derive(entries.iter(), &FoldContext::new());
599 assert_eq!(result.state, CommonFoldState::Count(3));
600 }
601
602 #[test]
603 fn test_common_fold_sum() {
604 let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
605 let entries = [1, 2, 3];
606 let result = fold.derive(entries.iter(), &FoldContext::new());
607 assert_eq!(result.state, CommonFoldState::SumI64(6));
608 }
609
610 #[test]
611 fn count_folds_saturate_on_overflow() {
612 let context = FoldContext::new();
613 let entry = 1;
614
615 let count = CountFold::new();
616 assert_eq!(count.reduce(usize::MAX, &entry, &context), usize::MAX);
617
618 let filtered = FilterCountFold::new(|_: &i32| true);
619 assert_eq!(filtered.reduce(usize::MAX, &entry, &context), usize::MAX);
620 }
621
622 #[test]
623 fn sum_i64_fold_saturates_on_overflow() {
624 let context = FoldContext::new();
625 let fold = SumI64Fold::new(|value: &i64| *value);
626 assert_eq!(fold.reduce(i64::MAX, &1, &context), i64::MAX);
627 }
628
629 #[test]
630 fn common_fold_try_step_mismatch_returns_error() {
631 let context = FoldContext::new();
632 let fold = CommonFold::<i32>::count();
633 let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
634 assert_eq!(
635 err,
636 FoldFailure::StateMismatch {
637 expected: "Count",
638 actual: "SumI64"
639 }
640 );
641 }
642
643 #[test]
644 fn test_any_fold() {
645 let fold = AnyFold::new(|value: &i32| *value == 7);
646 let entries = [1, 2, 7, 9];
647 let result = fold.derive(entries.iter(), &FoldContext::new());
648 assert!(result.state);
649 }
650
651 #[test]
652 fn fold_is_deterministic_no_timing() {
653 let fold = fold_fn(|_ctx| 0usize, |c, _: &i32, _ctx| c + 1);
655 let entries = [1, 2, 3];
656 let ctx = FoldContext::new();
657 let a = fold.derive(entries.iter(), &ctx);
658 let b = fold.derive(entries.iter(), &ctx);
659 assert_eq!(a, b);
660 }
661}