solverforge-macros 0.8.5

Derive macros for SolverForge constraint solver
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::Error;
use syn::Ident;

use super::config::ShadowConfig;
use super::list_runtime::{
    find_list_owner_config, find_list_runtime_config, shadow_updates_requested,
    ListElementCollectionKind,
};

pub(super) fn generate_shadow_support(
    config: &ShadowConfig,
    fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
    solution_name: &Ident,
) -> Result<TokenStream, Error> {
    if !shadow_updates_requested(config) {
        return Ok(quote! {
            impl ::solverforge::__internal::ShadowVariableSupport for #solution_name {
                #[inline]
                fn update_entity_shadows(&mut self, _entity_idx: usize) {}
            }
        });
    }

    let Some(list_owner) = find_list_owner_config(config, fields)? else {
        return Err(Error::new(
            proc_macro2::Span::call_site(),
            "#[shadow_variable_updates(...)] requires `list_owner = \"entity_collection_field\"` when shadow updates are configured",
        ));
    };

    let Some(runtime_config) = find_list_runtime_config(fields)? else {
        return Err(Error::new(
            proc_macro2::Span::call_site(),
            format!(
                "planning solution with list owner `{}` requires a `#[planning_entity_collection]` or `#[problem_fact_collection]` field named `{}`",
                list_owner.field_ident,
                list_owner.field_ident,
            ),
        ));
    };
    if runtime_config.list_owner.field_ident != list_owner.field_ident {
        return Err(Error::new(
            proc_macro2::Span::call_site(),
            format!(
                "#[shadow_variable_updates(list_owner = \"{}\")] does not match the inferred list owner `{}`",
                list_owner.field_ident,
                runtime_config.list_owner.field_ident,
            ),
        ));
    }
    if runtime_config.element_collection.kind == ListElementCollectionKind::LegacyListCollection {
        return Err(Error::new(
            proc_macro2::Span::call_site(),
            format!(
                "planning solution with list owner `{}` requires a matching `#[planning_entity_collection]` or `#[problem_fact_collection]` field for shadow updates",
                list_owner.field_ident,
            ),
        ));
    }

    let list_owner_ident = list_owner.field_ident;
    let element_collection_ident = runtime_config.element_collection.field_ident;
    let list_owner_type = list_owner.entity_type;
    let list_trait =
        quote! { <#list_owner_type as ::solverforge::__internal::ListVariableEntity<Self>> };

    let inverse_update = config.inverse_field.as_ref().map(|field| {
        let field_ident = Ident::new(field, proc_macro2::Span::call_site());
        quote! {
            for &element_idx in &element_indices {
                self.#element_collection_ident[element_idx].#field_ident = Some(entity_idx);
            }
        }
    });

    let previous_update = config.previous_field.as_ref().map(|field| {
        let field_ident = Ident::new(field, proc_macro2::Span::call_site());
        quote! {
            let mut prev_idx: Option<usize> = None;
            for &element_idx in &element_indices {
                self.#element_collection_ident[element_idx].#field_ident = prev_idx;
                prev_idx = Some(element_idx);
            }
        }
    });

    let next_update = config.next_field.as_ref().map(|field| {
        let field_ident = Ident::new(field, proc_macro2::Span::call_site());
        quote! {
            let len = element_indices.len();
            for (i, &element_idx) in element_indices.iter().enumerate() {
                let next_idx = if i + 1 < len { Some(element_indices[i + 1]) } else { None };
                self.#element_collection_ident[element_idx].#field_ident = next_idx;
            }
        }
    });

    let cascading_update = config.cascading_listener.as_ref().map(|method| {
        let method_ident = Ident::new(method, proc_macro2::Span::call_site());
        quote! {
            for &element_idx in &element_indices {
                self.#method_ident(element_idx);
            }
        }
    });

    let post_update = config.post_update_listener.as_ref().map(|method| {
        let method_ident = Ident::new(method, proc_macro2::Span::call_site());
        quote! {
            self.#method_ident(entity_idx);
        }
    });

    let aggregate_updates: Vec<_> = config
        .entity_aggregates
        .iter()
        .filter_map(|spec| {
            let parts: Vec<&str> = spec.split(':').collect();
            if parts.len() != 3 {
                return None;
            }
            let target_field = Ident::new(parts[0], proc_macro2::Span::call_site());
            let aggregation = parts[1];
            let source_field = Ident::new(parts[2], proc_macro2::Span::call_site());

            match aggregation {
                "sum" => Some(quote! {
                    self.#list_owner_ident[entity_idx].#target_field = element_indices
                        .iter()
                        .map(|&idx| self.#element_collection_ident[idx].#source_field)
                        .sum();
                }),
                _ => None,
            }
        })
        .collect();

    let compute_updates: Vec<_> = config
        .entity_computes
        .iter()
        .filter_map(|spec| {
            let parts: Vec<&str> = spec.split(':').collect();
            if parts.len() != 2 {
                return None;
            }
            let target_field = Ident::new(parts[0], proc_macro2::Span::call_site());
            let method_name = Ident::new(parts[1], proc_macro2::Span::call_site());

            Some(quote! {
                self.#list_owner_ident[entity_idx].#target_field = self.#method_name(entity_idx);
            })
        })
        .collect();

    Ok(quote! {
        impl ::solverforge::__internal::ShadowVariableSupport for #solution_name {
            #[inline]
            fn update_entity_shadows(&mut self, entity_idx: usize) {
                let element_indices: Vec<usize> =
                    #list_trait::list_field(&self.#list_owner_ident[entity_idx]).to_vec();

                #inverse_update
                #previous_update
                #next_update
                #cascading_update
                #(#aggregate_updates)*
                #(#compute_updates)*
                #post_update
            }
        }
    })
}