use crate::model::builder::SeparableModelBuilder;
use crate::model::errors::ModelError;
use crate::prelude::*;
use crate::test_helpers;
use assert_matches::assert_matches;
use nalgebra::DMatrix;
use nalgebra::DVector;
mockall::mock! {
pub SeparableNonlinearModel {
pub fn parameter_count(&self) -> usize;
pub fn base_function_count(&self) -> usize;
pub fn output_len(&self) -> usize;
pub fn set_params(&mut self, parameters : DVector<f64>) -> Result<(),MockModelError>;
pub fn params(&self) -> DVector<f64>;
pub fn eval(
&self,
) -> Result<DMatrix<f64>, MockModelError>;
pub fn eval_partial_deriv(
&self,
derivative_index: usize,
) -> Result<DMatrix<f64>, MockModelError>;
}
impl Clone for SeparableNonlinearModel {
fn clone(&self) -> Self;
}
}
#[derive(Debug, thiserror::Error)]
pub enum MockModelError {
#[error("MockModelError: {}", 0)]
Error(String),
}
impl<S> From<S> for MockModelError
where
S: Into<String>,
{
fn from(s: S) -> Self {
MockModelError::Error(s.into())
}
}
impl SeparableNonlinearModel for MockSeparableNonlinearModel {
type Error = MockModelError;
type ScalarType = f64;
fn parameter_count(&self) -> usize {
self.parameter_count()
}
fn base_function_count(&self) -> usize {
self.base_function_count()
}
fn output_len(&self) -> usize {
self.output_len()
}
fn set_params(&mut self, parameters: DVector<f64>) -> Result<(), Self::Error> {
self.set_params(parameters)
}
fn params(&self) -> DVector<f64> {
self.params()
}
fn eval(&self) -> Result<DMatrix<f64>, Self::Error> {
self.eval()
}
fn eval_partial_deriv(&self, derivative_index: usize) -> Result<DMatrix<f64>, Self::Error> {
self.eval_partial_deriv(derivative_index)
}
}
#[test]
fn model_gets_initialized_with_correct_parameter_names_and_count() {
let model = test_helpers::get_double_exponential_model_with_constant_offset(
DVector::zeros(10),
vec![1., 2.],
);
assert_eq!(
model.parameter_count(),
2,
"Double exponential model has 2 parameters"
);
assert_eq!(
model.parameters(),
&["tau1", "tau2"],
"Double exponential model has 2 parameters"
);
}
#[test]
fn model_function_eval_produces_correct_result() {
let tvec = DVector::from(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]);
let tau1 = 1.;
let tau2 = 3.;
let params = &[tau1, tau2];
let model = test_helpers::get_double_exponential_model_with_constant_offset(
tvec.clone(),
params.to_vec(),
);
let eval_matrix = model.eval().expect("Model evaluation should not fail");
let mut expected_eval_matrix = DMatrix::zeros(eval_matrix.nrows(), eval_matrix.ncols());
expected_eval_matrix.set_column(0, &test_helpers::exp_decay(&tvec, tau2));
expected_eval_matrix.set_column(1, &test_helpers::exp_decay(&tvec, tau1));
expected_eval_matrix.set_column(2, &DVector::from_element(tvec.len(), 1.));
assert_eq!(
eval_matrix, expected_eval_matrix,
"Model evaluation should produce the expected evaluation"
);
}
#[test]
fn model_function_eval_fails_for_invalid_length_of_return_value_in_base_function() {
let tvec = DVector::from(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]);
let model_with_bad_function = SeparableModelBuilder::<f64>::new(&["tau1", "tau2"])
.function(&["tau2"], test_helpers::exp_decay)
.partial_deriv("tau2", test_helpers::exp_decay_dtau)
.function(&["tau1"], |_t: &DVector<_>, _tau| {
DVector::from(vec![1., 3., 3., 7.])
})
.partial_deriv("tau1", test_helpers::exp_decay_dtau)
.initial_parameters(vec![2., 4.])
.independent_variable(tvec)
.build()
.expect("Model function creation should not fail, although function is bad");
assert_matches!(
model_with_bad_function.eval(),
Err(ModelError::UnexpectedFunctionOutput {
actual_length: 4,
..
}),
"Model must report an error when evaluated with a function that does not return the same length vector as independent variable"
);
}
#[test]
fn model_function_parameter_setting_fails_for_incorrect_number_of_parameters() {
let tvec = DVector::from(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]);
let params = vec![1., 2.];
let mut model = test_helpers::get_double_exponential_model_with_constant_offset(tvec, params);
assert_eq!(
model.parameter_count(),
2,
"double exponential model should have 2 params"
);
assert_matches!(
model.set_params(DVector::from_vec(vec![1.])),
Err(ModelError::IncorrectParameterCount { .. })
);
}
#[test]
fn model_derivative_evaluation_produces_correct_result() {
let ones = |t: &DVector<_>| DVector::from_element(t.len(), 1.);
let tvec = DVector::from(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]);
let tau = 3.;
let omega = 1.5;
let params = &[tau, omega];
let model = SeparableModelBuilder::<f64>::new(&["tau", "omega"])
.independent_variable(tvec.clone())
.initial_parameters(params.to_vec())
.function(&["tau"], test_helpers::exp_decay)
.partial_deriv("tau", test_helpers::exp_decay_dtau)
.invariant_function(ones)
.function(&["omega", "tau"], test_helpers::sin_ometa_t_plus_phi) .partial_deriv("tau", test_helpers::sin_ometa_t_plus_phi_dphi)
.partial_deriv("omega", test_helpers::sin_ometa_t_plus_phi_domega)
.build()
.expect("Valid model creation should not fail");
let deriv_tau = model
.eval_partial_deriv(0)
.expect("Derivative eval must not fail");
let deriv_omega = model
.eval_partial_deriv(1)
.expect("Derivative eval must not fail");
assert!(
deriv_tau.ncols() == 3 && deriv_tau.nrows() == tvec.len(),
"Deriv tau matrix does not have correct dimensions"
);
assert_eq!(
DVector::from(deriv_tau.column(0)),
test_helpers::exp_decay_dtau(&tvec, tau)
);
assert_eq!(
DVector::from(deriv_tau.column(1)),
DVector::from_element(tvec.len(), 0.)
);
assert_eq!(
DVector::from(deriv_tau.column(2)),
test_helpers::sin_ometa_t_plus_phi_dphi(&tvec, omega, tau)
);
assert!(
deriv_omega.ncols() == 3 && deriv_omega.nrows() == tvec.len(),
"Deriv omega matrix does not have correct dimensions"
);
assert_eq!(
DVector::from(deriv_omega.column(0)),
DVector::from_element(tvec.len(), 0.)
);
assert_eq!(
DVector::from(deriv_omega.column(1)),
DVector::from_element(tvec.len(), 0.)
);
assert_eq!(
DVector::from(deriv_omega.column(2)),
test_helpers::sin_ometa_t_plus_phi_domega(&tvec, omega, tau)
);
}
#[test]
fn model_derivative_evaluation_error_cases() {
let tvec = DVector::from(vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]);
let model_with_bad_function = SeparableModelBuilder::<f64>::new(&["tau1", "tau2"])
.independent_variable(tvec)
.function(&["tau2"], test_helpers::exp_decay)
.partial_deriv("tau2", test_helpers::exp_decay_dtau)
.function(&["tau1"], test_helpers::exp_decay)
.partial_deriv("tau1", |_t: &DVector<_>, _tau| {
DVector::from(vec![1., 3., 3., 7.])
})
.initial_parameters(vec![2., 4.])
.build()
.expect("Model function creation should not fail, although function is bad");
assert_matches!(
model_with_bad_function.eval_partial_deriv(0),
Err(ModelError::UnexpectedFunctionOutput { .. }),
"Derivative for invalid function must fail with correct error"
);
assert!(
model_with_bad_function.eval_partial_deriv(1).is_ok(),
"Derivative eval for valid function should return Ok result"
);
assert_matches!(
model_with_bad_function.eval_partial_deriv(100),
Err(ModelError::DerivativeIndexOutOfBounds { .. }),
"Derivative for invalid function must fail with correct error"
);
assert_matches!(
model_with_bad_function.eval_partial_deriv(3),
Err(ModelError::DerivativeIndexOutOfBounds { .. }),
"Derivative for invalid function must fail with correct error"
);
}