use proc_macro2::{Spacing, TokenStream as TokenStream2, TokenTree};
use quote::{ToTokens, quote};
use syn::Pat;
use crate::bind::filtered_set;
use crate::{
IndexBind, Named, RelOp, build_set, oximo_root, parse_named, split_relops, split_top_commas,
};
pub(crate) fn expand(input: TokenStream2) -> syn::Result<TokenStream2> {
let mut parts = split_top_commas(input).into_iter();
let model = parts.next().expect("split always yields at least one segment");
let model: syn::Expr = syn::parse2(model)?;
let spec = parts.next().ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "variable! needs a `name`/bounds spec")
})?;
let root = oximo_root();
let Trailing { domain: domain_ts, lb: kw_lb, ub: kw_ub, initial: kw_initial, fix: kw_fix } =
parse_trailing(parts)?;
let domain_method = match domain_ts {
None => quote!(),
Some(ts) => domain_method(ts, &root)?,
};
let (segs, ops) = split_relops(spec);
let (core, rel_lb, rel_ub) = match (segs.len(), ops.as_slice()) {
(1, []) => (segs[0].clone(), None, None),
(2, [RelOp::Le]) => (segs[0].clone(), None, Some(segs[1].clone())),
(2, [RelOp::Ge]) => (segs[0].clone(), Some(segs[1].clone()), None),
(3, [RelOp::Le, RelOp::Le]) => {
(segs[1].clone(), Some(segs[0].clone()), Some(segs[2].clone()))
}
_ => {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"variable! bounds must be `name`, `name >= lb`, `name <= ub`, or `lb <= name <= ub`",
));
}
};
let lb = merge_bound(rel_lb, kw_lb, "lb")?;
let ub = merge_bound(rel_ub, kw_ub, "ub")?;
if kw_fix.is_some() {
if let Some(b) = lb.as_ref().or(ub.as_ref()) {
return Err(syn::Error::new_spanned(
b,
"`fix` sets both bounds. Do not combine it with `lb`/`ub`",
));
}
}
let lb = lb.map(crate::index::rewrite_index_subscripts);
let ub = ub.map(crate::index::rewrite_index_subscripts);
let Named { name, binds, cond } = parse_named(core)?;
let name_str = name.to_string();
if binds.is_some() {
if let Some(kw) = kw_initial.as_ref().or(kw_fix.as_ref()) {
return Err(syn::Error::new_spanned(
kw,
"`initial`/`fix` is not supported on an indexed family. Use `m.set_initial` / \
`m.fix` per element",
));
}
}
let mut idents = Vec::new();
if let Some(binds) = &binds {
for b in binds {
collect_idents(&b.pat.to_token_stream(), &mut idents);
}
}
let binds_slice = binds.as_deref();
let bound_method = |val: &TokenStream2, kind: BoundKind| -> TokenStream2 {
match binds_slice {
Some(bs) if references_any(val, &idents) => {
let param = bound_closure_param(bs, val);
match kind {
BoundKind::Lb => quote!(.lb_by(move |#param| f64::from(#val))),
BoundKind::Ub => quote!(.ub_by(move |#param| f64::from(#val))),
}
}
_ => match kind {
BoundKind::Lb => quote!(.lb(f64::from(#val))),
BoundKind::Ub => quote!(.ub(f64::from(#val))),
},
}
};
let mut bounds = TokenStream2::new();
if let Some(lb) = &lb {
bounds.extend(bound_method(lb, BoundKind::Lb));
}
if let Some(ub) = &ub {
bounds.extend(bound_method(ub, BoundKind::Ub));
}
let extras = scalar_extras(kw_initial, kw_fix);
let expanded = match binds {
None => quote! {
let #name = (#model).__var(#name_str) #domain_method #bounds #extras .build();
},
Some(binds) => {
let set = build_set(&binds, &root);
let set = filtered_set(set, &binds, cond.as_ref(), &root);
quote! {
let #name = {
let __set = #set;
(#model).__indexed_var(#name_str, &__set) #domain_method #bounds .build()
};
}
}
};
Ok(expanded)
}
fn scalar_extras(initial: Option<TokenStream2>, fix: Option<TokenStream2>) -> TokenStream2 {
let mut extras = TokenStream2::new();
if let Some(init) = initial.map(crate::index::rewrite_index_subscripts) {
extras.extend(quote!(.initial(f64::from(#init))));
}
if let Some(fix) = fix.map(crate::index::rewrite_index_subscripts) {
extras.extend(quote!(.fix(f64::from(#fix))));
}
extras
}
struct Trailing {
domain: Option<TokenStream2>,
lb: Option<TokenStream2>,
ub: Option<TokenStream2>,
initial: Option<TokenStream2>,
fix: Option<TokenStream2>,
}
fn parse_trailing(parts: impl Iterator<Item = TokenStream2>) -> syn::Result<Trailing> {
let mut positional_domain: Option<TokenStream2> = None;
let (mut kw_domain, mut lb, mut ub, mut initial, mut fix) = (None, None, None, None, None);
for seg in parts {
if seg.is_empty() {
continue; }
if let Some((kw, val)) = parse_keyword(&seg) {
let slot = match kw.to_string().as_str() {
"lb" => &mut lb,
"ub" => &mut ub,
"domain" => &mut kw_domain,
"initial" => &mut initial,
"fix" => &mut fix,
_ => unreachable!("parse_keyword only returns known keywords"),
};
if slot.is_some() {
return Err(syn::Error::new_spanned(&kw, format!("`{kw}` specified twice")));
}
*slot = Some(val);
} else if positional_domain.is_some() {
return Err(syn::Error::new_spanned(
&seg,
"unexpected trailing tokens in variable! (only one positional domain token is \
allowed; use `lb =`/`ub =`/`domain =`/`initial =`/`fix =` for the rest)",
));
} else {
positional_domain = Some(seg);
}
}
let domain = match (positional_domain, kw_domain) {
(Some(_), Some(d)) => return Err(syn::Error::new_spanned(d, "domain specified twice")),
(Some(d), None) | (None, Some(d)) => Some(d),
(None, None) => None,
};
Ok(Trailing { domain, lb, ub, initial, fix })
}
fn parse_keyword(seg: &TokenStream2) -> Option<(proc_macro2::Ident, TokenStream2)> {
let tts: Vec<TokenTree> = seg.clone().into_iter().collect();
let TokenTree::Ident(kw) = tts.first()? else {
return None;
};
if !matches!(kw.to_string().as_str(), "lb" | "ub" | "domain" | "initial" | "fix") {
return None;
}
match tts.get(1)? {
TokenTree::Punct(p) if p.as_char() == '=' && p.spacing() == Spacing::Alone => {}
_ => return None,
}
Some((kw.clone(), tts[2..].iter().cloned().collect()))
}
fn merge_bound(
rel: Option<TokenStream2>,
kw: Option<TokenStream2>,
which: &str,
) -> syn::Result<Option<TokenStream2>> {
match (rel, kw) {
(Some(_), Some(kw)) => {
Err(syn::Error::new_spanned(kw, format!("`{which}` specified twice")))
}
(Some(b), None) | (None, Some(b)) => Ok(Some(b)),
(None, None) => Ok(None),
}
}
fn domain_method(ts: TokenStream2, root: &TokenStream2) -> syn::Result<TokenStream2> {
const HELP: &str = "domain must be `Bin`, `Int`, `Real`, `SemiCont(thr)`, or `SemiInt(thr)`";
match syn::parse2::<syn::Expr>(ts)? {
syn::Expr::Path(p) => {
let id = p.path.get_ident().ok_or_else(|| syn::Error::new_spanned(&p, HELP))?;
match id.to_string().as_str() {
"Bin" | "Binary" => Ok(quote!(.binary())),
"Int" | "Integer" => Ok(quote!(.integer())),
"Real" | "Cont" | "Continuous" => Ok(quote!()),
"SemiCont" | "SemiContinuous" | "SemiInt" | "SemiInteger" => {
Err(syn::Error::new_spanned(
id,
format!("`{id}` needs a threshold, e.g. `{id}(1.0)`"),
))
}
_ => Err(syn::Error::new_spanned(id, HELP)),
}
}
syn::Expr::Call(call) => {
let syn::Expr::Path(fp) = &*call.func else {
return Err(syn::Error::new_spanned(&call.func, HELP));
};
let func = fp.path.get_ident().ok_or_else(|| syn::Error::new_spanned(fp, HELP))?;
let variant = match func.to_string().as_str() {
"SemiCont" | "SemiContinuous" => quote!(SemiContinuous),
"SemiInt" | "SemiInteger" => quote!(SemiInteger),
_ => return Err(syn::Error::new_spanned(func, HELP)),
};
let [thr] = call.args.iter().collect::<Vec<_>>()[..] else {
return Err(syn::Error::new_spanned(
&call,
format!("`{func}` takes exactly one threshold argument, e.g. `{func}(1.0)`"),
));
};
Ok(quote! {
.domain(#root::__macro_support::Domain::#variant { threshold: f64::from(#thr) })
})
}
other => Err(syn::Error::new_spanned(other, HELP)),
}
}
#[derive(Copy, Clone)]
enum BoundKind {
Lb,
Ub,
}
fn bound_closure_param(binds: &[IndexBind], bound: &TokenStream2) -> TokenStream2 {
let masked = binds.iter().map(|b| mask_pat(&b.pat, bound));
let pattern = if binds.len() == 1 {
let p = mask_pat(&binds[0].pat, bound);
quote!(#p)
} else {
quote!( (#(#masked),*) )
};
let tys: Option<Vec<&syn::Type>> = binds.iter().map(|b| b.ty.as_ref()).collect();
match tys {
Some(tys) if binds.len() == 1 => {
let ty = tys[0];
quote!(#pattern: #ty)
}
Some(tys) => quote!( #pattern: (#(#tys),*) ),
None => pattern,
}
}
fn mask_pat(pat: &Pat, bound: &TokenStream2) -> TokenStream2 {
match pat {
Pat::Tuple(t) => {
let elems = t.elems.iter().map(|e| mask_pat(e, bound));
quote!( (#(#elems),*) )
}
Pat::Ident(pi) if pi.subpat.is_none() && pi.by_ref.is_none() => {
if references_any(bound, &[pi.ident.to_string()]) { quote!(#pat) } else { quote!(_) }
}
_ => quote!(#pat),
}
}
fn collect_idents(ts: &TokenStream2, out: &mut Vec<String>) {
for tt in ts.clone() {
match tt {
TokenTree::Ident(id) => out.push(id.to_string()),
TokenTree::Group(g) => collect_idents(&g.stream(), out),
_ => {}
}
}
}
fn references_any(ts: &TokenStream2, idents: &[String]) -> bool {
ts.clone().into_iter().any(|tt| match tt {
TokenTree::Ident(id) => idents.contains(&id.to_string()),
TokenTree::Group(g) => references_any(&g.stream(), idents),
_ => false,
})
}