use crate::{
parse::{as_constraint_id, as_variable_id, Parse, ParseError, RawParseError},
v1::{self, State},
ATol, Constraint, ConstraintHintsError, ConstraintID, DecisionVariable, InstanceError,
RemovedConstraint, VariableID,
};
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, PartialEq)]
pub(super) enum Sos1PartialEvaluateResult {
Updated(Sos1),
AdditionalFix(State),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Sos1 {
pub binary_constraint_id: ConstraintID,
pub big_m_constraint_ids: BTreeSet<ConstraintID>,
pub variables: BTreeSet<VariableID>,
}
impl Sos1 {
pub(super) fn partial_evaluate(
mut self,
state: &State,
atol: ATol,
) -> Result<Sos1PartialEvaluateResult, ConstraintHintsError> {
let mut fixed_to_nonzero: Option<(VariableID, f64)> = None;
let mut variables_to_remove = Vec::new();
for &var_id in &self.variables {
let Some(&value) = state.entries.get(&var_id) else {
continue;
};
if value.abs() < atol {
variables_to_remove.push(var_id);
continue;
}
if let Some((first_var, first_value)) = fixed_to_nonzero {
return Err(ConstraintHintsError::Sos1MultipleNonZeroFixed {
binary_constraint_id: self.binary_constraint_id,
variables: vec![(first_var, first_value), (var_id, value)],
});
}
fixed_to_nonzero = Some((var_id, value));
variables_to_remove.push(var_id);
}
for var_id in variables_to_remove {
self.variables.remove(&var_id);
}
if fixed_to_nonzero.is_some() {
let mut additional_fixes = State::default();
for &var_id in &self.variables {
additional_fixes.entries.insert(*var_id, 0.0);
}
Ok(Sos1PartialEvaluateResult::AdditionalFix(additional_fixes))
} else {
Ok(Sos1PartialEvaluateResult::Updated(self))
}
}
}
impl Parse for v1::Sos1 {
type Output = Sos1;
type Context = (
BTreeMap<VariableID, DecisionVariable>,
BTreeMap<ConstraintID, Constraint>,
BTreeMap<ConstraintID, RemovedConstraint>,
);
fn parse(
self,
(decision_variable, constraints, removed_constraints): &Self::Context,
) -> Result<Self::Output, ParseError> {
let message = "ommx.v1.Sos1";
let binary_constraint_id =
as_constraint_id(constraints, removed_constraints, self.binary_constraint_id)
.map_err(|e| e.context(message, "binary_constraint_id"))?;
let mut big_m_constraint_ids = BTreeSet::new();
for id in &self.big_m_constraint_ids {
let id = as_constraint_id(constraints, removed_constraints, *id)
.map_err(|e| e.context(message, "big_m_constraint_ids"))?;
if !big_m_constraint_ids.insert(id) {
return Err(
RawParseError::InstanceError(InstanceError::NonUniqueConstraintID { id })
.context(message, "big_m_constraint_ids"),
);
}
}
let mut variables = BTreeSet::new();
for id in &self.decision_variables {
let id = as_variable_id(decision_variable, *id)
.map_err(|e| e.context(message, "decision_variables"))?;
if !variables.insert(id) {
return Err(
RawParseError::InstanceError(InstanceError::NonUniqueVariableID { id })
.context(message, "decision_variables"),
);
}
}
Ok(Sos1 {
binary_constraint_id,
big_m_constraint_ids,
variables,
})
}
}
impl From<Sos1> for v1::Sos1 {
fn from(value: Sos1) -> Self {
Self {
binary_constraint_id: *value.binary_constraint_id,
big_m_constraint_ids: value.big_m_constraint_ids.into_iter().map(|c| *c).collect(),
decision_variables: value.variables.into_iter().map(|v| *v).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partial_evaluate_removes_zero_variables() {
let sos1 = Sos1 {
binary_constraint_id: ConstraintID::from(100),
big_m_constraint_ids: vec![
ConstraintID::from(101),
ConstraintID::from(102),
ConstraintID::from(103),
]
.into_iter()
.collect(),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(2, 0.0);
let result = sos1.partial_evaluate(&state, ATol::default()).unwrap();
match result {
Sos1PartialEvaluateResult::Updated(updated) => {
assert_eq!(updated.variables.len(), 2);
assert!(updated.variables.contains(&VariableID::from(1)));
assert!(!updated.variables.contains(&VariableID::from(2)));
assert!(updated.variables.contains(&VariableID::from(3)));
assert_eq!(updated.binary_constraint_id, ConstraintID::from(100));
assert_eq!(updated.big_m_constraint_ids.len(), 3);
}
_ => panic!("Expected Updated result"),
}
}
#[test]
fn test_partial_evaluate_fixes_others_when_one_is_nonzero() {
let sos1 = Sos1 {
binary_constraint_id: ConstraintID::from(100),
big_m_constraint_ids: BTreeSet::new(),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(2, 1.0);
let result = sos1.partial_evaluate(&state, ATol::default()).unwrap();
match result {
Sos1PartialEvaluateResult::AdditionalFix(fixes) => {
assert_eq!(fixes.entries.len(), 2); assert_eq!(fixes.entries.get(&1), Some(&0.0));
assert_eq!(fixes.entries.get(&3), Some(&0.0));
}
_ => panic!("Expected AdditionalFix result"),
}
}
#[test]
fn test_partial_evaluate_error_on_multiple_nonzero() {
let sos1 = Sos1 {
binary_constraint_id: ConstraintID::from(100),
big_m_constraint_ids: BTreeSet::new(),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(1, 1.0);
state.entries.insert(2, 2.0);
let result = sos1.partial_evaluate(&state, ATol::default());
match result {
Err(ConstraintHintsError::Sos1MultipleNonZeroFixed { variables, .. }) => {
assert_eq!(variables.len(), 2);
}
_ => panic!("Expected MultipleNonZeroFixed error"),
}
}
#[test]
fn test_partial_evaluate_all_zeros_valid() {
let sos1 = Sos1 {
binary_constraint_id: ConstraintID::from(100),
big_m_constraint_ids: BTreeSet::new(),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(1, 0.0);
state.entries.insert(2, 0.0);
state.entries.insert(3, 0.0);
let result = sos1.partial_evaluate(&state, ATol::default()).unwrap();
match result {
Sos1PartialEvaluateResult::Updated(updated) => {
assert_eq!(updated.variables.len(), 0); }
_ => panic!("Expected Updated result when all variables are 0"),
}
}
}