use rmpfit::{MPError, MPFitter, MPPar, MPResult};
use crate::model::*;
struct FitToResults<'a> {
pub model: &'a dyn Model,
xs: &'a [f64],
ys: &'a [f64],
param_bounds: Vec<MPPar>,
}
impl MPFitter for FitToResults<'_> {
fn eval(&mut self, params: &[f64], deviates: &mut [f64]) -> MPResult<()> {
for ((d, x), y) in deviates.iter_mut().zip(self.xs.iter()).zip(self.ys.iter()) {
let f = if self.model.constant() {
params[0]
} else {
params[0] + params[1] * self.model.complexity(*x)
};
*d = *y - f;
}
Ok(())
}
fn number_of_points(&self) -> usize {
self.xs.len()
}
fn parameters(&self) -> Option<&[rmpfit::MPPar]> {
Some(&self.param_bounds)
}
}
#[derive(Debug)]
pub struct FittingError {
error: MPError,
}
impl std::fmt::Display for FittingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
}
}
fn aic(num_parameters: usize, num_points: usize, rss: f64) -> f64 {
let num_points = num_points as f64;
let num_parameters = num_parameters as f64;
2.0 * num_parameters + num_points * (rss / num_points).ln()
}
pub(crate) struct Fitting {
pub scaling_params: Vec<f64>,
pub aic: f64,
pub model: BoxedModel,
}
#[cfg(feature = "plots")]
impl Fitting {
pub(crate) fn scaled_complexity(&self, n: f64) -> f64 {
if self.model.constant() {
self.scaling_params[0]
} else {
self.scaling_params[0] + self.scaling_params[1] * self.model.complexity(n)
}
}
}
impl std::fmt::Debug for Fitting {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Fitting(model={:?}, params={:?}, aic={})",
self.model.to_string(),
self.scaling_params,
self.aic
))
}
}
fn fit(model: BoxedModel, xs: &[f64], ys: &[f64]) -> Result<Fitting, FittingError> {
assert_eq!(xs.len(), ys.len());
let mut fitter = FitToResults {
model: &model as &dyn Model,
xs,
ys,
param_bounds: vec![],
};
let num_params = if fitter.model.constant() { 1 } else { 2 };
for _ in 0..num_params {
let param = MPPar {
limited_low: true,
limit_low: 0.0,
..Default::default()
};
fitter.param_bounds.push(param)
}
let mut params = vec![1.0; num_params];
match fitter.mpfit(&mut params) {
Ok(status) => Ok(Fitting {
scaling_params: params,
aic: aic(
num_params,
xs.len(),
status.resid.iter().map(|x| x * x).sum(),
),
model,
}),
Err(error) => Err(FittingError { error }),
}
}
fn expected_or_other_best_fit<M: Model + 'static>(
expected_model: M,
xs: &[f64],
ys: &[f64],
known_models: KnownModels,
) -> Result<Fitting, Fitting> {
let mut expected_fitting_is_best = true;
let expected_model = BoxedModel::new(expected_model);
let mut best_fitting = fit(expected_model, xs, ys).expect("Failed to fit to the model.");
for model in known_models.into_iter() {
if let Ok(fitting) = fit(model, xs, ys) {
if fitting.aic < best_fitting.aic {
expected_fitting_is_best = false;
best_fitting = fitting;
}
};
}
if expected_fitting_is_best {
Ok(best_fitting)
} else {
Err(best_fitting)
}
}
fn validate_fit(fitting: &Fitting, xs: &[f64], ys: &[f64]) -> bool {
match fit(BoxedModel::new(Log(N) * fitting.model.clone()), xs, ys) {
Ok(other_fitting) => other_fitting.aic > fitting.aic,
Err(_) => true,
}
}
#[derive(Debug)]
pub(crate) enum BestFit {
ExpectedModel,
AnotherKnownModel { best_known: Fitting },
NotFound { best_known: Fitting },
}
pub(crate) fn find_best_fit<M: Model + 'static>(
expected_model: M,
xs: &[f64],
ys: &[f64],
known_models: KnownModels,
) -> BestFit {
let fastest_growing = known_models.fastest_growing();
let result = expected_or_other_best_fit(expected_model, xs, ys, known_models);
let validates = match result {
Ok(ref fitting) | Err(ref fitting) => {
if fastest_growing.is_none() || (Some(&fitting.model) == fastest_growing.as_ref()) {
validate_fit(fitting, xs, ys)
} else {
true
}
}
};
match (result, validates) {
(Ok(best_known) | Err(best_known), false) => BestFit::NotFound { best_known },
(Ok(_), true) => BestFit::ExpectedModel,
(Err(best_known), true) => BestFit::AnotherKnownModel { best_known },
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[macro_export]
macro_rules! assert_matches {
($expression:expr, $pattern:pat) => {
assert!(matches!($expression, $pattern));
};
}
#[test]
fn downslope_disallowed() {
let xs = [10., 20., 30., 40.];
let ys = [100., 90., 80., 70.];
assert!(expected_or_other_best_fit(Constant, &xs, &ys, KnownModels::default()).is_ok());
assert_matches!(
find_best_fit(Constant, &xs, &ys, KnownModels::default()),
BestFit::ExpectedModel
);
}
#[test]
fn constant_matching() {
let xs = [10., 20., 30., 40.];
let ys = [100., 90., 80., 110.];
assert!(expected_or_other_best_fit(Constant, &xs, &ys, KnownModels::default()).is_ok());
assert_matches!(
find_best_fit(Constant, &xs, &ys, KnownModels::default()),
BestFit::ExpectedModel
);
}
#[rstest]
#[case(N)]
#[case(Log(N))]
#[case(Sqrt(N))]
#[case(N * Log(N))]
#[case(Log(Log(N)))]
#[case(N2)]
#[case(N3)]
fn model_output_fits_itself<M: Model + PartialEq + Clone + 'static>(#[case] model: M) {
let xs = [
100.0,
140.,
200.,
300.,
500.,
700.,
1000.0,
1400.,
2000.,
3000.,
5000.,
7000.,
10_000.0f64,
];
let errors = [1.0001, 0.999].iter().cycle();
for (a, b) in [(50., 2.), (500_000., 1_000_000.)] {
let ys = xs
.iter()
.zip(errors.clone())
.map(|(x, err)| (a + b * model.complexity(*x)) * err)
.collect::<Vec<f64>>();
let best_fit =
expected_or_other_best_fit(model.clone(), &xs, &ys, KnownModels::default());
assert!(best_fit.is_ok());
assert!(validate_fit(&best_fit.unwrap(), &xs, &ys));
assert_matches!(
find_best_fit(model.clone(), &xs, &ys, KnownModels::default()),
BestFit::ExpectedModel
);
}
}
#[test]
fn unknown_model() {
let xs: Vec<f64> = (1..20).map(|i| 10. * 1.4_f64.powi(i)).collect();
let ys: Vec<f64> = xs.iter().map(|x| x.powi(4)).collect();
let best_fit =
expected_or_other_best_fit(Constant, &xs, &ys, KnownModels::default()).unwrap_err();
assert_eq!(best_fit.model, BoxedModel::new(N3));
assert!(!validate_fit(&best_fit, &xs, &ys));
assert_matches!(
find_best_fit(Constant, &xs, &ys, KnownModels::default()),
BestFit::NotFound { .. }
);
let best_fit =
expected_or_other_best_fit(Pow(4.0), &xs, &ys, KnownModels::default()).unwrap();
assert_eq!(best_fit.model, BoxedModel::new(Pow(4.)));
assert!(validate_fit(&best_fit, &xs, &ys));
assert_matches!(
find_best_fit(Pow(4.), &xs, &ys, KnownModels::default()),
BestFit::ExpectedModel
);
let BestFit::AnotherKnownModel { best_known } =
find_best_fit(Constant, &xs, &ys, KnownModels::default().with(Pow(4.)))
else {
panic!("oops")
};
assert_eq!(best_known.model, BoxedModel::new(Pow(4.)));
}
}