use quote::{quote_spanned, ToTokens};
use super::{
DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints,
OperatorInstance, OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};
pub const REDUCE_KEYED: OperatorConstraints = OperatorConstraints {
name: "reduce_keyed",
categories: &[OperatorCategory::KeyedFold],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 1,
persistence_args: &(0..=1),
type_args: &(0..=2),
is_external_input: false,
has_singleton_output: false,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
write_fn: |wc @ &WriteContextArgs {
hydroflow,
context,
op_span,
ident,
inputs,
is_pull,
root,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
persistence_args,
type_args,
..
},
..
},
arguments,
..
},
diagnostics| {
assert!(is_pull);
let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[a] => a,
_ => unreachable!(),
};
let generic_type_args = [
type_args
.first()
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _)),
type_args
.get(1)
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _)),
];
let input = &inputs[0];
let aggfn = &arguments[0];
let (write_prologue, write_iterator, write_iterator_after) = match persistence {
Persistence::Tick => {
let groupbydata_ident = wc.make_ident("groupbydata");
let hashtable_ident = wc.make_ident("hashtable");
(
quote_spanned! {op_span=>
let #groupbydata_ident = #hydroflow.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
},
quote_spanned! {op_span=>
let mut #hashtable_ident = #context.state_ref(#groupbydata_ident).borrow_mut();
{
#[inline(always)]
fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
-> impl ::std::iter::Iterator<Item = (A, B)> { iter }
#[inline(always)]
fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
f(acc, item)
}
for kv in check_input(#input) {
match #hashtable_ident.entry(kv.0) {
::std::collections::hash_map::Entry::Vacant(vacant) => {
vacant.insert(kv.1);
}
::std::collections::hash_map::Entry::Occupied(mut occupied) => {
#[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
}
}
}
}
let #ident = #hashtable_ident.drain();
},
Default::default(),
)
}
Persistence::Static => {
let groupbydata_ident = wc.make_ident("groupbydata");
let hashtable_ident = wc.make_ident("hashtable");
(
quote_spanned! {op_span=>
let #groupbydata_ident = #hydroflow.add_state(::std::cell::RefCell::new(#root::rustc_hash::FxHashMap::<#( #generic_type_args ),*>::default()));
},
quote_spanned! {op_span=>
let mut #hashtable_ident = #context.state_ref(#groupbydata_ident).borrow_mut();
{
#[inline(always)]
fn check_input<Iter: ::std::iter::Iterator<Item = (A, B)>, A: ::std::clone::Clone, B: ::std::clone::Clone>(iter: Iter)
-> impl ::std::iter::Iterator<Item = (A, B)> { iter }
#[inline(always)]
fn call_comb_type<A, O>(acc: &mut A, item: A, f: impl Fn(&mut A, A) -> O) -> O {
f(acc, item)
}
for kv in check_input(#input) {
match #hashtable_ident.entry(kv.0) {
::std::collections::hash_map::Entry::Vacant(vacant) => {
vacant.insert(kv.1);
}
::std::collections::hash_map::Entry::Occupied(mut occupied) => {
#[allow(clippy::redundant_closure_call)] call_comb_type(occupied.get_mut(), kv.1, #aggfn);
}
}
}
}
let #ident = #context.is_first_run_this_tick()
.then_some(#hashtable_ident.iter())
.into_iter()
.flatten()
.map(
#[allow(unknown_lints, suspicious_double_ref_op, clippy::clone_on_copy)]
|(k, v)| (
::std::clone::Clone::clone(k),
::std::clone::Clone::clone(v),
)
);
},
quote_spanned! {op_span=>
#context.schedule_subgraph(#context.current_subgraph(), false);
},
)
}
Persistence::Mutable => {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
"An implementation of 'mutable does not exist",
));
return Err(());
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
};