1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use crate::{FoldContext, FoldOutcome};
7
8pub trait Fold<L, S>: Send + Sync {
10 fn init(&self, context: &FoldContext) -> S;
12
13 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S;
15
16 #[inline]
18 fn finalize(&self, state: S, _context: &FoldContext) -> S {
19 state
20 }
21
22 fn derive<'a, I>(&self, entries: I, context: &FoldContext) -> FoldOutcome<S>
24 where
25 Self: Sized,
26 I: IntoIterator<Item = &'a L>,
27 L: 'a,
28 {
29 let mut state = self.init(context);
30 let mut count = 0;
31
32 for entry in entries {
33 state = self.reduce(state, entry, context);
34 count += 1;
35 }
36
37 FoldOutcome::new(self.finalize(state, context), count)
38 }
39
40 fn derive_filtered<'a, I, F>(
42 &self,
43 entries: I,
44 context: &FoldContext,
45 filter: F,
46 ) -> FoldOutcome<S>
47 where
48 Self: Sized,
49 I: IntoIterator<Item = &'a L>,
50 L: 'a,
51 F: Fn(&L) -> bool,
52 {
53 let mut state = self.init(context);
54 let mut count = 0;
55
56 for entry in entries {
57 if filter(entry) {
58 state = self.reduce(state, entry, context);
59 count += 1;
60 }
61 }
62
63 FoldOutcome::new(self.finalize(state, context), count)
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
69pub enum FoldFailure {
70 #[error("Fold state mismatch: expected {expected}, got {actual}")]
72 StateMismatch {
73 expected: &'static str,
75 actual: &'static str,
77 },
78}
79
80pub trait TryFold<L, S>: Fold<L, S> {
82 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure>;
84}
85
86impl<L, S, T> Fold<L, S> for Box<T>
87where
88 T: Fold<L, S> + ?Sized,
89{
90 #[inline]
91 fn init(&self, context: &FoldContext) -> S {
92 (**self).init(context)
93 }
94
95 #[inline]
96 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
97 (**self).reduce(state, entry, context)
98 }
99
100 #[inline]
101 fn finalize(&self, state: S, context: &FoldContext) -> S {
102 (**self).finalize(state, context)
103 }
104}
105
106impl<L, S, T> TryFold<L, S> for Box<T>
107where
108 T: TryFold<L, S> + ?Sized,
109{
110 #[inline]
111 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
112 (**self).try_step(state, entry, context)
113 }
114}
115
116impl<L, S, T> Fold<L, S> for Arc<T>
117where
118 T: Fold<L, S> + ?Sized,
119{
120 #[inline]
121 fn init(&self, context: &FoldContext) -> S {
122 (**self).init(context)
123 }
124
125 #[inline]
126 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
127 (**self).reduce(state, entry, context)
128 }
129
130 #[inline]
131 fn finalize(&self, state: S, context: &FoldContext) -> S {
132 (**self).finalize(state, context)
133 }
134}
135
136impl<L, S, T> TryFold<L, S> for Arc<T>
137where
138 T: TryFold<L, S> + ?Sized,
139{
140 #[inline]
141 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
142 (**self).try_step(state, entry, context)
143 }
144}
145
146pub type BoxedFold<L, S> = Box<dyn Fold<L, S> + Send + Sync>;
148
149pub struct FnFold<L, S, I, St, F>
151where
152 I: Fn(&FoldContext) -> S,
153 St: Fn(S, &L, &FoldContext) -> S,
154 F: Fn(S, &FoldContext) -> S,
155{
156 initial_fn: I,
157 step_fn: St,
158 finalize_fn: F,
159 _phantom: PhantomData<(L, S)>,
160}
161
162impl<L, S, I, St, F> FnFold<L, S, I, St, F>
163where
164 I: Fn(&FoldContext) -> S,
165 St: Fn(S, &L, &FoldContext) -> S,
166 F: Fn(S, &FoldContext) -> S,
167{
168 pub fn new(initial: I, step: St, finalize: F) -> Self {
170 Self {
171 initial_fn: initial,
172 step_fn: step,
173 finalize_fn: finalize,
174 _phantom: PhantomData,
175 }
176 }
177}
178
179impl<L, S, I, St, F> Fold<L, S> for FnFold<L, S, I, St, F>
180where
181 L: Send + Sync,
182 S: Send + Sync,
183 I: Fn(&FoldContext) -> S + Send + Sync,
184 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
185 F: Fn(S, &FoldContext) -> S + Send + Sync,
186{
187 #[inline]
188 fn init(&self, context: &FoldContext) -> S {
189 (self.initial_fn)(context)
190 }
191
192 #[inline]
193 fn reduce(&self, state: S, entry: &L, context: &FoldContext) -> S {
194 (self.step_fn)(state, entry, context)
195 }
196
197 #[inline]
198 fn finalize(&self, state: S, context: &FoldContext) -> S {
199 (self.finalize_fn)(state, context)
200 }
201}
202
203impl<L, S, I, St, F> TryFold<L, S> for FnFold<L, S, I, St, F>
204where
205 L: Send + Sync,
206 S: Send + Sync,
207 I: Fn(&FoldContext) -> S + Send + Sync,
208 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
209 F: Fn(S, &FoldContext) -> S + Send + Sync,
210{
211 #[inline]
212 fn try_step(&self, state: S, entry: &L, context: &FoldContext) -> Result<S, FoldFailure> {
213 Ok((self.step_fn)(state, entry, context))
214 }
215}
216
217pub fn fold_fn<L, S, I, St>(initial: I, step: St) -> impl Fold<L, S>
219where
220 L: Send + Sync,
221 S: Send + Sync,
222 I: Fn(&FoldContext) -> S + Send + Sync,
223 St: Fn(S, &L, &FoldContext) -> S + Send + Sync,
224{
225 FnFold::new(initial, step, |s, _| s)
226}
227
228#[derive(Debug, Clone, Copy)]
230pub struct CountFold<L> {
231 _phantom: PhantomData<fn(&L)>,
232}
233
234impl<L> CountFold<L> {
235 #[must_use]
237 pub fn new() -> Self {
238 Self {
239 _phantom: PhantomData,
240 }
241 }
242}
243
244impl<L> Default for CountFold<L> {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250impl<L> Fold<L, usize> for CountFold<L> {
251 #[inline]
252 fn init(&self, _context: &FoldContext) -> usize {
253 0
254 }
255
256 #[inline]
257 fn reduce(&self, state: usize, _entry: &L, _context: &FoldContext) -> usize {
258 state.saturating_add(1)
259 }
260}
261
262impl<L> TryFold<L, usize> for CountFold<L> {
263 #[inline]
264 fn try_step(
265 &self,
266 state: usize,
267 entry: &L,
268 context: &FoldContext,
269 ) -> Result<usize, FoldFailure> {
270 Ok(self.reduce(state, entry, context))
271 }
272}
273
274#[derive(Clone, Copy)]
276pub struct FilterCountFold<L> {
277 predicate: fn(&L) -> bool,
278}
279
280impl<L> std::fmt::Debug for FilterCountFold<L> {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 f.debug_struct("FilterCountFold").finish()
283 }
284}
285
286impl<L> FilterCountFold<L> {
287 #[must_use]
289 pub fn new(predicate: fn(&L) -> bool) -> Self {
290 Self { predicate }
291 }
292}
293
294impl<L> Fold<L, usize> for FilterCountFold<L> {
295 #[inline]
296 fn init(&self, _context: &FoldContext) -> usize {
297 0
298 }
299
300 #[inline]
301 fn reduce(&self, state: usize, entry: &L, _context: &FoldContext) -> usize {
302 if (self.predicate)(entry) {
303 state.saturating_add(1)
304 } else {
305 state
306 }
307 }
308}
309
310impl<L> TryFold<L, usize> for FilterCountFold<L> {
311 #[inline]
312 fn try_step(
313 &self,
314 state: usize,
315 entry: &L,
316 context: &FoldContext,
317 ) -> Result<usize, FoldFailure> {
318 Ok(self.reduce(state, entry, context))
319 }
320}
321
322#[derive(Clone, Copy)]
324pub struct SumI64Fold<L> {
325 project: fn(&L) -> i64,
326}
327
328impl<L> std::fmt::Debug for SumI64Fold<L> {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.debug_struct("SumI64Fold").finish()
331 }
332}
333
334impl<L> SumI64Fold<L> {
335 #[must_use]
337 pub fn new(project: fn(&L) -> i64) -> Self {
338 Self { project }
339 }
340}
341
342impl<L> Fold<L, i64> for SumI64Fold<L> {
343 #[inline]
344 fn init(&self, _context: &FoldContext) -> i64 {
345 0
346 }
347
348 #[inline]
349 fn reduce(&self, state: i64, entry: &L, _context: &FoldContext) -> i64 {
350 state.saturating_add((self.project)(entry))
351 }
352}
353
354impl<L> TryFold<L, i64> for SumI64Fold<L> {
355 #[inline]
356 fn try_step(&self, state: i64, entry: &L, context: &FoldContext) -> Result<i64, FoldFailure> {
357 Ok(self.reduce(state, entry, context))
358 }
359}
360
361#[derive(Clone, Copy)]
363pub struct AnyFold<L> {
364 predicate: fn(&L) -> bool,
365}
366
367impl<L> std::fmt::Debug for AnyFold<L> {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 f.debug_struct("AnyFold").finish()
370 }
371}
372
373impl<L> AnyFold<L> {
374 #[must_use]
376 pub fn new(predicate: fn(&L) -> bool) -> Self {
377 Self { predicate }
378 }
379}
380
381impl<L> Fold<L, bool> for AnyFold<L> {
382 #[inline]
383 fn init(&self, _context: &FoldContext) -> bool {
384 false
385 }
386
387 #[inline]
388 fn reduce(&self, state: bool, entry: &L, _context: &FoldContext) -> bool {
389 state || (self.predicate)(entry)
390 }
391}
392
393impl<L> TryFold<L, bool> for AnyFold<L> {
394 #[inline]
395 fn try_step(&self, state: bool, entry: &L, context: &FoldContext) -> Result<bool, FoldFailure> {
396 Ok(self.reduce(state, entry, context))
397 }
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
402pub enum CommonFoldState {
403 Count(usize),
405 SumI64(i64),
407 Any(bool),
409}
410
411impl CommonFoldState {
412 #[inline]
413 fn kind(self) -> &'static str {
414 match self {
415 Self::Count(_) => "Count",
416 Self::SumI64(_) => "SumI64",
417 Self::Any(_) => "Any",
418 }
419 }
420}
421
422#[derive(Clone)]
424pub enum CommonFold<L> {
425 Count(CountFold<L>),
427 FilterCount(FilterCountFold<L>),
429 SumI64(SumI64Fold<L>),
431 Any(AnyFold<L>),
433}
434
435impl<L> std::fmt::Debug for CommonFold<L> {
436 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437 match self {
438 Self::Count(_) => f.write_str("CommonFold::Count"),
439 Self::FilterCount(_) => f.write_str("CommonFold::FilterCount"),
440 Self::SumI64(_) => f.write_str("CommonFold::SumI64"),
441 Self::Any(_) => f.write_str("CommonFold::Any"),
442 }
443 }
444}
445
446impl<L> CommonFold<L> {
447 #[must_use]
449 pub fn count() -> Self {
450 Self::Count(CountFold::new())
451 }
452
453 #[must_use]
455 pub fn filter_count(predicate: fn(&L) -> bool) -> Self {
456 Self::FilterCount(FilterCountFold::new(predicate))
457 }
458
459 #[must_use]
461 pub fn sum_i64(project: fn(&L) -> i64) -> Self {
462 Self::SumI64(SumI64Fold::new(project))
463 }
464
465 #[must_use]
467 pub fn any(predicate: fn(&L) -> bool) -> Self {
468 Self::Any(AnyFold::new(predicate))
469 }
470
471 #[inline]
472 fn expected_state_kind(&self) -> &'static str {
473 match self {
474 Self::Count(_) | Self::FilterCount(_) => "Count",
475 Self::SumI64(_) => "SumI64",
476 Self::Any(_) => "Any",
477 }
478 }
479
480 pub fn try_step(
482 &self,
483 state: CommonFoldState,
484 entry: &L,
485 context: &FoldContext,
486 ) -> Result<CommonFoldState, FoldFailure> {
487 match (self, state) {
488 (Self::Count(inner), CommonFoldState::Count(count)) => {
489 Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
490 }
491 (Self::FilterCount(inner), CommonFoldState::Count(count)) => {
492 Ok(CommonFoldState::Count(inner.reduce(count, entry, context)))
493 }
494 (Self::SumI64(inner), CommonFoldState::SumI64(sum)) => {
495 Ok(CommonFoldState::SumI64(inner.reduce(sum, entry, context)))
496 }
497 (Self::Any(inner), CommonFoldState::Any(any)) => {
498 Ok(CommonFoldState::Any(inner.reduce(any, entry, context)))
499 }
500 (kind, state) => Err(FoldFailure::StateMismatch {
501 expected: kind.expected_state_kind(),
502 actual: state.kind(),
503 }),
504 }
505 }
506}
507
508impl<L> Fold<L, CommonFoldState> for CommonFold<L> {
509 #[inline]
510 fn init(&self, _context: &FoldContext) -> CommonFoldState {
511 match self {
512 Self::Count(_) | Self::FilterCount(_) => CommonFoldState::Count(0),
513 Self::SumI64(_) => CommonFoldState::SumI64(0),
514 Self::Any(_) => CommonFoldState::Any(false),
515 }
516 }
517
518 #[inline]
520 fn reduce(&self, state: CommonFoldState, entry: &L, context: &FoldContext) -> CommonFoldState {
521 self.try_step(state, entry, context)
522 .unwrap_or_else(|err| panic!("{err}"))
523 }
524}
525
526impl<L> TryFold<L, CommonFoldState> for CommonFold<L> {
527 #[inline]
528 fn try_step(
529 &self,
530 state: CommonFoldState,
531 entry: &L,
532 context: &FoldContext,
533 ) -> Result<CommonFoldState, FoldFailure> {
534 CommonFold::try_step(self, state, entry, context)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_fold_fn() {
544 let counter = fold_fn(|_ctx| 0usize, |count, _entry: &i32, _ctx| count + 1);
545 let entries = [1, 2, 3, 4, 5];
546 let result = counter.derive(entries.iter(), &FoldContext::new());
547 assert_eq!(result.state, 5);
548 assert_eq!(result.entries_processed, 5);
549 }
550
551 #[test]
552 fn test_fold_fn_sum() {
553 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
554 let entries = [1, 2, 3, 4, 5];
555 let result = summer.derive(entries.iter(), &FoldContext::new());
556 assert_eq!(result.state, 15);
557 }
558
559 #[test]
560 fn test_fold_filtered() {
561 let summer = fold_fn(|_ctx| 0i32, |sum, entry: &i32, _ctx| sum + entry);
562 let entries = [1, 2, 3, 4, 5, 6];
563 let result = summer.derive_filtered(entries.iter(), &FoldContext::new(), |e| *e % 2 == 0);
564 assert_eq!(result.state, 12);
565 assert_eq!(result.entries_processed, 3);
566 }
567
568 #[test]
569 fn test_boxed_fold_derive() {
570 #[allow(clippy::box_default)]
573 let counter: BoxedFold<i32, usize> = Box::new(CountFold::new());
574 let entries = [1, 2, 3, 4];
575 let result = counter.derive(entries.iter(), &FoldContext::new());
576 assert_eq!(result.state, 4);
577 }
578
579 #[test]
580 fn test_common_fold_count() {
581 let fold = CommonFold::<i32>::count();
582 let entries = [1, 2, 3];
583 let result = fold.derive(entries.iter(), &FoldContext::new());
584 assert_eq!(result.state, CommonFoldState::Count(3));
585 }
586
587 #[test]
588 fn test_common_fold_sum() {
589 let fold = CommonFold::<i32>::sum_i64(|value: &i32| *value as i64);
590 let entries = [1, 2, 3];
591 let result = fold.derive(entries.iter(), &FoldContext::new());
592 assert_eq!(result.state, CommonFoldState::SumI64(6));
593 }
594
595 #[test]
596 fn count_folds_saturate_on_overflow() {
597 let context = FoldContext::new();
598 let entry = 1;
599
600 let count = CountFold::new();
601 assert_eq!(count.reduce(usize::MAX, &entry, &context), usize::MAX);
602
603 let filtered = FilterCountFold::new(|_: &i32| true);
604 assert_eq!(filtered.reduce(usize::MAX, &entry, &context), usize::MAX);
605 }
606
607 #[test]
608 fn sum_i64_fold_saturates_on_overflow() {
609 let context = FoldContext::new();
610 let fold = SumI64Fold::new(|value: &i64| *value);
611 assert_eq!(fold.reduce(i64::MAX, &1, &context), i64::MAX);
612 }
613
614 #[test]
615 fn common_fold_try_step_mismatch_returns_error() {
616 let context = FoldContext::new();
617 let fold = CommonFold::<i32>::count();
618 let err = TryFold::try_step(&fold, CommonFoldState::SumI64(0), &1, &context).unwrap_err();
619 assert_eq!(
620 err,
621 FoldFailure::StateMismatch {
622 expected: "Count",
623 actual: "SumI64"
624 }
625 );
626 }
627
628 #[test]
629 fn test_any_fold() {
630 let fold = AnyFold::new(|value: &i32| *value == 7);
631 let entries = [1, 2, 7, 9];
632 let result = fold.derive(entries.iter(), &FoldContext::new());
633 assert!(result.state);
634 }
635
636 #[test]
637 fn fold_is_deterministic_no_timing() {
638 let fold = fold_fn(|_ctx| 0usize, |c, _: &i32, _ctx| c + 1);
640 let entries = [1, 2, 3];
641 let ctx = FoldContext::new();
642 let a = fold.derive(entries.iter(), &ctx);
643 let b = fold.derive(entries.iter(), &ctx);
644 assert_eq!(a, b);
645 }
646}