fn parse_solution(module: &ModuleSource, item_struct: &ItemStruct) -> Result<SolutionMetadata> {
if let Some(attr) = get_attribute(&item_struct.attrs, "planning_solution") {
validate_planning_solution_attribute(attr)?;
}
let fields = named_fields(item_struct, "#[planning_solution] requires named fields")?;
validate_solution_fields(item_struct, fields)?;
let mut collections = Vec::new();
let mut collection_field_names = BTreeSet::new();
let mut descriptor_index = 0usize;
for field in fields {
let Some(field_ident) = field.ident.clone() else {
continue;
};
let field_name = field_ident.to_string();
if has_attribute(&field.attrs, "planning_entity_collection")
|| has_attribute(&field.attrs, "problem_fact_collection")
|| has_attribute(&field.attrs, "planning_list_element_collection")
{
collection_field_names.insert(field_name.clone());
}
if has_attribute(&field.attrs, "planning_entity_collection") {
let type_name = collection_type_name(&field.ty).ok_or_else(|| {
Error::new_spanned(
field,
"#[planning_entity_collection] requires a Vec<T> field",
)
})?;
collections.push(SolutionCollection {
field_ident,
field_name,
type_name,
descriptor_index: Some(descriptor_index),
});
descriptor_index += 1;
} else if has_attribute(&field.attrs, "problem_fact_collection") {
let type_name = collection_type_name(&field.ty).ok_or_else(|| {
Error::new_spanned(field, "#[problem_fact_collection] requires a Vec<T> field")
})?;
collections.push(SolutionCollection {
field_ident,
field_name,
type_name,
descriptor_index: None,
});
}
}
Ok(SolutionMetadata {
module_ident: module.ident.clone(),
ident: item_struct.ident.clone(),
collections,
collection_field_names,
shadow_config: parse_shadow_config(&item_struct.attrs)?,
scalar_groups_path: parse_solution_scalar_groups_path(module, item_struct)?,
coverage_groups_path: parse_solution_coverage_groups_path(module, item_struct)?,
})
}
fn parse_solution_scalar_groups_path(
module: &ModuleSource,
item_struct: &ItemStruct,
) -> Result<Option<syn::Path>> {
let Some(attr) = get_attribute(&item_struct.attrs, "planning_solution") else {
return Ok(None);
};
parse_hook_path(attr, "scalar_groups", &module.ident, item_struct)
}
fn parse_solution_coverage_groups_path(
module: &ModuleSource,
item_struct: &ItemStruct,
) -> Result<Option<syn::Path>> {
let Some(attr) = get_attribute(&item_struct.attrs, "planning_solution") else {
return Ok(None);
};
parse_hook_path(attr, "coverage_groups", &module.ident, item_struct)
}
fn parse_shadow_config(attrs: &[Attribute]) -> Result<ShadowConfig> {
let mut config = ShadowConfig::default();
if let Some(attr) = get_attribute(attrs, "shadow_variable_updates") {
validate_shadow_updates_attribute(attr)?;
config.list_owner = parse_attribute_string(attr, "list_owner");
config.inverse_field = parse_attribute_string(attr, "inverse_field");
config.previous_field = parse_attribute_string(attr, "previous_field");
config.next_field = parse_attribute_string(attr, "next_field");
config.cascading_listener = parse_attribute_string(attr, "cascading_listener");
config.post_update_listener = parse_attribute_string(attr, "post_update_listener");
config.entity_aggregates = parse_attribute_list(attr, "entity_aggregate");
config.entity_computes = parse_attribute_list(attr, "entity_compute");
}
Ok(config)
}
fn parse_entity(module: &ModuleSource, item_struct: &ItemStruct) -> Result<EntityMetadata> {
if let Some(attr) = get_attribute(&item_struct.attrs, "planning_entity") {
validate_planning_entity_attribute(attr)?;
}
let fields = named_fields(item_struct, "#[planning_entity] requires named fields")?;
validate_entity_fields(fields)?;
let mut scalar_variables = Vec::new();
let mut list_variable_name = None;
let mut list_element_collection = None;
for field in fields {
if has_attribute(&field.attrs, "planning_variable") {
let Some(field_ident) = field.ident.as_ref() else {
continue;
};
if !field_is_option_usize(&field.ty) {
continue;
}
let attr = get_attribute(&field.attrs, "planning_variable").unwrap();
if parse_attribute_bool(attr, "chained").unwrap_or(false) {
continue;
}
scalar_variables.push(ScalarVariableMetadata {
field_name: field_ident.to_string(),
hooks: HookPaths {
candidate_values: parse_hook_path(
attr,
"candidate_values",
&module.ident,
field,
)?,
nearby_value_candidates: parse_hook_path(
attr,
"nearby_value_candidates",
&module.ident,
field,
)?,
nearby_entity_candidates: parse_hook_path(
attr,
"nearby_entity_candidates",
&module.ident,
field,
)?,
nearby_value_distance_meter: parse_hook_path(
attr,
"nearby_value_distance_meter",
&module.ident,
field,
)?,
nearby_entity_distance_meter: parse_hook_path(
attr,
"nearby_entity_distance_meter",
&module.ident,
field,
)?,
construction_entity_order_key: parse_hook_path(
attr,
"construction_entity_order_key",
&module.ident,
field,
)?,
construction_value_order_key: parse_hook_path(
attr,
"construction_value_order_key",
&module.ident,
field,
)?,
},
});
}
if has_attribute(&field.attrs, "planning_list_variable") {
if let Some(field_ident) = field.ident.as_ref() {
list_variable_name = Some(field_ident.to_string());
}
let attr = get_attribute(&field.attrs, "planning_list_variable").unwrap();
let element_collection =
parse_attribute_string(attr, "element_collection").ok_or_else(|| {
Error::new_spanned(
field,
"#[planning_list_variable] requires `element_collection = \"solution_field\"`",
)
})?;
list_element_collection = Some(element_collection);
}
}
Ok(EntityMetadata {
type_name: item_struct.ident.to_string(),
scalar_variables,
list_variable_name,
list_element_collection,
})
}
fn validate_solution_fields(
item_struct: &ItemStruct,
fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
) -> Result<()> {
if let Some(attr) = get_attribute(&item_struct.attrs, "shadow_variable_updates") {
validate_shadow_updates_attribute(attr)?;
}
for field in fields {
if let Some(attr) = get_attribute(&field.attrs, "planning_entity_collection") {
validate_no_attribute_args(attr, "planning_entity_collection")?;
}
if let Some(attr) = get_attribute(&field.attrs, "problem_fact_collection") {
validate_no_attribute_args(attr, "problem_fact_collection")?;
}
if let Some(attr) = get_attribute(&field.attrs, "planning_score") {
validate_no_attribute_args(attr, "planning_score")?;
}
if let Some(attr) = get_attribute(&field.attrs, "value_range_provider") {
validate_no_attribute_args(attr, "value_range_provider")?;
}
if let Some(attr) = get_attribute(&field.attrs, "planning_list_element_collection") {
validate_list_element_collection_attribute(attr)?;
}
}
Ok(())
}
fn validate_entity_fields(
fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
) -> Result<()> {
for field in fields {
if let Some(attr) = get_attribute(&field.attrs, "planning_id") {
validate_no_attribute_args(attr, "planning_id")?;
}
if let Some(attr) = get_attribute(&field.attrs, "planning_pin") {
validate_no_attribute_args(attr, "planning_pin")?;
}
if let Some(attr) = get_attribute(&field.attrs, "planning_variable") {
validate_planning_variable_attribute(attr)?;
}
if let Some(attr) = get_attribute(&field.attrs, "planning_list_variable") {
validate_planning_list_variable_attribute(attr)?;
}
if let Some(attr) = get_attribute(&field.attrs, "inverse_relation_shadow_variable") {
validate_shadow_variable_attribute(attr, "inverse_relation_shadow_variable")?;
}
if let Some(attr) = get_attribute(&field.attrs, "previous_element_shadow_variable") {
validate_shadow_variable_attribute(attr, "previous_element_shadow_variable")?;
}
if let Some(attr) = get_attribute(&field.attrs, "next_element_shadow_variable") {
validate_shadow_variable_attribute(attr, "next_element_shadow_variable")?;
}
if let Some(attr) = get_attribute(&field.attrs, "cascading_update_shadow_variable") {
validate_no_attribute_args(attr, "cascading_update_shadow_variable")?;
}
}
Ok(())
}
fn validate_problem_fact_fields(item_struct: &ItemStruct) -> Result<()> {
let fields = named_fields(item_struct, "#[problem_fact] requires named fields")?;
for field in fields {
if let Some(attr) = get_attribute(&field.attrs, "planning_id") {
validate_no_attribute_args(attr, "planning_id")?;
}
}
Ok(())
}
fn named_fields<'a>(
item_struct: &'a ItemStruct,
message: &'static str,
) -> Result<&'a syn::punctuated::Punctuated<syn::Field, syn::token::Comma>> {
let Fields::Named(fields) = &item_struct.fields else {
return Err(Error::new_spanned(item_struct, message));
};
Ok(&fields.named)
}
fn parse_hook_path(
attr: &Attribute,
key: &str,
module_ident: &Ident,
span: &impl ToTokens,
) -> Result<Option<syn::Path>> {
let Some(raw) = parse_attribute_string(attr, key) else {
return Ok(None);
};
let mut path: syn::Path = syn::parse_str(&raw)
.map_err(|_| Error::new_spanned(span, format!("{key} must be a valid Rust path")))?;
if path.leading_colon.is_none() && path.segments.len() == 1 {
path = syn::parse_quote! { #module_ident::#path };
}
Ok(Some(path))
}
fn validate_collections(
solution: &SolutionMetadata,
entities: &BTreeMap<String, EntityMetadata>,
facts: &BTreeSet<String>,
aliases: &BTreeMap<String, String>,
) -> Result<()> {
for collection in &solution.collections {
let resolved_type_name = canonical_type_name(aliases, &collection.type_name);
if collection.descriptor_index.is_some() {
if !entities.contains_key(resolved_type_name) {
return Err(Error::new_spanned(
&collection.field_ident,
format!(
"planning_model! entity collection `{}` references unknown #[planning_entity] type `{}`",
collection.field_name, collection.type_name,
),
));
}
} else if !facts.contains(resolved_type_name) && !entities.contains_key(resolved_type_name)
{
return Err(Error::new_spanned(
&collection.field_ident,
format!(
"planning_model! problem fact collection `{}` references unknown #[problem_fact] type `{}`",
collection.field_name, collection.type_name,
),
));
}
}
Ok(())
}
fn validate_list_element_sources(
solution: &SolutionMetadata,
entities: &BTreeMap<String, EntityMetadata>,
aliases: &BTreeMap<String, String>,
) -> Result<()> {
for collection in solution
.collections
.iter()
.filter(|collection| collection.descriptor_index.is_some())
{
let resolved_type_name = canonical_type_name(aliases, &collection.type_name);
let Some(entity) = entities.get(resolved_type_name) else {
continue;
};
let Some(element_collection) = entity.list_element_collection.as_deref() else {
continue;
};
if !solution.collection_field_names.contains(element_collection) {
return Err(Error::new_spanned(
&collection.field_ident,
format!(
"planning_model! list entity `{}` requires a solution collection field named `{}`",
entity.type_name, element_collection,
),
));
}
}
Ok(())
}
fn collection_type_name(ty: &Type) -> Option<String> {
let inner = collection_inner_type(ty)?;
type_name(inner)
}
fn collection_inner_type(ty: &Type) -> Option<&Type> {
let Type::Path(type_path) = ty else {
return None;
};
let segment = type_path.path.segments.last()?;
if segment.ident != "Vec" {
return None;
}
let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
return None;
};
let Some(syn::GenericArgument::Type(inner)) = args.args.first() else {
return None;
};
Some(inner)
}
fn type_name(ty: &Type) -> Option<String> {
let Type::Path(type_path) = ty else {
return None;
};
Some(type_path.path.segments.last()?.ident.to_string())
}
fn field_is_option_usize(ty: &Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
let Some(segment) = type_path.path.segments.last() else {
return false;
};
if segment.ident != "Option" {
return false;
}
let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
return false;
};
let Some(syn::GenericArgument::Type(Type::Path(inner))) = args.args.first() else {
return false;
};
inner
.path
.segments
.last()
.is_some_and(|segment| segment.ident == "usize")
}