solverforge-macros 0.15.0

Derive macros for SolverForge constraint solver
Documentation
use quote::{format_ident, quote};
use syn::{GenericArgument, PathArguments, ReturnType, Type, TypeParamBound};

use super::ast::{
    ConstraintFunction, ConstraintProgram, SharedGroupedProgram, TailMember, TerminalConstraint,
    TerminalKind,
};

struct EmittedMember<'a> {
    terminal: Option<&'a TerminalConstraint>,
    set: proc_macro2::TokenStream,
    count: proc_macro2::TokenStream,
    metadata_entry_count: proc_macro2::TokenStream,
}

pub(crate) fn emit(program: ConstraintProgram) -> syn::Result<proc_macro2::TokenStream> {
    match program {
        ConstraintProgram::Passthrough(function) => emit_passthrough(function),
        ConstraintProgram::SharedGrouped(program) => emit_shared_grouped(program),
    }
}

fn emit_passthrough(function: ConstraintFunction) -> syn::Result<proc_macro2::TokenStream> {
    let ConstraintFunction {
        mut item,
        prefix_statements,
        tail,
        ..
    } = function;
    let emitted_prefix_statements = prefix_statements.iter();
    if let Some(tail) = tail.as_ref() {
        *item.block = syn::parse_quote!({
            #(#emitted_prefix_statements)*
            #tail
        });
    } else {
        *item.block = syn::parse_quote!({
            #(#emitted_prefix_statements)*
        });
    }
    Ok(quote! { #item })
}

fn emit_shared_grouped(program: SharedGroupedProgram) -> syn::Result<proc_macro2::TokenStream> {
    let SharedGroupedProgram {
        mut item,
        prefix_statements,
        bindings,
        tail_members,
    } = program;
    let return_types = has_other_member(&tail_members)
        .then(|| constraint_set_return_types(&item.sig.output))
        .transpose()?;
    let emitted_prefix_statements = prefix_statements.iter();
    let (emitted_members, emitted_member_preambles) =
        emit_members(&tail_members, &bindings, return_types.as_ref());
    let member_refs = emitted_members.iter().collect::<Vec<_>>();
    let combined = emit_member_set(&member_refs, &bindings, &bindings);

    *item.block = syn::parse_quote!({
        #(#emitted_prefix_statements)*
        #(#emitted_member_preambles)*
        #combined
    });
    Ok(quote! { #item })
}

fn has_other_member(tail_members: &[TailMember]) -> bool {
    tail_members
        .iter()
        .any(|member| matches!(member, TailMember::Other { .. }))
}

fn emit_members<'a>(
    tail_members: &'a [TailMember],
    bindings: &[String],
    return_types: Option<&(Type, Type)>,
) -> (Vec<EmittedMember<'a>>, Vec<proc_macro2::TokenStream>) {
    let mut initialized_shared_bindings = Vec::new();
    let mut emitted_members = Vec::new();
    let mut preambles = Vec::new();

    for (index, member) in tail_members.iter().enumerate() {
        match member {
            TailMember::Terminal(terminal) => {
                let (member, preamble) = emit_terminal_member(
                    index,
                    terminal,
                    bindings,
                    &mut initialized_shared_bindings,
                );
                emitted_members.push(member);
                preambles.push(preamble);
            }
            TailMember::Other { tokens } => {
                let (member, preamble) = emit_other_member(index, tokens, return_types);
                emitted_members.push(member);
                preambles.push(preamble);
            }
        }
    }

    (emitted_members, preambles)
}

fn emit_terminal_member<'a>(
    index: usize,
    terminal: &'a TerminalConstraint,
    bindings: &[String],
    initialized_shared_bindings: &mut Vec<String>,
) -> (EmittedMember<'a>, proc_macro2::TokenStream) {
    let Some(shared_index) = shared_binding_index(bindings, &terminal.source_binding) else {
        let member_ident = format_ident!("__solverforge_member_{index}");
        let terminal_set = emit_terminal(terminal);
        return (
            EmittedMember {
                terminal: Some(terminal),
                set: quote! { #member_ident },
                count: quote! { 1usize },
                metadata_entry_count: quote! { 1usize },
            },
            quote! {
                let #member_ident = #terminal_set;
            },
        );
    };

    let shared_ident = shared_binding_ident(shared_index);
    let chain_source = if initialized_shared_bindings
        .iter()
        .any(|binding| binding == &terminal.source_binding)
    {
        quote! { #shared_ident }
    } else {
        initialized_shared_bindings.push(terminal.source_binding.clone());
        let binding_ident = format_ident!("{}", terminal.source_binding);
        quote! { #binding_ident }
    };
    let terminal_chain = emit_terminal_chain(chain_source, terminal);

    (
        EmittedMember {
            terminal: Some(terminal),
            set: quote! { #shared_ident },
            count: quote! { 1usize },
            metadata_entry_count: quote! { 1usize },
        },
        quote! {
            let #shared_ident = #terminal_chain;
        },
    )
}

fn emit_other_member<'a>(
    index: usize,
    tokens: &proc_macro2::TokenStream,
    return_types: Option<&(Type, Type)>,
) -> (EmittedMember<'a>, proc_macro2::TokenStream) {
    let (solution_type, score_type) =
        return_types.expect("opaque constraint members require parsed return ConstraintSet types");
    let member_ident = format_ident!("__solverforge_member_{index}");
    let count_ident = format_ident!("__solverforge_member_count_{index}");
    let metadata_entry_count_ident =
        format_ident!("__solverforge_member_metadata_entry_count_{index}");
    (
        EmittedMember {
            terminal: None,
            set: quote! { #member_ident },
            count: quote! { #count_ident },
            metadata_entry_count: quote! { #metadata_entry_count_ident },
        },
        quote! {
            let #member_ident = #tokens;
            let #count_ident =
                <_ as ::solverforge::__internal::ConstraintSet<#solution_type, #score_type>>::constraint_count(&#member_ident);
            let #metadata_entry_count_ident =
                <_ as ::solverforge::__internal::ConstraintSet<#solution_type, #score_type>>::constraint_metadata_entries(&#member_ident).len();
        },
    )
}

fn emit_member_set(
    members: &[&EmittedMember<'_>],
    bindings: &[String],
    all_bindings: &[String],
) -> proc_macro2::TokenStream {
    let Some((binding, remaining_bindings)) = bindings.split_first() else {
        let sets = members.iter().map(|member| &member.set).collect::<Vec<_>>();
        return combine_constraint_sets(sets);
    };

    let left = emit_shared_binding_set(binding, all_bindings);
    let mut right_members = Vec::new();
    let mut order = Vec::new();
    for &member in members {
        if member_matches_binding(member, binding) {
            order.push(quote! { ::solverforge::__internal::ConstraintSetSource::Left });
        } else {
            let count = &member.count;
            let metadata_entry_count = &member.metadata_entry_count;
            order.push(quote! {
                ::solverforge::__internal::ConstraintSetSource::Right {
                    constraint_count: #count,
                    metadata_entry_count: #metadata_entry_count,
                }
            });
            right_members.push(member);
        }
    }

    if right_members.is_empty() {
        return left;
    }

    let right = emit_member_set(&right_members, remaining_bindings, all_bindings);
    quote! {{
        let mut __solverforge_order = ::std::vec::Vec::new();
        #(__solverforge_order.push(#order);)*
        ::solverforge::__internal::OrderedConstraintSetChain::new(
            #left,
            #right,
            __solverforge_order,
        )
    }}
}

fn emit_shared_binding_set(binding: &str, all_bindings: &[String]) -> proc_macro2::TokenStream {
    let shared_index = shared_binding_index(all_bindings, binding)
        .expect("shared binding should be present in the planned binding list");
    let shared_ident = shared_binding_ident(shared_index);
    quote! { #shared_ident }
}

fn member_matches_binding(member: &EmittedMember<'_>, binding: &str) -> bool {
    matches!(member.terminal, Some(terminal) if terminal.source_binding.as_str() == binding)
}

fn combine_constraint_sets(sets: Vec<&proc_macro2::TokenStream>) -> proc_macro2::TokenStream {
    let mut sets = sets.into_iter();
    let Some(combined) = sets.next() else {
        return quote! { () };
    };
    let mut combined = quote! { #combined };
    for set in sets {
        combined = quote! {
            ::solverforge::__internal::ConstraintSetChain::new(#combined, #set)
        };
    }
    combined
}

fn emit_terminal(terminal: &TerminalConstraint) -> proc_macro2::TokenStream {
    let binding = format_ident!("{}", terminal.source_binding);
    emit_terminal_chain(quote! { #binding }, terminal)
}

fn shared_binding_index(bindings: &[String], binding: &str) -> Option<usize> {
    bindings.iter().position(|candidate| candidate == binding)
}

fn shared_binding_ident(index: usize) -> proc_macro2::Ident {
    format_ident!("__solverforge_shared_{index}")
}

fn emit_terminal_chain(
    chain: proc_macro2::TokenStream,
    terminal: &TerminalConstraint,
) -> proc_macro2::TokenStream {
    debug_assert_eq!(terminal.kind, TerminalKind::GroupedScore);
    let method = terminal.impact.method_ident();
    let weight = &terminal.weight;
    let name = &terminal.name;
    let _order = terminal.order;
    quote! {
        #chain.#method(#weight).named(#name)
    }
}

fn constraint_set_return_types(output: &ReturnType) -> syn::Result<(Type, Type)> {
    let ReturnType::Type(_, return_type) = output else {
        return Err(syn::Error::new_spanned(
            quote! { #output },
            "#[solverforge_constraints] shared functions must return impl ConstraintSet<S, Sc>",
        ));
    };
    let return_type = peel_parens(return_type.as_ref());
    let Type::ImplTrait(impl_trait) = return_type else {
        return Err(syn::Error::new_spanned(
            return_type,
            "#[solverforge_constraints] shared functions must return impl ConstraintSet<S, Sc>",
        ));
    };

    let mut fallback = None;
    for bound in &impl_trait.bounds {
        let Some((types, is_constraint_set)) = constraint_set_bound_types(bound) else {
            continue;
        };
        if is_constraint_set {
            return Ok(types);
        }
        fallback.get_or_insert(types);
    }

    fallback.ok_or_else(|| {
        syn::Error::new_spanned(
            return_type,
            "#[solverforge_constraints] shared functions must return impl ConstraintSet<S, Sc>",
        )
    })
}

fn peel_parens(ty: &Type) -> &Type {
    match ty {
        Type::Paren(paren) => peel_parens(paren.elem.as_ref()),
        _ => ty,
    }
}

fn constraint_set_bound_types(bound: &TypeParamBound) -> Option<((Type, Type), bool)> {
    let TypeParamBound::Trait(trait_bound) = bound else {
        return None;
    };
    let segment = trait_bound.path.segments.last()?;
    let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
        return None;
    };

    let mut type_arguments = arguments.args.iter().filter_map(|argument| match argument {
        GenericArgument::Type(ty) => Some(ty.clone()),
        _ => None,
    });
    let solution_type = type_arguments.next()?;
    let score_type = type_arguments.next()?;
    Some((
        (solution_type, score_type),
        segment.ident == "ConstraintSet",
    ))
}