use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Expr, Token};
use crate::IndexBind;
use crate::oximo_root;
struct SumInput {
body: Expr,
binds: Vec<IndexBind>,
cond: Option<Expr>,
}
impl Parse for SumInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let body = input.parse::<Expr>()?;
input.parse::<Token![for]>()?;
let binds = Punctuated::<IndexBind, Token![,]>::parse_separated_nonempty(input)?;
let cond = if input.peek(Token![if]) {
input.parse::<Token![if]>()?;
Some(input.parse::<Expr>()?)
} else {
None
};
if !input.is_empty() {
return Err(input.error("unexpected tokens after `sum!` clauses"));
}
Ok(Self { body, binds: binds.into_iter().collect(), cond })
}
}
pub(crate) fn expand(input: TokenStream2) -> syn::Result<TokenStream2> {
let input = crate::index::rewrite_index_subscripts(input);
let SumInput { body, binds, cond } = syn::parse2(input)?;
let root = oximo_root();
let Some(cond) = cond else {
let mut expr = quote!(#body);
for b in binds.iter().rev() {
let param = b.closure_param();
let domain = &b.domain;
expr = quote!( #root::__macro_support::sum_over(&(#domain), |#param| #expr) );
}
return Ok(expr);
};
let mut inner = quote! {
if #cond {
__terms.push(#body);
}
};
for b in binds.iter().rev() {
let pat = &b.pat;
let domain = &b.domain;
let keys = if let Some(ty) = b.keys_of_type() {
quote!( #root::__macro_support::keys_of::<#ty, _>(&(#domain)) )
} else {
quote!( #root::__macro_support::keys_of(&(#domain)) )
};
inner = quote! {
for #pat in #keys {
#inner
}
};
}
Ok(quote! {{
let mut __terms = ::std::vec::Vec::new();
#inner
::core::assert!(
!__terms.is_empty(),
"sum! with an `if` filter produced no terms"
);
#root::__macro_support::sum_terms(__terms)
}})
}