1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use quote::quote_spanned;
use super::{
DelayType, OperatorCategory, OperatorConstraints, OperatorWriteOutput, Persistence, RANGE_0,
RANGE_1, WriteContextArgs,
};
/// > 1 input stream, 1 output stream
///
/// > Arguments: two arguments, both closures. The first closure is used to create the initial
/// > value for the accumulator, and the second is used to combine new items with the existing
/// > accumulator value. The second closure takes two two arguments: an `&mut Accum` accumulated
/// > value, and an `Item`.
///
/// Akin to Rust's built-in [`fold`](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.fold)
/// operator, except that it takes the accumulator by `&mut` instead of by value. Folds every item
/// into an accumulator by applying a closure, returning the final result.
///
/// > Note: The closures have access to the [`context` object](surface_flows.mdx#the-context-object).
///
/// `fold` can also be provided with one generic lifetime persistence argument, either
/// `'tick` or `'static`, to specify how data persists. With `'tick`, Items will only be collected
/// within the same tick. With `'static`, the accumulated value will be remembered across ticks and
/// will be aggregated with items arriving in later ticks. When not explicitly specified
/// persistence defaults to `'tick`.
///
/// ```dfir
/// // should print `Reassembled vector [1,2,3,4,5]`
/// source_iter([1,2,3,4,5])
/// -> fold::<'tick>(Vec::new, |accum: &mut Vec<_>, elem| {
/// accum.push(elem);
/// })
/// -> assert_eq([vec![1, 2, 3, 4, 5]]);
/// ```
pub const FOLD: OperatorConstraints = OperatorConstraints {
name: "fold",
categories: &[OperatorCategory::Fold],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: &(0..=1),
soft_range_out: &(0..=1),
num_args: 2,
persistence_args: &(0..=1),
type_args: RANGE_0,
is_external_input: false,
has_singleton_output: true,
flo_type: None,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
write_fn: |wc @ &WriteContextArgs {
root,
context,
df_ident,
op_span,
work_fn,
work_fn_async,
ident,
is_pull,
inputs,
outputs,
singleton_output_ident,
arguments,
..
},
diagnostics| {
let init_fn = &arguments[0];
let func = &arguments[1];
let initializer_func_ident = wc.make_ident("initializer_func");
let init = quote_spanned! {op_span=>
(#initializer_func_ident)()
};
let [persistence] = wc.persistence_args_disallow_mutable(diagnostics);
let input = &inputs[0];
let accumulator_ident = wc.make_ident("accumulator");
let item_ident = wc.make_ident("item");
let write_prologue = quote_spanned! {op_span=>
#[allow(unused_mut, reason = "for if `Fn` instead of `FnMut`.")]
let mut #initializer_func_ident = #init_fn;
#[allow(clippy::redundant_closure_call)]
let #singleton_output_ident = #df_ident.add_state(::std::cell::RefCell::new(#init));
};
let write_prologue_after = wc
.persistence_as_state_lifespan(persistence)
.map(|lifespan| quote_spanned! {op_span=>
#[allow(clippy::redundant_closure_call)]
#df_ident.set_state_lifespan_hook(
#singleton_output_ident, #lifespan, move |rcell| { rcell.replace(#init); },
);
}).unwrap_or_default();
let assign_accum_ident = quote_spanned! {op_span=>
#[allow(unused_mut)]
let mut #accumulator_ident = unsafe {
// SAFETY: handle from `#df_ident.add_state(..)`.
#context.state_ref_unchecked(#singleton_output_ident)
}.borrow_mut();
};
let foreach_body = quote_spanned! {op_span=>
#[inline(always)]
fn call_comb_type<Accum, Item>(
accum: &mut Accum,
item: Item,
mut func: impl FnMut(&mut Accum, Item),
) {
(func)(accum, item);
}
#[allow(clippy::redundant_closure_call)]
call_comb_type(&mut *#accumulator_ident, #item_ident, #func);
};
let write_iterator = if is_pull {
quote_spanned! {op_span=>
#assign_accum_ident
// Eagerly consume input to ensure updated state.
{
let __fut = #root::dfir_pipes::pull::Pull::for_each(#input, |#item_ident| {
#foreach_body
});
let () = #work_fn_async(__fut).await;
}
let #ident = #work_fn(
|| #root::dfir_pipes::pull::once(
::std::clone::Clone::clone(&*#accumulator_ident)
)
);
}
} else {
assert_eq!(0, outputs.len());
quote_spanned! {op_span=>
let #ident = #root::dfir_pipes::push::for_each(|#item_ident| {
#assign_accum_ident
#foreach_body
});
}
};
let write_iterator_after = if let Persistence::Static | Persistence::Tick = persistence {
quote_spanned! {op_span=>
#context.schedule_subgraph(#context.current_subgraph(), false);
}
} else {
Default::default()
};
Ok(OperatorWriteOutput {
write_prologue,
write_prologue_after,
write_iterator,
write_iterator_after,
})
},
};