use crate::circuit::checkpointer::Checkpoint;
use crate::{
Circuit, NumEntries, RootCircuit, Stream,
circuit::OwnershipPreference,
operator::{Z1, z1::DelayedId},
};
use size_of::SizeOf;
impl<T> Stream<RootCircuit, T>
where
T: Clone + 'static,
{
#[track_caller]
pub fn stream_fold<A, F>(&self, init: A, fold_func: F) -> Stream<RootCircuit, A>
where
F: Fn(A, &T) -> A + 'static,
A: Checkpoint + Eq + Clone + SizeOf + NumEntries + 'static,
{
self.stream_fold_persistent(None, init, fold_func)
}
#[track_caller]
pub fn stream_fold_persistent<A, F>(
&self,
persistent_id: Option<&str>,
init: A,
fold_func: F,
) -> Stream<RootCircuit, A>
where
F: Fn(A, &T) -> A + 'static,
A: Checkpoint + Eq + Clone + SizeOf + NumEntries + 'static,
{
let (prev_accumulator, feedback) = self.circuit().add_feedback_persistent(
persistent_id.map(|name| format!("{name}.fold")).as_deref(),
Z1::new(init),
);
let new_accumulator = prev_accumulator.apply2_owned(self, fold_func);
feedback
.connect_with_preference(&new_accumulator, OwnershipPreference::STRONGLY_PREFER_OWNED);
self.circuit().cache_insert(
DelayedId::new(new_accumulator.stream_id()),
prev_accumulator,
);
new_accumulator
}
}