solverforge-macros 0.8.4

Derive macros for SolverForge constraint solver
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Error, Ident, Type};

use crate::attr_parse::{get_attribute, parse_attribute_string};
use crate::list_registry::record_list_entity_metadata;

pub(super) fn generate_list_metadata(
    entity_name: &Ident,
    list_variables: &[&syn::Field],
) -> Result<TokenStream, Error> {
    let Some(field) = list_variables.first().copied() else {
        return Ok(TokenStream::new());
    };

    let field_name = field.ident.as_ref().unwrap();
    let field_name_str = field_name.to_string();
    let attr = get_attribute(&field.attrs, "planning_list_variable").unwrap();
    let element_collection = parse_attribute_string(attr, "element_collection").ok_or_else(|| {
        Error::new_spanned(
            field,
            "#[planning_list_variable] requires `element_collection = \"solution_field\"` for stock solving",
        )
    })?;

    ensure_vec_usize(&field.ty, field)?;

    let cross_dm_ty = parse_type_or_default(
        parse_attribute_string(attr, "distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "distance_meter",
        field,
    )?;
    let intra_dm_ty = parse_type_or_default(
        parse_attribute_string(attr, "intra_distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "intra_distance_meter",
        field,
    )?;
    let cross_dm_expr = parse_default_expr(
        parse_attribute_string(attr, "distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "distance_meter",
        field,
    )?;
    let intra_dm_expr = parse_default_expr(
        parse_attribute_string(attr, "intra_distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "intra_distance_meter",
        field,
    )?;
    let solution_trait_bound = parse_solution_trait_bound(attr, field)?;
    let solution_where_clause = solution_trait_bound
        .as_ref()
        .map(|bound| quote! { where Solution: #bound });
    let merge_feasible = option_fn_expr(
        parse_attribute_string(attr, "merge_feasible_fn"),
        "merge_feasible_fn",
        field,
    )?;
    let cw_depot = option_fn_expr(
        parse_attribute_string(attr, "cw_depot_fn"),
        "cw_depot_fn",
        field,
    )?;
    let cw_dist = option_fn_expr(
        parse_attribute_string(attr, "cw_distance_fn"),
        "cw_distance_fn",
        field,
    )?;
    let cw_load = option_fn_expr(
        parse_attribute_string(attr, "cw_element_load_fn"),
        "cw_element_load_fn",
        field,
    )?;
    let cw_cap = option_fn_expr(
        parse_attribute_string(attr, "cw_capacity_fn"),
        "cw_capacity_fn",
        field,
    )?;
    let cw_assign = option_fn_expr(
        parse_attribute_string(attr, "cw_assign_route_fn"),
        "cw_assign_route_fn",
        field,
    )?;
    let k_opt_get = option_fn_expr(
        parse_attribute_string(attr, "k_opt_get_route"),
        "k_opt_get_route",
        field,
    )?;
    let k_opt_set = option_fn_expr(
        parse_attribute_string(attr, "k_opt_set_route"),
        "k_opt_set_route",
        field,
    )?;
    let k_opt_depot = option_fn_expr(
        parse_attribute_string(attr, "k_opt_depot_fn"),
        "k_opt_depot_fn",
        field,
    )?;
    let k_opt_dist = option_fn_expr(
        parse_attribute_string(attr, "k_opt_distance_fn"),
        "k_opt_distance_fn",
        field,
    )?;
    let k_opt_feasible = option_fn_expr(
        parse_attribute_string(attr, "k_opt_feasible_fn"),
        "k_opt_feasible_fn",
        field,
    )?;

    record_list_entity_metadata(&entity_name.to_string(), element_collection.clone());

    Ok(quote! {
        pub const __SOLVERFORGE_LIST_VARIABLE_COUNT: usize = 1;
        pub const __SOLVERFORGE_LIST_VARIABLE_NAME: &'static str = #field_name_str;
        pub const __SOLVERFORGE_LIST_ELEMENT_COLLECTION: &'static str = #element_collection;

        #[inline]
        pub fn __solverforge_list_field(entity: &Self) -> &[usize] {
            &entity.#field_name
        }

        #[inline]
        pub fn __solverforge_list_field_mut(entity: &mut Self) -> &mut ::std::vec::Vec<usize> {
            &mut entity.#field_name
        }

        #[inline]
        pub fn __solverforge_list_metadata<Solution>() -> ::solverforge::__internal::ListVariableMetadata<
            Solution,
            #cross_dm_ty,
            #intra_dm_ty,
        >
        #solution_where_clause
        {
            let _ = stringify!(#entity_name);
            let _ = #element_collection;
            ::solverforge::__internal::ListVariableMetadata::new(
                #cross_dm_expr,
                #intra_dm_expr,
                #merge_feasible,
                #cw_depot,
                #cw_dist,
                #cw_load,
                #cw_cap,
                #cw_assign,
                #k_opt_get,
                #k_opt_set,
                #k_opt_depot,
                #k_opt_dist,
                #k_opt_feasible,
            )
        }

    })
}

pub(super) fn generate_list_trait_impl(
    entity_name: &Ident,
    list_variables: &[&syn::Field],
) -> Result<TokenStream, Error> {
    let Some(field) = list_variables.first().copied() else {
        return Ok(quote! {
            impl<Solution> ::solverforge::__internal::ListVariableEntity<Solution> for #entity_name
            where
                Solution: ::solverforge::__internal::PlanningSolution,
            {
                type CrossDistanceMeter = ::solverforge::__internal::DefaultCrossEntityDistanceMeter;
                type IntraDistanceMeter = ::solverforge::__internal::DefaultCrossEntityDistanceMeter;

                const HAS_STOCK_LIST_VARIABLE: bool = false;
                const STOCK_LIST_VARIABLE_NAME: &'static str = "";
                const STOCK_LIST_ELEMENT_SOURCE: ::core::option::Option<&'static str> =
                    ::core::option::Option::None;

                #[inline]
                fn list_field(_entity: &Self) -> &[usize] {
                    panic!("ListVariableEntity::list_field called on an entity without #[planning_list_variable]");
                }

                #[inline]
                fn list_field_mut(_entity: &mut Self) -> &mut ::std::vec::Vec<usize> {
                    panic!("ListVariableEntity::list_field_mut called on an entity without #[planning_list_variable]");
                }

                #[inline]
                fn list_metadata() -> ::solverforge::__internal::ListVariableMetadata<
                    Solution,
                    Self::CrossDistanceMeter,
                    Self::IntraDistanceMeter,
                > {
                    ::solverforge::__internal::ListVariableMetadata::new(
                        ::solverforge::__internal::DefaultCrossEntityDistanceMeter::default(),
                        ::solverforge::__internal::DefaultCrossEntityDistanceMeter::default(),
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                        ::core::option::Option::None,
                    )
                }
            }
        });
    };

    let attr = get_attribute(&field.attrs, "planning_list_variable").unwrap();
    let cross_dm_ty = parse_type_or_default(
        parse_attribute_string(attr, "distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "distance_meter",
        field,
    )?;
    let intra_dm_ty = parse_type_or_default(
        parse_attribute_string(attr, "intra_distance_meter"),
        "::solverforge::__internal::DefaultCrossEntityDistanceMeter",
        "intra_distance_meter",
        field,
    )?;
    let solution_trait_bound = parse_solution_trait_bound(attr, field)?;
    let element_source = parse_attribute_string(attr, "element_collection").ok_or_else(|| {
        Error::new_spanned(
            field,
            "#[planning_list_variable] requires `element_collection = \"solution_collection\"` for stock solving",
        )
    })?;
    let solution_bound = solution_trait_bound
        .as_ref()
        .map(|bound| quote! { + #bound })
        .unwrap_or_default();

    Ok(quote! {
        impl<Solution> ::solverforge::__internal::ListVariableEntity<Solution> for #entity_name
        where
            Solution: ::solverforge::__internal::PlanningSolution #solution_bound,
        {
            type CrossDistanceMeter = #cross_dm_ty;
            type IntraDistanceMeter = #intra_dm_ty;

            const HAS_STOCK_LIST_VARIABLE: bool = true;
            const STOCK_LIST_VARIABLE_NAME: &'static str = Self::__SOLVERFORGE_LIST_VARIABLE_NAME;
            const STOCK_LIST_ELEMENT_SOURCE: ::core::option::Option<&'static str> =
                ::core::option::Option::Some(#element_source);

            #[inline]
            fn list_field(entity: &Self) -> &[usize] {
                Self::__solverforge_list_field(entity)
            }

            #[inline]
            fn list_field_mut(entity: &mut Self) -> &mut ::std::vec::Vec<usize> {
                Self::__solverforge_list_field_mut(entity)
            }

            #[inline]
            fn list_metadata() -> ::solverforge::__internal::ListVariableMetadata<
                Solution,
                Self::CrossDistanceMeter,
                Self::IntraDistanceMeter,
            > {
                Self::__solverforge_list_metadata::<Solution>()
            }
        }
    })
}

fn ensure_vec_usize(ty: &Type, field: &syn::Field) -> Result<(), Error> {
    let Some(inner) = field_vec_inner_type(ty) else {
        return Err(Error::new_spanned(
            field,
            "#[planning_list_variable] requires a field of type Vec<usize>",
        ));
    };
    let Type::Path(type_path) = inner else {
        return Err(Error::new_spanned(
            field,
            "#[planning_list_variable] requires a field of type Vec<usize>",
        ));
    };
    let Some(segment) = type_path.path.segments.last() else {
        return Err(Error::new_spanned(
            field,
            "#[planning_list_variable] requires a field of type Vec<usize>",
        ));
    };
    if segment.ident != "usize" {
        return Err(Error::new_spanned(
            field,
            "#[planning_list_variable] stock solving currently requires Vec<usize>",
        ));
    }
    Ok(())
}

fn field_vec_inner_type(ty: &Type) -> Option<&Type> {
    let Type::Path(type_path) = ty else {
        return None;
    };
    let segment = type_path.path.segments.last()?;
    if segment.ident != "Vec" {
        return None;
    }
    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
        return None;
    };
    let Some(syn::GenericArgument::Type(inner)) = args.args.first() else {
        return None;
    };
    Some(inner)
}

fn parse_type_or_default(
    path: Option<String>,
    default: &str,
    label: &str,
    span: &impl quote::ToTokens,
) -> Result<Type, Error> {
    let raw = path.unwrap_or_else(|| default.to_string());
    syn::parse_str(&raw)
        .map_err(|_| Error::new_spanned(span, format!("{label} must be a valid type path")))
}

fn parse_default_expr(
    path: Option<String>,
    default: &str,
    label: &str,
    span: &impl quote::ToTokens,
) -> Result<syn::Expr, Error> {
    if let Some(path) = path {
        let parsed: syn::Path = syn::parse_str(&path)
            .map_err(|_| Error::new_spanned(span, format!("{label} must be a valid path")))?;
        Ok(syn::parse_quote! { #parsed::default() })
    } else {
        syn::parse_str(default)
            .map_err(|_| Error::new_spanned(span, format!("{label} must be a valid path")))
    }
}

fn parse_solution_trait_bound(
    attr: &syn::Attribute,
    span: &impl quote::ToTokens,
) -> Result<Option<syn::TypeParamBound>, Error> {
    parse_attribute_string(attr, "solution_trait")
        .map(|path| {
            syn::parse_str(&path)
                .map_err(|_| Error::new_spanned(span, "solution_trait must be a valid trait path"))
        })
        .transpose()
}

fn option_fn_expr(
    path: Option<String>,
    label: &str,
    span: &impl quote::ToTokens,
) -> Result<syn::Expr, Error> {
    if let Some(path) = path {
        let parsed: syn::Path = syn::parse_str(&path)
            .map_err(|_| Error::new_spanned(span, format!("{label} must be a valid path")))?;
        Ok(syn::parse_quote! { ::core::option::Option::Some(#parsed) })
    } else {
        Ok(syn::parse_quote! { ::core::option::Option::None })
    }
}