use proc_macro2::{Ident, TokenStream};
use quote::{quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{parse_quote, Expr, ExprCall};
use super::{
DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints,
OperatorInstance, OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};
pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
name: "join_fused",
categories: &[OperatorCategory::MultiIn],
hard_range_inn: &(2..=2),
soft_range_inn: &(2..=2),
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 2,
persistence_args: &(0..=2),
type_args: RANGE_0,
is_external_input: false,
has_singleton_output: false,
ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
ports_out: None,
input_delaytype_fn: |_| Some(DelayType::Stratum),
write_fn: |wc @ &WriteContextArgs {
context,
op_span,
ident,
inputs,
is_pull,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
persistence_args, ..
},
..
},
arguments,
..
},
diagnostics| {
assert!(is_pull);
let persistences = parse_persistences(persistence_args);
let lhs_join_options =
parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
let rhs_join_options =
parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;
let (lhs_joindata_ident, lhs_borrow_ident, lhs_prologue, lhs_borrow) =
make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
.map_err(|err| diagnostics.push(err))?;
let (rhs_joindata_ident, rhs_borrow_ident, rhs_prologue, rhs_borrow) =
make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
.map_err(|err| diagnostics.push(err))?;
let write_prologue = quote_spanned! {op_span=>
#lhs_prologue
#rhs_prologue
};
let lhs = &inputs[0];
let rhs = &inputs[1];
let arg0_span = arguments[0].span();
let arg1_span = arguments[1].span();
let lhs_tokens = match lhs_join_options {
JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
#lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
},
JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
#lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
},
JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
#lhs_borrow.reduce_into(#lhs, #lhs_reduce);
},
};
let rhs_tokens = match rhs_join_options {
JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
#rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
},
JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
#rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
},
JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
#rhs_borrow.reduce_into(#rhs, #rhs_reduce);
},
};
let write_iterator = quote_spanned! {op_span=>
let mut #lhs_borrow_ident = #context.state_ref(#lhs_joindata_ident).borrow_mut();
let mut #rhs_borrow_ident = #context.state_ref(#rhs_joindata_ident).borrow_mut();
let #ident = {
#lhs_tokens
#rhs_tokens
#[allow(clippy::clone_on_copy)]
#[allow(suspicious_double_ref_op)]
#rhs_borrow
.table
.iter()
.filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
};
};
let write_iterator_after =
if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
quote_spanned! {op_span=>
#context.schedule_subgraph(#context.current_subgraph(), false);
}
} else {
quote_spanned! {op_span=>}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
};
pub(crate) enum JoinOptions<'a> {
FoldFrom(&'a Expr, &'a Expr),
Fold(&'a Expr, &'a Expr),
Reduce(&'a Expr),
}
pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
let Expr::Call(ExprCall {
attrs: _,
func,
paren_token: _,
args,
}) = arg
else {
return Err(Diagnostic::spanned(
arg.span(),
Level::Error,
format!("Argument must be a function call: {arg:?}"),
));
};
let mut elems = args.iter();
let func_name = func.to_token_stream().to_string();
match func_name.as_str() {
"Fold" => match (elems.next(), elems.next()) {
(Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
_ => {
Err(Diagnostic::spanned(
args.span(),
Level::Error,
format!("Fold requires two arguments, first is the default function, second is the folding function: {func:?}"),
))
}
},
"FoldFrom" => match (elems.next(), elems.next()) {
(Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
_ => {
Err(Diagnostic::spanned(
args.span(),
Level::Error,
format!("FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"),
))
}
},
"Reduce" => match elems.next() {
Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
_ => Err(Diagnostic::spanned(
args.span(),
Level::Error,
format!("Reduce requires one argument, the reducing function: {func:?}"),
)),
},
_ => Err(Diagnostic::spanned(
func.span(),
Level::Error,
format!("Unknown summarizing function: {func:?}"),
)),
}
}
pub(crate) fn make_joindata(
wc: &WriteContextArgs,
persistence: Persistence,
join_options: &JoinOptions<'_>,
side: &str,
) -> Result<(Ident, Ident, TokenStream, TokenStream), Diagnostic> {
let joindata_ident = wc.make_ident(format!("joindata_{}", side));
let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));
let context = wc.context;
let hydroflow = wc.hydroflow;
let root = wc.root;
let op_span = wc.op_span;
let join_type = match *join_options {
JoinOptions::FoldFrom(_, _) => {
quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
}
JoinOptions::Fold(_, _) => {
quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
}
JoinOptions::Reduce(_) => {
quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
}
};
let (prologue, borrow) = match persistence {
Persistence::Tick => (
quote_spanned! {op_span=>
let #joindata_ident = #hydroflow.add_state(std::cell::RefCell::new(
#root::util::monotonic_map::MonotonicMap::new_init(
#join_type::default()
)
));
},
quote_spanned! {op_span=>
#borrow_ident.get_mut_clear(#context.current_tick())
},
),
Persistence::Static => (
quote_spanned! {op_span=>
let #joindata_ident = #hydroflow.add_state(std::cell::RefCell::new(
#join_type::default()
));
},
quote_spanned! {op_span=>
#borrow_ident
},
),
Persistence::Mutable => {
return Err(Diagnostic::spanned(
op_span,
Level::Error,
"An implementation of 'mutable does not exist",
));
}
};
Ok((joindata_ident, borrow_ident, prologue, borrow))
}
pub(crate) fn parse_persistences(persistences: &[Persistence]) -> [Persistence; 2] {
match persistences {
[] => [Persistence::Tick, Persistence::Tick],
[a] => [*a, *a],
[a, b] => [*a, *b],
_ => panic!("Too many persistences: {persistences:?}"),
}
}