use super::*;
use assert_matches::assert_matches;
use nalgebra::DVector;
#[test]
fn builder_fails_for_invalid_model_parameters() {
let result =
SeparableModelBuilder::<f64>::new(["a".to_string(), "b".to_string(), "b".to_string()])
.build();
assert_matches!(
result,
Err(ModelBuildError::DuplicateParameterNames { .. }),
"Duplicate parameter error must be emitted when creating model with duplicate params"
);
let result = SeparableModelBuilder::<f64>::new(Vec::<String>::default()).build();
assert_matches!(
result,
Err(ModelBuildError::EmptyParameters),
"Creating model with empty parameters must fail with correct error"
);
let result =
SeparableModelBuilder::<f64>::new(["a".to_string(), "b".to_string(), "c".to_string()])
.build();
assert_matches!(
result,
Err(ModelBuildError::EmptyModel),
"Creating model without functions must fail with correct error"
);
}
#[test]
fn builder_fails_when_not_all_model_parameters_are_depended_on_by_the_modelfunctions() {
let result =
SeparableModelBuilder::<f64>::new(["a".to_string(), "b".to_string(), "c".to_string()])
.invariant_function(|_| unimplemented!())
.function(
["a".to_string()],
|_: &DVector<f64>, _: f64| unimplemented!(),
)
.partial_deriv("a", |_: &DVector<f64>, _: f64| unimplemented!())
.build();
assert_matches!(
result,
Err(ModelBuildError::UnusedParameter { .. }),
"Duplicate parameter error must be emitted when creating model with duplicate params"
);
}
#[test]
fn builder_fails_when_not_all_required_partial_derivatives_are_given_for_function() {
let result = SeparableModelBuilder::<f64>::new(["a".to_string(), "b".to_string()])
.function(
["a".to_string(), "b".to_string()],
|_: &DVector<f64>, _: f64, _: f64| unimplemented!(),
)
.partial_deriv("a", |_: &DVector<f64>, _: f64, _: f64| unimplemented!())
.build();
assert_matches!(
result,
Err(ModelBuildError::MissingDerivative { .. }),
"Duplicate parameter error must be emitted when creating model with duplicate params"
);
}
fn identity_function<T: Clone>(x: &T) -> T {
x.clone()
}
#[test]
fn builder_produces_correct_model_from_functions() {
let ts = DVector::<f64>::from(vec![
1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
]);
let t0 = 3.;
let tau = 2.;
let omega1 = std::f64::consts::FRAC_1_PI * 3.;
let omega2 = std::f64::consts::FRAC_1_PI * 2.;
let params = vec![t0, tau, omega1, omega2];
let model = SeparableModelBuilder::<f64>::new([
"t0".to_string(),
"tau".to_string(),
"omega1".to_string(),
"omega2".to_string(),
])
.invariant_function(|x| 2. * identity_function(x)) .function(["t0".to_string(), "tau".to_string()], exponential_decay)
.partial_deriv("tau", exponential_decay_dtau)
.partial_deriv("t0", exponential_decay_dt0)
.invariant_function(identity_function)
.function(["omega1".to_string()], sinusoid_omega)
.partial_deriv("omega1", sinusoid_omega_domega)
.function(["omega2".to_string()], sinusoid_omega)
.partial_deriv("omega2", sinusoid_omega_domega)
.independent_variable(ts.clone())
.initial_parameters(params.to_vec())
.build()
.expect("Valid builder calls should produce a valid model function.");
assert_eq!(
model.basefunctions.len(),
5,
"Number of functions in model is incorrect"
);
let func = &model.basefunctions[0];
assert!(
func.derivatives.is_empty(),
"This function should have no derivatives"
);
assert_eq!(
(func.function)(&ts, ¶ms),
2. * ts.clone(),
"Function should be f(x)=2x"
);
let func = &model.basefunctions[1];
assert_eq!(func.derivatives.len(), 2, "Incorrect number of derivatives");
assert_eq!(
(func.function)(&ts, ¶ms),
exponential_decay(&ts, t0, tau),
"Incorrect function value"
);
assert_eq!(
(func.derivatives.get(&0).unwrap())(&ts, ¶ms),
exponential_decay_dt0(&ts, t0, tau),
"Incorrect first derivative value"
);
assert_eq!(
(func.derivatives.get(&1).unwrap())(&ts, ¶ms),
exponential_decay_dtau(&ts, t0, tau),
"Incorrect second derivative value"
);
let func = &model.basefunctions[2];
assert!(
func.derivatives.is_empty(),
"This function should have no derivatives"
);
assert_eq!(
(func.function)(&ts, ¶ms),
ts.clone(),
"Function should be f(x)=2x"
);
let func = &model.basefunctions[3];
assert_eq!(func.derivatives.len(), 1, "Incorrect number of derivatives");
assert_eq!(
(func.function)(&ts, ¶ms),
sinusoid_omega(&ts, omega1),
"Incorrect function value"
);
assert_eq!(
(func.derivatives.get(&2).unwrap())(&ts, ¶ms),
sinusoid_omega_domega(&ts, omega1),
"Incorrect first derivative value"
);
let func = &model.basefunctions[4];
assert_eq!(func.derivatives.len(), 1, "Incorrect number of derivatives");
assert_eq!(
(func.function)(&ts, ¶ms),
sinusoid_omega(&ts, omega2),
"Incorrect function value"
);
assert_eq!(
(func.derivatives.get(&3).unwrap())(&ts, ¶ms),
sinusoid_omega_domega(&ts, omega2),
"Incorrect first derivative value"
);
}
#[test]
fn test_model_builder_fails_when_x_data_is_missing() {
fn foo(_: &DVector<f64>, _: f64, _: f64) -> DVector<f64> {
todo!()
}
let result = SeparableModelBuilder::<f64>::new(&["a", "b", "c"])
.function(&["a", "c"], foo)
.partial_deriv("a", foo)
.partial_deriv("c", |_x: &DVector<f64>, _a: f64, _c: f64| todo!()) .function(&["b", "c"], foo)
.partial_deriv("b", foo)
.partial_deriv("c", foo)
.initial_parameters(vec![1., 2., 3.])
.build();
assert_matches!(result, Err(ModelBuildError::MissingX));
}
#[test]
fn test_model_builder_fails_when_initial_parameters_are_missing() {
fn foo(_: &DVector<f64>, _: f64, _: f64) -> DVector<f64> {
todo!()
}
let result = SeparableModelBuilder::<f64>::new(&["a", "b", "c"])
.function(&["a", "c"], foo)
.partial_deriv("a", foo)
.partial_deriv("c", foo) .function(&["b", "c"], foo)
.partial_deriv("b", foo)
.partial_deriv("c", foo)
.independent_variable(DVector::from_vec(vec![1., 2., 3.]))
.build();
assert_matches!(result, Err(ModelBuildError::MissingInitialParameters));
}
#[test]
fn test_model_builder_fails_when_initial_parameters_have_incorrect_parameter_count() {
fn foo(_: &DVector<f64>, _: f64, _: f64) -> DVector<f64> {
todo!()
}
let result = SeparableModelBuilder::<f64>::new(&["a", "b", "c"])
.function(&["a", "c"], foo)
.partial_deriv("a", foo)
.partial_deriv("c", foo) .function(&["b", "c"], foo)
.partial_deriv("b", foo)
.partial_deriv("c", foo)
.independent_variable(DVector::from_vec(vec![1., 2., 3., 4., 5.]))
.initial_parameters(vec![1., 3.])
.build();
assert_matches!(result, Err(ModelBuildError::IncorrectParameterCount { .. }));
}
#[test]
fn model_builder_can_be_used_programmatically_in_a_loop() {
let param_names = &["param1", "param2", "param3"];
let mut builder = SeparableModelBuilder::<f64>::new(param_names)
.independent_variable(nalgebra::dvector![1., 2., 3., 4.])
.initial_parameters(vec![1., 2., 3.]);
fn dummy_function(_x: &DVector<f64>, _param: f64) -> DVector<f64> {
todo!()
}
fn dummy_deriv(_x: &DVector<f64>, _param: f64) -> DVector<f64> {
todo!()
}
for param in param_names {
builder = builder
.function(&[*param], dummy_function)
.partial_deriv(*param, dummy_deriv);
}
let _ = builder.build().unwrap();
}
pub fn exponential_decay(tvec: &DVector<f64>, t0: f64, tau: f64) -> DVector<f64> {
assert!(tau > 0f64, "Parameter tau must be greater than zero");
tvec.map(|t| (-(t - t0) / tau).exp())
}
pub fn exponential_decay_dt0(tvec: &DVector<f64>, t0: f64, tau: f64) -> DVector<f64> {
assert!(tau > 0f64, "Parameter tau must be greater than zero");
exponential_decay(tvec, t0, tau).map(|val| val / tau)
}
pub fn exponential_decay_dtau(tvec: &DVector<f64>, t0: f64, tau: f64) -> DVector<f64> {
assert!(tau > 0f64, "Parameter tau must be greater than zero");
tvec.map(|t| ((-t - t0) / tau).exp() * (t0 - t) / tau.powi(2))
}
pub fn sinusoid_omega(tvec: &DVector<f64>, omega: f64) -> DVector<f64> {
tvec.map(|t| (omega * t).sin())
}
pub fn sinusoid_omega_domega(tvec: &DVector<f64>, omega: f64) -> DVector<f64> {
tvec.map(|t| (omega * t).cos() * omega)
}