solverforge-macros 0.9.0

Derive macros for SolverForge constraint solver
Documentation
fn generate_scalar_runtime_setup(
    fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
    solution_name: &Ident,
) -> TokenStream {
    let entity_fields: Vec<_> = fields
        .iter()
        .filter(|f| has_attribute(&f.attrs, "planning_entity_collection"))
        .enumerate()
        .filter_map(|(idx, field)| {
            let field_name = field.ident.as_ref()?;
            let field_type = extract_collection_inner_type(&field.ty)?;
            let syn::Type::Path(type_path) = field_type else {
                return None;
            };
            let _ = type_path.path.segments.last()?;
            Some((idx, field_name, field_type))
        })
        .collect();

    let provider_fields: Vec<_> = fields
        .iter()
        .filter(|f| {
            has_attribute(&f.attrs, "planning_entity_collection")
                || has_attribute(&f.attrs, "problem_fact_collection")
        })
        .filter_map(|field| field.ident.as_ref())
        .collect();

    let provider_names: Vec<_> = provider_fields
        .iter()
        .map(|field_name| field_name.to_string())
        .collect();
    let provider_count_arms: Vec<_> = provider_fields
        .iter()
        .enumerate()
        .map(|(idx, field_name)| {
            quote! { #idx => solution.#field_name.len(), }
        })
        .collect();

    let entity_helpers: Vec<_> = entity_fields
        .iter()
        .map(|(_, field_name, field_type)| {
            let count_fn_ident = format_ident!("__solverforge_scalar_count_{}", field_name);
            let getter_ident = format_ident!("__solverforge_scalar_get_{}", field_name);
            let setter_ident = format_ident!("__solverforge_scalar_set_{}", field_name);
            let values_ident = format_ident!("__solverforge_scalar_values_{}", field_name);
            quote! {
                fn #count_fn_ident(solution: &#solution_name) -> usize {
                    solution.#field_name.len()
                }

                fn #getter_ident(
                    solution: &#solution_name,
                    entity_index: usize,
                    variable_index: usize,
                ) -> ::core::option::Option<usize> {
                    <#field_type>::__solverforge_scalar_get_by_index(
                        &solution.#field_name[entity_index],
                        variable_index,
                    )
                }

                fn #setter_ident(
                    solution: &mut #solution_name,
                    entity_index: usize,
                    variable_index: usize,
                    value: ::core::option::Option<usize>,
                ) {
                    <#field_type>::__solverforge_scalar_set_by_index(
                        &mut solution.#field_name[entity_index],
                        variable_index,
                        value,
                    );
                }

                fn #values_ident(
                    solution: &#solution_name,
                    entity_index: usize,
                    variable_index: usize,
                ) -> &[usize] {
                    <#field_type>::__solverforge_scalar_values_by_index(
                        &solution.#field_name[entity_index],
                        variable_index,
                    )
                }
            }
        })
        .collect();

    let scalar_context_pushes: Vec<_> = entity_fields
        .iter()
        .map(|(descriptor_index, field_name, field_type)| {
            let entity_count_fn_ident = format_ident!("__solverforge_scalar_count_{}", field_name);
            let getter_ident = format_ident!("__solverforge_scalar_get_{}", field_name);
            let setter_ident = format_ident!("__solverforge_scalar_set_{}", field_name);
            let values_ident = format_ident!("__solverforge_scalar_values_{}", field_name);
            quote! {
                {
                    let __solverforge_descriptor_index = #descriptor_index;
                    let __solverforge_entity_descriptor = descriptor
                        .entity_descriptors
                        .get(__solverforge_descriptor_index)
                        .expect("entity descriptor missing for scalar runtime setup");
                    for __solverforge_variable_index in 0..<#field_type>::__solverforge_scalar_variable_count() {
                        let Some(__solverforge_variable_name) =
                            <#field_type>::__solverforge_scalar_variable_name_by_index(
                                __solverforge_variable_index,
                            )
                        else {
                            continue;
                        };
                        let Some(__solverforge_variable_descriptor) = __solverforge_entity_descriptor
                            .genuine_variable_descriptors()
                            .find(|variable| {
                                variable.name == __solverforge_variable_name
                                    && variable.usize_getter.is_some()
                                    && variable.usize_setter.is_some()
                            })
                        else {
                            continue;
                        };

                        let __solverforge_value_source = if __solverforge_variable_descriptor
                            .entity_value_provider
                            .is_some()
                            || <#field_type>::__solverforge_scalar_provider_is_entity_field_by_index(
                                __solverforge_variable_index,
                            )
                        {
                            ::solverforge::__internal::ValueSource::EntitySlice {
                                values_for_entity: #values_ident,
                            }
                        } else {
                            match &__solverforge_variable_descriptor.value_range_type {
                                ::solverforge::__internal::ValueRangeType::CountableRange { from, to } => {
                                    let from = usize::try_from(*from).expect(
                                        "countable_range start must be non-negative for canonical scalar solving",
                                    );
                                    let to = usize::try_from(*to).expect(
                                        "countable_range end must be non-negative for canonical scalar solving",
                                    );
                                    ::solverforge::__internal::ValueSource::CountableRange { from, to }
                                }
                                _ => {
                                    if let Some(provider_name) =
                                        __solverforge_variable_descriptor.value_range_provider
                                    {
                                        let provider_index = __solverforge_scalar_provider_fields
                                            .iter()
                                            .position(|field| *field == provider_name)
                                            .expect("scalar value range provider must be a solution collection");
                                        ::solverforge::__internal::ValueSource::SolutionCount {
                                            count_fn: __solverforge_scalar_collection_count,
                                            provider_index,
                                        }
                                    } else {
                                        ::solverforge::__internal::ValueSource::Empty
                                    }
                                }
                            }
                        };

                        let __solverforge_context =
                            ::solverforge::__internal::ScalarVariableContext::new(
                                __solverforge_descriptor_index,
                                __solverforge_variable_index,
                                __solverforge_entity_descriptor.type_name,
                                #entity_count_fn_ident,
                                __solverforge_variable_name,
                                #getter_ident,
                                #setter_ident,
                                __solverforge_value_source,
                                __solverforge_variable_descriptor.allows_unassigned,
                            );
                        let __solverforge_context =
                            <#solution_name as ::solverforge::__internal::PlanningModelSupport>::attach_runtime_scalar_hooks(
                                __solverforge_context,
                            );
                        __solverforge_variables.push(
                            ::solverforge::__internal::VariableContext::Scalar(
                                __solverforge_context,
                            )
                        );
                    }
                }
            }
        })
        .collect();

    quote! {
        let mut __solverforge_variables = ::std::vec::Vec::new();
        let __solverforge_scalar_provider_fields: &[&str] = &[
            #(#provider_names),*
        ];
        #(#entity_helpers)*
        fn __solverforge_scalar_collection_count(
            solution: &#solution_name,
            provider_index: usize,
        ) -> usize {
            match provider_index {
                #(#provider_count_arms)*
                _ => 0,
            }
        }
        #(#scalar_context_pushes)*
    }
}