use crate::{
parse::{as_constraint_id, as_variable_id, Parse, ParseError, RawParseError},
v1::{self, State},
ATol, Constraint, ConstraintHintsError, ConstraintID, DecisionVariable, InstanceError,
RemovedConstraint, VariableID,
};
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, PartialEq)]
pub(super) enum OneHotPartialEvaluateResult {
Updated(OneHot),
AdditionalFix(State),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OneHot {
pub id: ConstraintID,
pub variables: BTreeSet<VariableID>,
}
impl OneHot {
pub(super) fn partial_evaluate(
mut self,
state: &State,
atol: ATol,
) -> Result<OneHotPartialEvaluateResult, ConstraintHintsError> {
let mut fixed_to_one: Option<VariableID> = None;
let mut variables_to_remove = Vec::new();
for &var_id in &self.variables {
let Some(&value) = state.entries.get(&var_id) else {
continue;
};
if value.abs() < atol {
variables_to_remove.push(var_id);
continue;
}
if (value - 1.0).abs() < atol {
if let Some(first_var) = fixed_to_one {
return Err(ConstraintHintsError::OneHotMultipleNonZeroFixed {
constraint_id: self.id,
variables: vec![(first_var, 1.0), (var_id, value)],
});
}
fixed_to_one = Some(var_id);
variables_to_remove.push(var_id);
continue;
}
return Err(ConstraintHintsError::OneHotInvalidFixedValue {
constraint_id: self.id,
variable_id: var_id,
value,
});
}
for var_id in variables_to_remove {
self.variables.remove(&var_id);
}
if fixed_to_one.is_some() {
let mut additional_fixes = State::default();
for &var_id in &self.variables {
additional_fixes.entries.insert(*var_id, 0.0);
}
Ok(OneHotPartialEvaluateResult::AdditionalFix(additional_fixes))
} else if self.variables.is_empty() {
Err(ConstraintHintsError::OneHotAllVariablesFixedToZero {
constraint_id: self.id,
})
} else {
Ok(OneHotPartialEvaluateResult::Updated(self))
}
}
}
impl Parse for v1::OneHot {
type Output = OneHot;
type Context = (
BTreeMap<VariableID, DecisionVariable>,
BTreeMap<ConstraintID, Constraint>,
BTreeMap<ConstraintID, RemovedConstraint>,
);
fn parse(
self,
(decision_variable, constraints, removed_constraints): &Self::Context,
) -> Result<Self::Output, ParseError> {
let message = "ommx.v1.OneHot";
let constraint_id = as_constraint_id(constraints, removed_constraints, self.constraint_id)
.map_err(|e| e.context(message, "constraint_id"))?;
let mut variables = BTreeSet::new();
for v in &self.decision_variables {
let id = as_variable_id(decision_variable, *v)
.map_err(|e| e.context(message, "decision_variables"))?;
if !variables.insert(id) {
return Err(
RawParseError::InstanceError(InstanceError::NonUniqueVariableID { id })
.context(message, "decision_variables"),
);
}
}
Ok(OneHot {
id: constraint_id,
variables,
})
}
}
impl From<OneHot> for v1::OneHot {
fn from(value: OneHot) -> Self {
Self {
constraint_id: *value.id,
decision_variables: value.variables.into_iter().map(|v| *v).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partial_evaluate_removes_zero_variables() {
let one_hot = OneHot {
id: ConstraintID::from(100),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(2, 0.0);
let result = one_hot.partial_evaluate(&state, ATol::default()).unwrap();
match result {
OneHotPartialEvaluateResult::Updated(updated) => {
assert_eq!(updated.variables.len(), 2);
assert!(updated.variables.contains(&VariableID::from(1)));
assert!(!updated.variables.contains(&VariableID::from(2)));
assert!(updated.variables.contains(&VariableID::from(3)));
}
_ => panic!("Expected Updated result"),
}
}
#[test]
fn test_partial_evaluate_fixes_others_when_one_is_fixed() {
let one_hot = OneHot {
id: ConstraintID::from(100),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(2, 1.0);
let result = one_hot.partial_evaluate(&state, ATol::default()).unwrap();
match result {
OneHotPartialEvaluateResult::AdditionalFix(fixes) => {
assert_eq!(fixes.entries.len(), 2); assert_eq!(fixes.entries.get(&1), Some(&0.0));
assert_eq!(fixes.entries.get(&3), Some(&0.0));
}
_ => panic!("Expected AdditionalFix result"),
}
}
#[test]
fn test_partial_evaluate_error_on_invalid_value() {
let one_hot = OneHot {
id: ConstraintID::from(100),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(2, 0.5);
let result = one_hot.partial_evaluate(&state, ATol::default());
match result {
Err(ConstraintHintsError::OneHotInvalidFixedValue {
variable_id, value, ..
}) => {
assert_eq!(variable_id, VariableID::from(2));
assert_eq!(value, 0.5);
}
_ => panic!("Expected InvalidFixedValue error"),
}
}
#[test]
fn test_partial_evaluate_error_on_multiple_ones() {
let one_hot = OneHot {
id: ConstraintID::from(100),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(1, 1.0);
state.entries.insert(2, 1.0);
let result = one_hot.partial_evaluate(&state, ATol::default());
match result {
Err(ConstraintHintsError::OneHotMultipleNonZeroFixed { variables, .. }) => {
assert_eq!(variables.len(), 2);
}
_ => panic!("Expected MultipleNonZeroFixed error"),
}
}
#[test]
fn test_partial_evaluate_all_zeros_error() {
let one_hot = OneHot {
id: ConstraintID::from(100),
variables: vec![
VariableID::from(1),
VariableID::from(2),
VariableID::from(3),
]
.into_iter()
.collect(),
};
let mut state = State::default();
state.entries.insert(1, 0.0);
state.entries.insert(2, 0.0);
state.entries.insert(3, 0.0);
let result = one_hot.partial_evaluate(&state, ATol::default());
match result {
Err(ConstraintHintsError::OneHotAllVariablesFixedToZero { constraint_id }) => {
assert_eq!(constraint_id, ConstraintID::from(100));
}
_ => panic!("Expected AllVariablesFixedToZero error when all variables are 0"),
}
}
}