use super::*;
use crate::{constraint_type::ConstraintCollection, linear, v1, Function, VariableID};
use anyhow::Result;
use num::Zero;
impl Instance {
#[cfg_attr(doc, katexit::katexit)]
pub fn penalty_method(self) -> Result<ParametricInstance> {
anyhow::ensure!(
self.indicator_constraint_collection.active().is_empty(),
"penalty_method does not support indicator constraints. \
Remove or convert indicator constraints before applying penalty method."
);
anyhow::ensure!(
self.one_hot_constraint_collection.active().is_empty(),
"penalty_method does not support one-hot constraints. \
Remove or convert one-hot constraints before applying penalty method."
);
anyhow::ensure!(
self.sos1_constraint_collection.active().is_empty(),
"penalty_method does not support SOS1 constraints. \
Remove or convert SOS1 constraints before applying penalty method."
);
let mut max_id = 0;
for id in self.decision_variables.keys() {
max_id = max_id.max(id.into_inner());
}
if let Some(params) = &self.parameters {
for id in params.entries.keys() {
max_id = max_id.max(*id);
}
}
let id_base = max_id + 1;
let mut objective = self.objective.clone();
let mut parameters = BTreeMap::new();
let mut removed_constraints = BTreeMap::new();
let (active_constraints, existing_removed, constraint_metadata) =
self.constraint_collection.into_parts();
removed_constraints.extend(existing_removed);
for (i, (constraint_id, constraint)) in active_constraints.into_iter().enumerate() {
let parameter_id = VariableID::from(id_base + i as u64);
let parameter = v1::Parameter {
id: parameter_id.into_inner(),
name: Some("penalty_weight".to_string()),
subscripts: vec![constraint_id.into_inner() as i64],
..Default::default()
};
let f = constraint.function().clone();
let penalty_term = Function::from(linear!(parameter_id)) * f.clone() * f;
objective += penalty_term;
let removed_reason = crate::constraint::RemovedReason {
reason: "ommx.Instance.penalty_method".to_string(),
parameters: {
let mut map = fnv::FnvHashMap::default();
map.insert(
"parameter_id".to_string(),
parameter_id.into_inner().to_string(),
);
map
},
};
parameters.insert(parameter_id, parameter);
removed_constraints.insert(constraint_id, (constraint, removed_reason));
}
Ok(ParametricInstance {
sense: self.sense,
objective,
decision_variables: self.decision_variables,
parameters,
variable_metadata: self.variable_metadata,
constraint_collection: ConstraintCollection::with_metadata(
BTreeMap::new(),
removed_constraints,
constraint_metadata,
),
indicator_constraint_collection: self.indicator_constraint_collection,
one_hot_constraint_collection: self.one_hot_constraint_collection,
sos1_constraint_collection: self.sos1_constraint_collection,
decision_variable_dependency: self.decision_variable_dependency,
description: self.description,
named_functions: self.named_functions,
named_function_metadata: self.named_function_metadata,
})
}
#[cfg_attr(doc, katexit::katexit)]
pub fn uniform_penalty_method(self) -> Result<ParametricInstance> {
anyhow::ensure!(
self.indicator_constraint_collection.active().is_empty(),
"uniform_penalty_method does not support indicator constraints. \
Remove or convert indicator constraints before applying penalty method."
);
anyhow::ensure!(
self.one_hot_constraint_collection.active().is_empty(),
"uniform_penalty_method does not support one-hot constraints. \
Remove or convert one-hot constraints before applying penalty method."
);
anyhow::ensure!(
self.sos1_constraint_collection.active().is_empty(),
"uniform_penalty_method does not support SOS1 constraints. \
Remove or convert SOS1 constraints before applying penalty method."
);
if self.constraints().is_empty() {
let (_active, existing_removed, constraint_metadata) =
self.constraint_collection.into_parts();
return Ok(ParametricInstance {
sense: self.sense,
objective: self.objective,
decision_variables: self.decision_variables,
parameters: BTreeMap::new(),
variable_metadata: self.variable_metadata,
constraint_collection: ConstraintCollection::with_metadata(
BTreeMap::new(),
existing_removed,
constraint_metadata,
),
indicator_constraint_collection: self.indicator_constraint_collection,
one_hot_constraint_collection: self.one_hot_constraint_collection,
sos1_constraint_collection: self.sos1_constraint_collection,
decision_variable_dependency: self.decision_variable_dependency,
description: self.description,
named_functions: self.named_functions,
named_function_metadata: self.named_function_metadata,
});
}
let mut max_id = 0;
for id in self.decision_variables.keys() {
max_id = max_id.max(id.into_inner());
}
if let Some(params) = &self.parameters {
for id in params.entries.keys() {
max_id = max_id.max(*id);
}
}
let parameter_id = VariableID::from(max_id + 1);
let mut objective = self.objective.clone();
let parameter = v1::Parameter {
id: parameter_id.into_inner(),
name: Some("uniform_penalty_weight".to_string()),
..Default::default()
};
let mut removed_constraints = BTreeMap::new();
let mut quad_sum = Function::zero();
let (active_constraints, existing_removed, constraint_metadata) =
self.constraint_collection.into_parts();
removed_constraints.extend(existing_removed);
for (constraint_id, constraint) in active_constraints.into_iter() {
let f = constraint.function().clone();
quad_sum += f.clone() * f;
let removed_reason = crate::constraint::RemovedReason {
reason: "ommx.Instance.uniform_penalty_method".to_string(),
parameters: Default::default(),
};
removed_constraints.insert(constraint_id, (constraint, removed_reason));
}
objective += Function::from(linear!(parameter_id)) * quad_sum;
let parameters = BTreeMap::from([(parameter_id, parameter)]);
Ok(ParametricInstance {
sense: self.sense,
objective,
decision_variables: self.decision_variables,
parameters,
variable_metadata: self.variable_metadata,
constraint_collection: ConstraintCollection::with_metadata(
BTreeMap::new(),
removed_constraints,
constraint_metadata,
),
indicator_constraint_collection: self.indicator_constraint_collection,
one_hot_constraint_collection: self.one_hot_constraint_collection,
sos1_constraint_collection: self.sos1_constraint_collection,
decision_variable_dependency: self.decision_variable_dependency,
description: self.description,
named_functions: self.named_functions,
named_function_metadata: self.named_function_metadata,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{coeff, constraint::Equality, linear, DecisionVariable, Sense};
use std::collections::BTreeMap;
fn create_test_instance_with_constraints() -> Instance {
let mut decision_variables = BTreeMap::new();
decision_variables.insert(
VariableID::from(1),
DecisionVariable::continuous(VariableID::from(1)),
);
decision_variables.insert(
VariableID::from(2),
DecisionVariable::continuous(VariableID::from(2)),
);
let objective = Function::from(linear!(1) + linear!(2));
let mut constraints = BTreeMap::new();
constraints.insert(
ConstraintID::from(1),
Constraint {
equality: Equality::LessThanOrEqualToZero,
stage: crate::constraint::CreatedData {
function: Function::from(linear!(1) + linear!(2) + coeff!(-1.0)),
},
},
);
constraints.insert(
ConstraintID::from(2),
Constraint {
equality: Equality::EqualToZero,
stage: crate::constraint::CreatedData {
function: Function::from(linear!(1) + coeff!(-1.0) * linear!(2)),
},
},
);
Instance::new(Sense::Minimize, objective, decision_variables, constraints).unwrap()
}
fn verify_penalty_method_properties(
original_objective: Function,
original_constraint_count: usize,
parametric_instance: &ParametricInstance,
expected_param_count: usize,
expected_param_name: &str,
) {
assert_eq!(parametric_instance.constraints().len(), 0);
assert_eq!(
parametric_instance.removed_constraints().len(),
original_constraint_count
);
assert_eq!(parametric_instance.parameters.len(), expected_param_count);
for parameter in parametric_instance.parameters.values() {
assert_eq!(parameter.name, Some(expected_param_name.to_string()));
}
let dv_ids: std::collections::BTreeSet<_> = parametric_instance
.decision_variables
.keys()
.cloned()
.collect();
let p_ids: std::collections::BTreeSet<_> =
parametric_instance.parameters.keys().cloned().collect();
assert!(dv_ids.is_disjoint(&p_ids));
use crate::v1::Parameters;
use ::approx::AbsDiffEq;
let parameters = Parameters {
entries: p_ids.iter().map(|id| (id.into_inner(), 0.0)).collect(),
};
let substituted = parametric_instance
.clone()
.with_parameters(parameters)
.unwrap();
assert!(substituted
.objective
.abs_diff_eq(&original_objective, crate::ATol::default()));
assert_eq!(substituted.constraints().len(), 0);
}
#[test]
fn test_penalty_method() {
let instance = create_test_instance_with_constraints();
let original_objective = instance.objective.clone();
let original_constraint_count = instance.constraints().len();
let parametric_instance = instance.penalty_method().unwrap();
verify_penalty_method_properties(
original_objective,
original_constraint_count,
¶metric_instance,
2, "penalty_weight",
);
}
#[test]
fn test_uniform_penalty_method() {
let instance = create_test_instance_with_constraints();
let original_objective = instance.objective.clone();
let original_constraint_count = instance.constraints().len();
let parametric_instance = instance.uniform_penalty_method().unwrap();
verify_penalty_method_properties(
original_objective,
original_constraint_count,
¶metric_instance,
1, "uniform_penalty_weight",
);
}
#[test]
fn test_penalty_methods_with_no_constraints() {
let mut decision_variables = BTreeMap::new();
decision_variables.insert(
VariableID::from(1),
DecisionVariable::continuous(VariableID::from(1)),
);
let objective = Function::from(linear!(1));
let constraints = BTreeMap::new();
let instance = Instance::new(
Sense::Minimize,
objective.clone(),
decision_variables,
constraints,
)
.unwrap();
let parametric_instance = instance.clone().penalty_method().unwrap();
assert_eq!(parametric_instance.parameters.len(), 0);
assert_eq!(parametric_instance.constraints().len(), 0);
assert_eq!(parametric_instance.removed_constraints().len(), 0);
assert_eq!(parametric_instance.objective, objective);
let parametric_instance = instance.uniform_penalty_method().unwrap();
assert_eq!(parametric_instance.parameters.len(), 0);
assert_eq!(parametric_instance.constraints().len(), 0);
assert_eq!(parametric_instance.removed_constraints().len(), 0);
assert_eq!(parametric_instance.objective, objective);
}
#[test]
fn test_penalty_method_preserves_existing_removed_constraints() {
let mut instance = create_test_instance_with_constraints();
instance
.relax_constraint(
ConstraintID::from(1),
"pre_existing".to_string(),
std::iter::empty::<(String, String)>(),
)
.unwrap();
assert_eq!(instance.constraints().len(), 1); assert_eq!(instance.removed_constraints().len(), 1);
let parametric_instance = instance.penalty_method().unwrap();
assert_eq!(parametric_instance.removed_constraints().len(), 2);
assert!(parametric_instance
.removed_constraints()
.contains_key(&ConstraintID::from(1)));
assert!(parametric_instance
.removed_constraints()
.contains_key(&ConstraintID::from(2)));
}
#[test]
fn test_uniform_penalty_method_preserves_existing_removed_constraints() {
let mut instance = create_test_instance_with_constraints();
instance
.relax_constraint(
ConstraintID::from(1),
"pre_existing".to_string(),
std::iter::empty::<(String, String)>(),
)
.unwrap();
assert_eq!(instance.constraints().len(), 1);
assert_eq!(instance.removed_constraints().len(), 1);
let parametric_instance = instance.uniform_penalty_method().unwrap();
assert_eq!(parametric_instance.removed_constraints().len(), 2);
assert!(parametric_instance
.removed_constraints()
.contains_key(&ConstraintID::from(1)));
assert!(parametric_instance
.removed_constraints()
.contains_key(&ConstraintID::from(2)));
}
}