use twine_core::{EquationProblem, Model, OptimizationProblem};
use twine_solvers::{equation::bisection, optimization::golden_section, transient::euler};
pub trait HasResidual {
fn residual(&self) -> f64;
}
pub trait HasObjective {
fn objective(&self) -> f64;
}
pub trait CanStopEarly {
fn stop_early() -> Self;
}
pub trait CanAssumeWorse {
fn assume_worse() -> Self;
}
impl<M, P> HasResidual for bisection::Event<'_, M, P>
where
M: Model,
P: EquationProblem<1, Input = M::Input, Output = M::Output>,
{
fn residual(&self) -> f64 {
match self {
bisection::Event::Evaluated { point, .. } => point.residual,
bisection::Event::ModelFailed { .. } | bisection::Event::ProblemFailed { .. } => {
f64::NAN
}
}
}
}
impl<M, P> HasObjective for golden_section::Event<'_, M, P>
where
M: Model,
P: OptimizationProblem<1, Input = M::Input, Output = M::Output>,
{
fn objective(&self) -> f64 {
match self {
golden_section::Event::Evaluated { point, .. } => point.objective,
golden_section::Event::ModelFailed { .. }
| golden_section::Event::ProblemFailed { .. } => f64::NAN,
}
}
}
impl CanStopEarly for bisection::Action {
fn stop_early() -> Self {
Self::StopEarly
}
}
impl CanStopEarly for golden_section::Action {
fn stop_early() -> Self {
Self::StopEarly
}
}
impl CanStopEarly for euler::Action {
fn stop_early() -> Self {
Self::StopEarly
}
}
impl CanAssumeWorse for golden_section::Action {
fn assume_worse() -> Self {
Self::AssumeWorse
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{convert::Infallible, error::Error, fmt};
use approx::assert_relative_eq;
use twine_core::{EquationProblem, Model, OptimizationProblem};
use twine_solvers::{
equation::bisection,
optimization::golden_section::{self, Point},
};
struct Identity;
impl Model for Identity {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn call(&self, input: &f64) -> Result<f64, Infallible> {
Ok(*input)
}
}
#[derive(Debug)]
struct Failure;
impl fmt::Display for Failure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "failure")
}
}
impl Error for Failure {}
struct FailingModel;
impl Model for FailingModel {
type Input = f64;
type Output = f64;
type Error = Failure;
fn call(&self, _: &f64) -> Result<f64, Failure> {
Err(Failure)
}
}
struct LinearProblem;
impl EquationProblem<1> for LinearProblem {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn input(&self, x: &[f64; 1]) -> Result<f64, Infallible> {
Ok(x[0])
}
fn residuals(&self, _: &f64, output: &f64) -> Result<[f64; 1], Infallible> {
Ok([*output])
}
}
impl OptimizationProblem<1> for LinearProblem {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn input(&self, x: &[f64; 1]) -> Result<f64, Infallible> {
Ok(x[0])
}
fn objective(&self, _: &f64, output: &f64) -> Result<f64, Infallible> {
Ok(*output)
}
}
struct FailingEqProblem;
impl EquationProblem<1> for FailingEqProblem {
type Input = f64;
type Output = f64;
type Error = Failure;
fn input(&self, x: &[f64; 1]) -> Result<f64, Failure> {
Ok(x[0])
}
fn residuals(&self, _: &f64, _: &f64) -> Result<[f64; 1], Failure> {
Err(Failure)
}
}
struct FailingOptProblem;
impl OptimizationProblem<1> for FailingOptProblem {
type Input = f64;
type Output = f64;
type Error = Failure;
fn input(&self, x: &[f64; 1]) -> Result<f64, Failure> {
Ok(x[0])
}
fn objective(&self, _: &f64, _: &f64) -> Result<f64, Failure> {
Err(Failure)
}
}
fn test_bracket() -> bisection::Bracket {
bisection::Bracket::new(
(0.0, bisection::Sign::Negative),
(1.0, bisection::Sign::Positive),
)
.unwrap()
}
#[test]
fn bisection_residual_evaluated() {
let input = 1.0_f64;
let output = 1.0_f64;
let bracket = test_bracket();
let event: bisection::Event<'_, Identity, LinearProblem> = bisection::Event::Evaluated {
point: bisection::Point::new(1.0, 0.5),
input: &input,
output: &output,
bracket: &bracket,
};
assert_relative_eq!(event.residual(), 0.5);
}
#[test]
fn bisection_residual_nan_on_model_failed() {
let error = Failure;
let bracket = test_bracket();
let event: bisection::Event<'_, FailingModel, LinearProblem> =
bisection::Event::ModelFailed {
x: 0.5,
error: &error,
bracket: &bracket,
};
assert!(event.residual().is_nan());
}
#[test]
fn bisection_residual_nan_on_problem_failed() {
let error = Failure;
let bracket = test_bracket();
let event: bisection::Event<'_, Identity, FailingEqProblem> =
bisection::Event::ProblemFailed {
x: 0.5,
error: &error,
bracket: &bracket,
};
assert!(event.residual().is_nan());
}
#[test]
fn golden_section_objective_evaluated() {
let input = 1.0_f64;
let output = 1.0_f64;
let event: golden_section::Event<'_, Identity, LinearProblem> =
golden_section::Event::Evaluated {
point: Point::new(1.0, 7.5),
input: &input,
output: &output,
other: Point::new(0.5, 4.0),
};
assert_relative_eq!(event.objective(), 7.5);
}
#[test]
fn golden_section_objective_nan_on_model_failed() {
let error = Failure;
let event: golden_section::Event<'_, FailingModel, LinearProblem> =
golden_section::Event::ModelFailed {
x: 0.5,
other: Point::new(0.5, 1.0),
error: &error,
};
assert!(event.objective().is_nan());
}
#[test]
fn golden_section_objective_nan_on_problem_failed() {
let error = Failure;
let event: golden_section::Event<'_, Identity, FailingOptProblem> =
golden_section::Event::ProblemFailed {
x: 0.5,
other: Point::new(0.5, 1.0),
error: &error,
};
assert!(event.objective().is_nan());
}
}