use super::*;
use crate::{Evaluate, VariableIDSet};
impl Evaluate for NamedFunction {
type Output = EvaluatedNamedFunction;
type SampledOutput = SampledNamedFunction;
fn evaluate(
&self,
solution: &crate::v1::State,
atol: crate::ATol,
) -> anyhow::Result<Self::Output> {
let evaluated_value = self.function.evaluate(solution, atol)?;
let used_decision_variable_ids = self.function.required_ids();
Ok(EvaluatedNamedFunction {
id: self.id,
evaluated_value,
name: self.name.clone(),
subscripts: self.subscripts.clone(),
parameters: self.parameters.clone(),
description: self.description.clone(),
used_decision_variable_ids,
})
}
fn partial_evaluate(
&mut self,
state: &crate::v1::State,
atol: crate::ATol,
) -> anyhow::Result<()> {
self.function.partial_evaluate(state, atol)
}
fn required_ids(&self) -> VariableIDSet {
self.function.required_ids()
}
fn evaluate_samples(
&self,
samples: &crate::v1::Samples,
atol: crate::ATol,
) -> anyhow::Result<Self::SampledOutput> {
let evaluated_values_v1 = self.function.evaluate_samples(samples, atol)?;
let evaluated_values = evaluated_values_v1.try_into()?;
let used_decision_variable_ids = self.function.required_ids();
Ok(SampledNamedFunction {
id: self.id,
evaluated_values,
name: self.name.clone(),
subscripts: self.subscripts.clone(),
parameters: self.parameters.clone(),
description: self.description.clone(),
used_decision_variable_ids,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{coeff, linear, Coefficient, Evaluate, Function, VariableID};
use maplit::btreeset;
#[test]
fn test_evaluate_constant_function() {
let nf = NamedFunction {
id: NamedFunctionID::from(1),
function: Function::Constant(Coefficient::try_from(42.0).unwrap()),
name: Some("my_func".to_string()),
subscripts: vec![1, 2],
parameters: Default::default(),
description: Some("constant function".to_string()),
};
let state = crate::v1::State::default();
let result = nf.evaluate(&state, crate::ATol::default()).unwrap();
assert_eq!(result.id(), NamedFunctionID::from(1));
assert_eq!(result.evaluated_value(), 42.0);
assert_eq!(*result.name(), Some("my_func".to_string()));
assert_eq!(*result.subscripts(), vec![1, 2]);
assert_eq!(*result.description(), Some("constant function".to_string()));
assert!(result.used_decision_variable_ids().is_empty());
}
#[test]
fn test_evaluate_linear_function() {
let nf = NamedFunction {
id: NamedFunctionID::from(2),
function: Function::Linear(coeff!(2.0) * linear!(1) + coeff!(3.0) * linear!(2)),
name: Some("linear_func".to_string()),
subscripts: vec![],
parameters: Default::default(),
description: None,
};
let state = crate::v1::State {
entries: [(1, 5.0), (2, 10.0)].into_iter().collect(),
};
let result = nf.evaluate(&state, crate::ATol::default()).unwrap();
assert_eq!(result.evaluated_value(), 40.0);
assert_eq!(
*result.used_decision_variable_ids(),
btreeset! { VariableID::from(1), VariableID::from(2) }
);
}
#[test]
fn test_required_ids() {
let nf = NamedFunction {
id: NamedFunctionID::from(3),
function: Function::Linear(coeff!(2.0) * linear!(1) + coeff!(3.0) * linear!(2)),
name: None,
subscripts: vec![],
parameters: Default::default(),
description: None,
};
let ids = nf.required_ids();
assert_eq!(ids, btreeset! { VariableID::from(1), VariableID::from(2) });
}
}