solverforge-macros 0.8.5

Derive macros for SolverForge constraint solver
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Error, Fields};

use crate::attr_parse::has_attribute;

use super::config::{
    parse_config_path, parse_constraints_path, parse_shadow_config, parse_solver_toml_path,
};
use super::list_operations::generate_list_operations;
use super::runtime::{
    generate_runtime_phase_support, generate_runtime_solve_internal, generate_solvable_solution,
};
use super::shadow::generate_shadow_support;
use super::stream_extensions::generate_constraint_stream_extensions;
use super::type_helpers::{extract_collection_inner_type, extract_option_inner_type};

pub(crate) fn expand_derive(input: DeriveInput) -> Result<TokenStream, Error> {
    let name = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let fields = match &input.data {
        Data::Struct(data) => match &data.fields {
            Fields::Named(fields) => &fields.named,
            _ => {
                return Err(Error::new_spanned(
                    &input,
                    "#[planning_solution] requires named fields",
                ))
            }
        },
        _ => {
            return Err(Error::new_spanned(
                &input,
                "#[planning_solution] only works on structs",
            ))
        }
    };

    let score_field = fields
        .iter()
        .find(|f| has_attribute(&f.attrs, "planning_score"))
        .ok_or_else(|| {
            Error::new_spanned(
                &input,
                "#[planning_solution] requires a #[planning_score] field",
            )
        })?;

    let score_field_name = score_field.ident.as_ref().unwrap();
    let score_type = extract_option_inner_type(&score_field.ty)?;

    let entity_descriptors: Vec<_> = fields
        .iter()
        .filter(|f| has_attribute(&f.attrs, "planning_entity_collection"))
        .filter_map(|f| {
            let field_name = f.ident.as_ref().unwrap();
            let field_name_str = field_name.to_string();
            let element_type = extract_collection_inner_type(&f.ty)?;
            Some(quote! {
                .with_entity(#element_type::entity_descriptor(#field_name_str).with_extractor(
                    Box::new(::solverforge::__internal::EntityCollectionExtractor::new(
                        stringify!(#element_type),
                        #field_name_str,
                        |s: &#name| &s.#field_name,
                        |s: &mut #name| &mut s.#field_name,
                    ))
                ))
            })
        })
        .collect();

    let fact_descriptors: Vec<_> = fields
        .iter()
        .filter(|f| has_attribute(&f.attrs, "problem_fact_collection"))
        .filter_map(|f| {
            let field_name = f.ident.as_ref().unwrap();
            let field_name_str = field_name.to_string();
            let element_type = extract_collection_inner_type(&f.ty)?;
            Some(quote! {
                .with_problem_fact(#element_type::problem_fact_descriptor(#field_name_str).with_extractor(
                    Box::new(::solverforge::__internal::EntityCollectionExtractor::new(
                        stringify!(#element_type),
                        #field_name_str,
                        |s: &#name| &s.#field_name,
                        |s: &mut #name| &mut s.#field_name,
                    ))
                ))
            })
        })
        .collect();

    let name_str = name.to_string();
    let score_field_str = score_field_name.to_string();

    let shadow_config = parse_shadow_config(&input.attrs);
    let shadow_support_impl = generate_shadow_support(&shadow_config, fields, name)?;
    let constraints_path = parse_constraints_path(&input.attrs);
    let config_path = parse_config_path(&input.attrs);
    let solver_toml_path = parse_solver_toml_path(&input.attrs);
    let entity_count_arms: Vec<_> = fields
        .iter()
        .filter(|f| has_attribute(&f.attrs, "planning_entity_collection"))
        .enumerate()
        .map(|(idx, f)| {
            let field_name = f.ident.as_ref().unwrap();
            quote! { #idx => this.#field_name.len(), }
        })
        .collect();

    let list_operations = generate_list_operations(fields);
    let runtime_phase_support = generate_runtime_phase_support(fields, &constraints_path, name);
    let runtime_solve_internal =
        generate_runtime_solve_internal(&constraints_path, &config_path, &solver_toml_path);
    let solvable_solution_impl = generate_solvable_solution(name, &constraints_path);

    let stream_extensions = generate_constraint_stream_extensions(fields, name);

    let expanded = quote! {
        impl #impl_generics ::solverforge::__internal::PlanningSolution for #name #ty_generics #where_clause {
            type Score = #score_type;
            fn score(&self) -> Option<Self::Score> { self.#score_field_name.clone() }
            fn set_score(&mut self, score: Option<Self::Score>) { self.#score_field_name = score; }
        }

        impl #impl_generics #name #ty_generics #where_clause {
            pub fn descriptor() -> ::solverforge::__internal::SolutionDescriptor {
                ::solverforge::__internal::SolutionDescriptor::new(
                    #name_str,
                    ::std::any::TypeId::of::<Self>(),
                )
                .with_score_field(#score_field_str)
                #(#entity_descriptors)*
                #(#fact_descriptors)*
            }

            #[inline]
            pub fn entity_count(this: &Self, descriptor_index: usize) -> usize {
                match descriptor_index {
                    #(#entity_count_arms)*
                    _ => 0,
                }
            }

            #list_operations
            #runtime_solve_internal
        }

        #runtime_phase_support
        #shadow_support_impl

        #solvable_solution_impl

        #stream_extensions
    };

    Ok(expanded)
}