dfir_lang 0.17.0-alpha.1

Hydro's Dataflow Intermediate Representation (DFIR) implementation
Documentation
use quote::quote_spanned;

use super::{
    DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
    RANGE_1, WriteContextArgs,
};

/// > 1 input stream, 1 output stream
///
/// > Arguments: two arguments, both closures. The first closure is used to create the initial
/// > value for the accumulator, and the second is used to combine new items with the existing
/// > accumulator value. The second closure takes two two arguments: an `&mut Accum` accumulated
/// > value, and an `Item`.
///
/// Akin to Rust's built-in [`fold`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
/// operator, except that it takes the accumulator by `&mut` instead of by value. Folds every item
/// into an accumulator by applying a closure, returning the final result.
///
/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
///
/// `fold` can also be provided with one generic lifetime persistence argument, either
/// `'tick` or `'static`, to specify how data persists. With `'tick`, Items will only be collected
/// within the same tick. With `'static`, the accumulated value will be remembered across ticks and
/// will be aggregated with items arriving in later ticks. When not explicitly specified
/// persistence defaults to `'tick`.
///
/// ```dfir
/// // should print `Reassembled vector [1,2,3,4,5]`
/// source_iter([1,2,3,4,5])
///     -> fold::<'tick>(Vec::new, |accum: &mut Vec<_>, elem| {
///         accum.push(elem);
///     })
///     -> assert_eq([vec![1, 2, 3, 4, 5]]);
/// ```
pub const FOLD: OperatorConstraints = OperatorConstraints {
    name: "fold",
    categories: &[OperatorCategory::Fold],
    hard_range_inn: RANGE_1,
    soft_range_inn: RANGE_1,
    hard_range_out: &(0..=1),
    soft_range_out: &(0..=1),
    num_args: 2,
    persistence_args: &(0..=1),
    type_args: RANGE_0,
    is_external_input: false,
    flo_type: None,
    ports_inn: None,
    ports_out: None,
    input_delaytype_fn: |_| Some(DelayType::Stratum),
    write_fn: |wc @ &WriteContextArgs {
                   root,
                   op_span,
                   work_fn,
                   work_fn_async,
                   ident,
                   is_pull,
                   inputs,
                   outputs,
                   arguments,
                   ..
               },
               diagnostics| {
        let init_fn = &arguments[0];
        let func = &arguments[1];
        let singleton_output_ident = wc.make_ident("singleton_output");

        let initializer_func_ident = wc.make_ident("initializer_func");
        let init = quote_spanned! {op_span=>
            (#initializer_func_ident)()
        };

        let [persistence] = wc.persistence_args_disallow_mutable(diagnostics);

        let input = &inputs[0];
        let accumulator_ident = wc.make_ident("accumulator");
        let item_ident = wc.make_ident("item");

        let write_prologue = quote_spanned! {op_span=>
            #[allow(unused_mut, reason = "for if `Fn` instead of `FnMut`.")]
            let mut #initializer_func_ident = #init_fn;

            #[allow(clippy::redundant_closure_call)]
            let mut #singleton_output_ident = #init;
        };

        let write_tick_end = match persistence {
            Persistence::Tick => quote_spanned! {op_span=>
                #[allow(clippy::redundant_closure_call)]
                { #singleton_output_ident = #init; }
            },
            _ => Default::default(),
        };

        let assign_accum_ident = quote_spanned! {op_span=>
            #[allow(unused_mut)]
            let mut #accumulator_ident = &mut #singleton_output_ident;
        };
        let foreach_body = quote_spanned! {op_span=>
            #[inline(always)]
            fn call_comb_type<Accum, Item>(
                accum: &mut Accum,
                item: Item,
                mut func: impl FnMut(&mut Accum, Item),
            ) {
                (func)(accum, item);
            }
            #[allow(clippy::redundant_closure_call)]
            call_comb_type(&mut *#accumulator_ident, #item_ident, #func);
        };

        let write_iterator = if is_pull {
            quote_spanned! {op_span=>
                #assign_accum_ident

                // Eagerly consume input to ensure updated state.
                {
                    let __fut = #root::dfir_pipes::pull::Pull::for_each(#input, |#item_ident| {
                        #foreach_body
                    });
                    let () = #work_fn_async(__fut).await;
                }

                let #ident = #work_fn(
                    || #root::dfir_pipes::pull::once(
                        ::std::clone::Clone::clone(&*#accumulator_ident)
                    )
                );
            }
        } else {
            assert_eq!(0, outputs.len());
            quote_spanned! {op_span=>
                let #ident = #root::dfir_pipes::push::for_each(|#item_ident| {
                    #assign_accum_ident

                    #foreach_body
                });
            }
        };

        Ok(OperatorWriteOutput {
            write_prologue,
            write_iterator,
            write_iterator_after: Default::default(),
            write_tick_end,
        })
    },
};