use crate::{
Batch, ChildCircuit, Circuit, SchedulerError, Scope, Stream, Timestamp,
circuit::{
OwnershipPreference,
circuit_builder::{CircuitBase, NonIterativeCircuit},
operator_traits::{ImportOperator, Operator},
runtime::Consensus,
schedule::{CommitProgress, DynamicScheduler, Executor, Scheduler},
},
operator::Generator,
trace::{
Batch as DynBatch, BatchReader as _, BatchReaderFactories, Spine as DynSpine, Trace as _,
},
typed_batch::{Spine, TypedBatch},
};
use impl_trait_for_tuples::impl_for_tuples;
use std::{
borrow::Cow,
cell::{Cell, RefCell},
collections::BTreeSet,
future::Future,
pin::Pin,
rc::Rc,
};
pub trait NonIncrementalInputStreams<C>
where
C: Circuit,
{
type Imported;
fn import(&self, child_circuit: &mut NonIterativeCircuit<C>) -> Self::Imported;
}
impl<B, C> NonIncrementalInputStreams<C> for Stream<C, B>
where
C: Circuit,
B: Batch,
{
type Imported = Stream<NonIterativeCircuit<C>, TypedBatch<B::Key, B::Val, B::R, B::Inner>>;
fn import(&self, child_circuit: &mut NonIterativeCircuit<C>) -> Self::Imported {
let accumulated = child_circuit
.import_stream(
ImportAccumulator::new(&BatchReaderFactories::new::<B::Key, B::Val, B::R>()),
&self.try_sharded_version().inner(),
)
.typed()
.apply_owned(|spine: Spine<B>| spine.consolidate());
accumulated.mark_sharded_if(self);
accumulated
}
}
impl<C, S> NonIncrementalInputStreams<C> for Vec<S>
where
C: Circuit,
S: NonIncrementalInputStreams<C>,
{
type Imported = Vec<S::Imported>;
fn import(&self, child_circuit: &mut NonIterativeCircuit<C>) -> Self::Imported {
self.iter()
.map(|x| x.import(child_circuit))
.collect::<Vec<_>>()
}
}
#[allow(clippy::unused_unit)]
#[impl_for_tuples(14)]
#[tuple_types_custom_trait_bound(NonIncrementalInputStreams<C>)]
impl<C> NonIncrementalInputStreams<C> for Tuple
where
C: Circuit,
{
for_tuples!( type Imported = ( #( Tuple::Imported ),* ); );
fn import(&self, child_circuit: &mut NonIterativeCircuit<C>) -> Self::Imported {
(for_tuples!( #( self.Tuple.import(child_circuit) ),* ))
}
}
struct NonIterativeExecutor {
scheduler: DynamicScheduler,
flush: Cell<bool>,
flush_consensus: Consensus,
}
impl NonIterativeExecutor {
pub fn new() -> Self {
Self {
scheduler: DynamicScheduler::new(),
flush: Cell::new(false),
flush_consensus: Consensus::new(),
}
}
}
impl<C> Executor<C> for NonIterativeExecutor
where
C: Circuit,
{
fn prepare(
&mut self,
circuit: &C,
nodes: Option<&BTreeSet<crate::circuit::NodeId>>,
) -> Result<(), SchedulerError> {
self.scheduler.prepare(circuit, nodes)
}
fn start_transaction<'a>(
&'a self,
_circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + 'a>> {
unimplemented!()
}
fn start_commit_transaction(&self) -> Result<(), SchedulerError> {
unimplemented!()
}
fn is_commit_complete(&self) -> bool {
unimplemented!()
}
fn commit_progress(&self) -> CommitProgress {
unimplemented!()
}
fn step<'a>(
&'a self,
_circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + 'a>> {
todo!()
}
fn transaction<'a>(
&'a self,
circuit: &'a C,
) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + 'a>> {
let circuit = circuit.clone();
Box::pin(async move {
let local = self.flush.get();
if self.flush_consensus.check(local).await? {
self.flush.set(false);
self.scheduler.transaction(&circuit).await?;
}
Ok(())
})
}
fn flush(&self) {
self.flush.set(true);
}
fn is_flush_complete(&self) -> bool {
!self.flush.get()
}
}
struct ImportAccumulator<B>
where
B: DynBatch,
{
factories: B::Factories,
spine: DynSpine<B>,
}
impl<B> ImportAccumulator<B>
where
B: DynBatch,
{
pub fn new(factories: &B::Factories) -> Self {
Self {
factories: factories.clone(),
spine: DynSpine::<B>::new(factories),
}
}
}
impl<B> Operator for ImportAccumulator<B>
where
B: DynBatch,
{
fn name(&self) -> std::borrow::Cow<'static, str> {
Cow::Borrowed("ImportAccumulator")
}
fn fixedpoint(&self, _scope: crate::circuit::Scope) -> bool {
self.spine.is_empty()
}
}
impl<B> ImportOperator<B, DynSpine<B>> for ImportAccumulator<B>
where
B: DynBatch,
{
fn import(&mut self, val: &B) {
self.spine.insert(val.clone())
}
fn import_owned(&mut self, val: B) {
self.spine.insert(val)
}
async fn eval(&mut self) -> DynSpine<B> {
let mut spine = DynSpine::<B>::new(&self.factories);
std::mem::swap(&mut self.spine, &mut spine);
spine
}
}
impl<P, T> ChildCircuit<P, T>
where
P: 'static,
T: Timestamp,
Self: Circuit,
{
#[track_caller]
pub fn non_incremental<F, I, O>(
&self,
input_streams: &I,
f: F,
) -> Result<Stream<Self, O>, SchedulerError>
where
F: FnOnce(
&NonIterativeCircuit<Self>,
&I::Imported,
) -> Result<Stream<NonIterativeCircuit<Self>, O>, SchedulerError>,
I: NonIncrementalInputStreams<Self>,
O: Clone + Default + std::fmt::Debug + 'static,
{
let output_value: Rc<RefCell<O>> = Rc::new(RefCell::new(O::default()));
let output_value_clone = output_value.clone();
let subcircuit_node_id = self.non_iterative_subcircuit(move |circuit| {
let accumulated = input_streams.import(circuit);
let result = f(circuit, &accumulated)?;
result.apply(move |batch| *output_value_clone.borrow_mut() = batch.clone());
let mut executor = NonIterativeExecutor::new();
executor.prepare(circuit, None)?;
Ok((circuit.node_id(), executor))
})?;
let output = self.add_source(Generator::new(move || {
std::mem::take(&mut *output_value.borrow_mut())
}));
self.add_dependency(subcircuit_node_id, output.local_node_id());
Ok(output)
}
}
impl<C, D> Stream<C, D>
where
D: Clone + 'static,
C: Circuit,
{
#[track_caller]
pub fn delta0_non_iterative<CC>(&self, subcircuit: &CC) -> Stream<CC, D>
where
CC: Circuit<Parent = C>,
{
let delta =
subcircuit.import_stream(Delta0NonIterative::new(), &self.try_sharded_version());
delta.mark_sharded_if(self);
delta
}
}
pub struct Delta0NonIterative<D> {
val: Option<D>,
fixedpoint: bool,
}
impl<D> Delta0NonIterative<D> {
pub fn new() -> Self {
Self {
val: None,
fixedpoint: false,
}
}
}
impl<D> Default for Delta0NonIterative<D> {
fn default() -> Self {
Self::new()
}
}
impl<D> Operator for Delta0NonIterative<D>
where
D: Clone + 'static,
{
fn name(&self) -> Cow<'static, str> {
Cow::from("delta0")
}
fn fixedpoint(&self, scope: Scope) -> bool {
if scope == 0 {
self.fixedpoint
} else {
true
}
}
}
impl<D> ImportOperator<D, D> for Delta0NonIterative<D>
where
D: Clone + 'static,
{
fn import(&mut self, val: &D) {
self.val = Some(val.clone());
self.fixedpoint = false;
}
fn import_owned(&mut self, val: D) {
self.val = Some(val);
self.fixedpoint = false;
}
async fn eval(&mut self) -> D {
if self.val.is_none() {
self.fixedpoint = true;
}
self.val.take().unwrap()
}
fn input_preference(&self) -> OwnershipPreference {
OwnershipPreference::PREFER_OWNED
}
}
#[cfg(test)]
mod test {
use crate::{
OrdZSet, Runtime,
typed_batch::{IndexedZSetReader, SpineSnapshot},
};
#[test]
fn test_non_incremental() {
let (mut dbsp, (input_handle, output_handle)) = Runtime::init_circuit(4, |circuit| {
let (input_stream, input_handle) = circuit.add_input_zset::<i64>();
let differentiated_input = circuit
.non_incremental(&input_stream, |_child_circuit, input_stream| {
Ok(input_stream.differentiate())
})
.unwrap();
let output_handle = differentiated_input.accumulate_output();
Ok((input_handle, output_handle))
})
.unwrap();
dbsp.start_transaction().unwrap();
input_handle.push(5, 1);
dbsp.step().unwrap();
input_handle.push(5, 1);
dbsp.step().unwrap();
dbsp.commit_transaction().unwrap();
let output = SpineSnapshot::<OrdZSet<i64>>::concat(&output_handle.take_from_all())
.iter()
.collect::<Vec<_>>();
debug_assert_eq!(output, vec![(5, (), 2)]);
dbsp.start_transaction().unwrap();
input_handle.push(5, 1);
dbsp.step().unwrap();
input_handle.push(2, 1);
dbsp.step().unwrap();
dbsp.commit_transaction().unwrap();
let output = SpineSnapshot::<OrdZSet<i64>>::concat(&output_handle.take_from_all())
.iter()
.collect::<Vec<_>>();
debug_assert_eq!(output, vec![(2, (), 1), (5, (), -1)]);
}
}