use std::{io::IsTerminal, sync::LazyLock};
use textplots::{Chart, ColorPlot, Shape};
use crate::KnownModels;
use crate::measure::measure_complexity;
use crate::model::{BoxedModel, Model};
static USE_COLORS: LazyLock<bool> = LazyLock::new(|| {
std::io::stdout().is_terminal()
&& std::env::var("NO_COLOR").is_err()
&& std::env::var("TERM").unwrap_or_default() != "dumb"
});
fn to_ints(vals: &[f64]) -> impl Iterator<Item = i64> {
vals.iter().map(|x| *x as i64)
}
struct FittingWithMeasurements {
pub scaling_params: Vec<f64>,
model: BoxedModel,
xs: Vec<f64>,
ys: Vec<f64>,
}
pub fn growing_inputs<C: 'static + Fn(usize) -> T, T: Clone>(
initial_input_size: usize,
create_input: C,
num_inputs: usize,
) -> impl IntoIterator<Item = (usize, T)> {
let growth_factor = 10.0f64.powf(0.1);
(0..num_inputs).map(move |i: usize| {
let len = ((initial_input_size as f64) * (growth_factor.powi(i as i32))).round() as usize;
(len, (create_input)(len))
})
}
fn find_better_fit<M, F, I, T, R>(
expected_model: M,
function: F,
input_sizes_and_values: I,
known_models: KnownModels,
) -> Option<FittingWithMeasurements>
where
M: Model + 'static,
F: Fn(T) -> R,
I: IntoIterator<Item = (usize, T)>,
T: Clone,
{
let measurements = measure_complexity(function, input_sizes_and_values.into_iter());
let xs = measurements.xs;
let ys = measurements.ys;
crate::fit::find_better_fit(expected_model, &xs, &ys, known_models).map(|result| {
FittingWithMeasurements {
scaling_params: result.scaling_params,
model: result.model,
xs,
ys,
}
})
}
pub fn assert_best_fit<M, F, T, R, I>(expected_model: M, function: F, input_sizes_and_values: I)
where
M: Model + 'static,
F: Fn(T) -> R,
I: IntoIterator<Item = (usize, T)>,
T: Clone,
{
assert_best_fit_vs_models(
expected_model,
function,
input_sizes_and_values,
KnownModels::default(),
);
}
pub fn assert_best_fit_vs_models<M, F, T, R, I>(
expected_model: M,
function: F,
input_sizes_and_values: I,
known_models: KnownModels,
) where
M: Model + 'static,
F: Fn(T) -> R,
I: IntoIterator<Item = (usize, T)>,
T: Clone,
{
let expected_string = expected_model.to_string();
let maybe_better = find_better_fit(
expected_model,
function,
input_sizes_and_values,
known_models,
);
if let Some(better_fitting) = maybe_better {
let best_model = better_fitting.model;
let xs = &better_fitting.xs;
let ys = &better_fitting.ys;
if *USE_COLORS {
println!(
"\nRed dots: real data\nBlue line: best fitted model {}",
best_model.to_string()
);
let params = &better_fitting.scaling_params;
let calc_complexity = |n: f64| -> f64 {
if best_model.constant() {
params[0]
} else {
params[0] + params[1] * best_model.complexity(n)
}
};
let points = xs
.iter()
.zip(ys.iter())
.map(|(x, y)| (*x as f32, *y as f32))
.collect::<Vec<(f32, f32)>>();
let combo_results = to_ints(ys)
.chain(xs.iter().map(|x| calc_complexity(*x) as i64))
.collect::<Vec<_>>();
let mut chart = Chart::new_with_y_range(
120,
60,
to_ints(xs).min().unwrap() as f32 * 0.9,
to_ints(xs).max().unwrap() as f32 * 1.1,
*combo_results.iter().min().unwrap() as f32 * 0.9,
*combo_results.iter().max().unwrap() as f32 * 1.1,
);
chart.borders();
chart
.linecolorplot(
&Shape::Continuous(Box::new(|n| calc_complexity(n as f64) as f32)),
rgb::RGB8 { r: 0, g: 0, b: 255 },
)
.linecolorplot(&Shape::Points(&points), rgb::RGB8 { r: 255, g: 0, b: 0 })
.display();
}
panic!(
"The best matching model is: {}\nThis didn't match the expected model: {}",
best_model.to_string(),
expected_string
);
}
}
#[cfg(test)]
mod tests {
use std::iter::repeat_with;
use super::*;
use crate::model::{Constant, Log, N};
fn make_vec(n: usize) -> Vec<i64> {
repeat_with(|| fastrand::i64(..)).take(n).collect()
}
fn sort(mut v: Vec<i64>) -> Vec<i64> {
v.sort();
v
}
#[test]
fn check_expected_complexity_sort() {
check_expected_complexity_on_vecs(sort, N * Log(N), 25);
assert_best_fit(N * Log(N), sort, growing_inputs(100, make_vec, 25));
}
#[test]
fn check_expected_complexity_search() {
check_expected_complexity_on_vecs(|v| v.into_iter().position(|i| i == 123), N, 40);
}
fn check_expected_complexity_on_vecs<R, F: Fn(Vec<i64>) -> R, M: Model + 'static>(
function: F,
expected_complexity: M,
num_items: usize,
) {
let fit = find_better_fit(
Constant,
function,
growing_inputs(100, make_vec, num_items),
KnownModels::default(),
)
.unwrap_or_else(|| {
panic!(
"Unexpectedly estimated constant complexity, i.e. O(1), instead of {}",
expected_complexity.to_string()
)
});
let expected_complexity = BoxedModel::new(expected_complexity);
assert!(
fit.model == expected_complexity,
"{:?} != {}",
fit.model,
expected_complexity.to_string()
);
}
#[should_panic(
expected = "The best matching model is: n*log(n)\nThis didn't match the expected model: n"
)]
#[test]
fn assert_panics_on_wrong_options() {
assert_best_fit(N, sort, growing_inputs(100, make_vec, 25));
}
#[should_panic(
expected = "The best matching model is: n*log(n)\nThis didn't match the expected model: n"
)]
#[test]
fn assert_panics_on_wrong_options_with_models() {
assert_best_fit_vs_models(
N,
sort,
growing_inputs(100, make_vec, 25),
KnownModels::default(),
);
}
#[test]
fn best_we_can_do() {
assert_best_fit_vs_models(
N,
sort,
growing_inputs(100, make_vec, 25),
KnownModels::new(),
);
}
}