use super::*;
use crate::{
substitute_acyclic, substitute_one_via_acyclic, Function, Substitute, SubstitutionError,
VariableID,
};
impl Substitute for Instance {
type Output = Self;
fn substitute_acyclic(
mut self,
acyclic: &crate::AcyclicAssignments,
) -> Result<Self::Output, crate::SubstitutionError> {
let substituted_variables: std::collections::BTreeSet<VariableID> =
acyclic.iter().map(|(var_id, _)| *var_id).collect();
let mut affected_constraint_ids = std::collections::BTreeSet::new();
for (constraint_id, constraint) in self.constraint_collection.active() {
let required_ids = constraint.required_ids();
if !required_ids.is_disjoint(&substituted_variables) {
affected_constraint_ids.insert(*constraint_id);
}
}
substitute_acyclic(&mut self.objective, acyclic)?;
for constraint_id in &affected_constraint_ids {
if let Some(constraint) = self
.constraint_collection
.active_mut()
.get_mut(constraint_id)
{
substitute_acyclic(&mut constraint.stage.function, acyclic)?;
}
}
for (&cid, ic) in self.indicator_constraint_collection.active().iter() {
if substituted_variables.contains(&ic.indicator_variable) {
return Err(SubstitutionError::IndicatorVariableSubstitution {
indicator_variable: ic.indicator_variable,
constraint_id: cid,
});
}
}
for ic in self
.indicator_constraint_collection
.active_mut()
.values_mut()
{
let required_ids = ic.stage.function.required_ids();
if !required_ids.is_disjoint(&substituted_variables) {
substitute_acyclic(&mut ic.stage.function, acyclic)?;
}
}
for (&cid, oh) in self.one_hot_constraint_collection.active().iter() {
for var_id in &oh.variables {
if substituted_variables.contains(var_id) {
return Err(SubstitutionError::OneHotVariableSubstitution {
variable: *var_id,
constraint_id: cid,
});
}
}
}
for (&cid, sos1) in self.sos1_constraint_collection.active().iter() {
for var_id in &sos1.variables {
if substituted_variables.contains(var_id) {
return Err(SubstitutionError::Sos1VariableSubstitution {
variable: *var_id,
constraint_id: cid,
});
}
}
}
substitute_acyclic(&mut self.decision_variable_dependency, acyclic)?;
Ok(self)
}
fn substitute_one(
self,
assigned: VariableID,
f: &Function,
) -> Result<Self::Output, SubstitutionError> {
substitute_one_via_acyclic(self, assigned, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{coeff, constraint::Equality, linear, DecisionVariable, Sense};
use std::collections::BTreeMap;
#[test]
fn test_instance_substitute() {
let mut decision_variables = BTreeMap::new();
decision_variables.insert(
VariableID::from(1),
DecisionVariable::continuous(VariableID::from(1)),
);
decision_variables.insert(
VariableID::from(2),
DecisionVariable::continuous(VariableID::from(2)),
);
let objective = Function::from(linear!(1) + coeff!(2.0) * linear!(2));
let constraint_function = Function::from(linear!(1) + linear!(2) + coeff!(-10.0));
let mut constraints = BTreeMap::new();
let constraint = Constraint {
equality: Equality::LessThanOrEqualToZero,
stage: crate::constraint::CreatedData {
function: constraint_function,
},
};
constraints.insert(ConstraintID::from(1), constraint);
let instance =
Instance::new(Sense::Minimize, objective, decision_variables, constraints).unwrap();
let substitution = Function::from(linear!(3) + coeff!(1.0));
let result = instance
.substitute_one(VariableID::from(1), &substitution)
.unwrap();
assert_eq!(result.decision_variable_dependency.len(), 1);
assert!(result
.decision_variable_dependency
.get(&VariableID::from(1))
.is_some());
}
#[test]
fn test_substitute_indicator_function() {
let mut decision_variables = BTreeMap::new();
decision_variables.insert(
VariableID::from(1),
DecisionVariable::continuous(VariableID::from(1)),
);
decision_variables.insert(
VariableID::from(2),
DecisionVariable::continuous(VariableID::from(2)),
);
decision_variables.insert(
VariableID::from(10),
DecisionVariable::binary(VariableID::from(10)),
);
let objective = Function::from(linear!(1));
let mut indicator_constraints = BTreeMap::new();
indicator_constraints.insert(
crate::IndicatorConstraintID::from(1),
crate::IndicatorConstraint::new(
VariableID::from(10),
Equality::LessThanOrEqualToZero,
Function::from(linear!(1) + coeff!(-5.0)),
),
);
let instance = Instance::builder()
.sense(Sense::Minimize)
.objective(objective)
.decision_variables(decision_variables)
.constraints(BTreeMap::new())
.indicator_constraints(indicator_constraints)
.build()
.unwrap();
let assignments = crate::AcyclicAssignments::new(vec![(
VariableID::from(1),
Function::from(linear!(2) + coeff!(1.0)),
)])
.unwrap();
let result = instance.substitute_acyclic(&assignments).unwrap();
assert_eq!(result.indicator_constraints().len(), 1);
}
#[test]
fn test_substitute_indicator_variable_fails() {
let mut decision_variables = BTreeMap::new();
decision_variables.insert(
VariableID::from(1),
DecisionVariable::continuous(VariableID::from(1)),
);
decision_variables.insert(
VariableID::from(10),
DecisionVariable::binary(VariableID::from(10)),
);
let objective = Function::from(linear!(1));
let mut indicator_constraints = BTreeMap::new();
indicator_constraints.insert(
crate::IndicatorConstraintID::from(1),
crate::IndicatorConstraint::new(
VariableID::from(10),
Equality::LessThanOrEqualToZero,
Function::from(linear!(1) + coeff!(-5.0)),
),
);
let instance = Instance::builder()
.sense(Sense::Minimize)
.objective(objective)
.decision_variables(decision_variables)
.constraints(BTreeMap::new())
.indicator_constraints(indicator_constraints)
.build()
.unwrap();
let assignments = crate::AcyclicAssignments::new(vec![(
VariableID::from(10),
Function::from(coeff!(1.0)),
)])
.unwrap();
let result = instance.substitute_acyclic(&assignments);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SubstitutionError::IndicatorVariableSubstitution {
indicator_variable,
constraint_id,
} if indicator_variable == VariableID::from(10)
&& constraint_id == crate::IndicatorConstraintID::from(1)
));
}
}