use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Error, Fields};
use crate::attr_parse::has_attribute;
use super::config::{
parse_config_path, parse_constraints_path, parse_shadow_config, parse_solver_toml_path,
};
use super::list_operations::generate_list_operations;
use super::runtime::{
generate_runtime_phase_support, generate_runtime_solve_internal, generate_solvable_solution,
};
use super::shadow::generate_shadow_support;
use super::stream_extensions::generate_constraint_stream_extensions;
use super::type_helpers::{extract_collection_inner_type, extract_option_inner_type};
pub(crate) fn expand_derive(input: DeriveInput) -> Result<TokenStream, Error> {
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(Error::new_spanned(
&input,
"#[planning_solution] requires named fields",
))
}
},
_ => {
return Err(Error::new_spanned(
&input,
"#[planning_solution] only works on structs",
))
}
};
let score_field = fields
.iter()
.find(|f| has_attribute(&f.attrs, "planning_score"))
.ok_or_else(|| {
Error::new_spanned(
&input,
"#[planning_solution] requires a #[planning_score] field",
)
})?;
let score_field_name = score_field.ident.as_ref().unwrap();
let score_type = extract_option_inner_type(&score_field.ty)?;
let entity_descriptors: Vec<_> = fields
.iter()
.filter(|f| has_attribute(&f.attrs, "planning_entity_collection"))
.filter_map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_name_str = field_name.to_string();
let element_type = extract_collection_inner_type(&f.ty)?;
Some(quote! {
.with_entity(#element_type::entity_descriptor(#field_name_str).with_extractor(
Box::new(::solverforge::__internal::EntityCollectionExtractor::new(
stringify!(#element_type),
#field_name_str,
|s: &#name| &s.#field_name,
|s: &mut #name| &mut s.#field_name,
))
))
})
})
.collect();
let fact_descriptors: Vec<_> = fields
.iter()
.filter(|f| has_attribute(&f.attrs, "problem_fact_collection"))
.filter_map(|f| {
let field_name = f.ident.as_ref().unwrap();
let field_name_str = field_name.to_string();
let element_type = extract_collection_inner_type(&f.ty)?;
Some(quote! {
.with_problem_fact(#element_type::problem_fact_descriptor(#field_name_str).with_extractor(
Box::new(::solverforge::__internal::EntityCollectionExtractor::new(
stringify!(#element_type),
#field_name_str,
|s: &#name| &s.#field_name,
|s: &mut #name| &mut s.#field_name,
))
))
})
})
.collect();
let name_str = name.to_string();
let score_field_str = score_field_name.to_string();
let shadow_config = parse_shadow_config(&input.attrs);
let shadow_support_impl = generate_shadow_support(&shadow_config, fields, name)?;
let constraints_path = parse_constraints_path(&input.attrs);
let config_path = parse_config_path(&input.attrs);
let solver_toml_path = parse_solver_toml_path(&input.attrs);
let entity_count_arms: Vec<_> = fields
.iter()
.filter(|f| has_attribute(&f.attrs, "planning_entity_collection"))
.enumerate()
.map(|(idx, f)| {
let field_name = f.ident.as_ref().unwrap();
quote! { #idx => this.#field_name.len(), }
})
.collect();
let list_operations = generate_list_operations(fields);
let runtime_phase_support = generate_runtime_phase_support(fields, &constraints_path, name);
let runtime_solve_internal =
generate_runtime_solve_internal(&constraints_path, &config_path, &solver_toml_path);
let solvable_solution_impl = generate_solvable_solution(name, &constraints_path);
let stream_extensions = generate_constraint_stream_extensions(fields, name);
let expanded = quote! {
impl #impl_generics ::solverforge::__internal::PlanningSolution for #name #ty_generics #where_clause {
type Score = #score_type;
fn score(&self) -> Option<Self::Score> { self.#score_field_name.clone() }
fn set_score(&mut self, score: Option<Self::Score>) { self.#score_field_name = score; }
}
impl #impl_generics #name #ty_generics #where_clause {
pub fn descriptor() -> ::solverforge::__internal::SolutionDescriptor {
::solverforge::__internal::SolutionDescriptor::new(
#name_str,
::std::any::TypeId::of::<Self>(),
)
.with_score_field(#score_field_str)
#(#entity_descriptors)*
#(#fact_descriptors)*
}
#[inline]
pub fn entity_count(this: &Self, descriptor_index: usize) -> usize {
match descriptor_index {
#(#entity_count_arms)*
_ => 0,
}
}
#list_operations
#runtime_solve_internal
}
#runtime_phase_support
#shadow_support_impl
#solvable_solution_impl
#stream_extensions
};
Ok(expanded)
}