#![allow(dead_code)]
use crate::knee_locator::{ValidCurve, ValidDirection};
use ndarray::{s, stack, Array, Array1, Axis};
use ndarray_linalg::Solve;
type Shape = (ValidDirection, ValidCurve);
fn find_shape(x: &Array1<f64>, y: &Array1<f64>) -> Shape {
let a = stack![Axis(1), x.mapv(|xi| xi), Array::ones(x.len())];
let p = a.t().dot(&a).solve(&a.t().dot(y)).unwrap();
let x1 = (x.len() as f64 * 0.2) as usize;
let x2 = (x.len() as f64 * 0.8) as usize;
let middle_x = x.slice(s![x1..x2]);
let middle_y = y.slice(s![x1..x2]);
let q = middle_y.mean().unwrap() - (middle_x.mapv(|xi| xi * p[0] + p[1])).mean().unwrap();
const EPSILON: f64 = 1e-10;
if p[0].abs() < EPSILON {
(ValidDirection::Decreasing, ValidCurve::Convex)
} else if p[0] > 0.0 {
if q >= 0.0 {
(ValidDirection::Increasing, ValidCurve::Concave)
} else {
(ValidDirection::Increasing, ValidCurve::Convex)
}
} else if q > 0.0 {
(ValidDirection::Decreasing, ValidCurve::Concave)
} else {
(ValidDirection::Decreasing, ValidCurve::Convex)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
data_generator::DataGenerator,
knee_locator::{ValidCurve, ValidDirection},
};
use ndarray::array;
#[test]
fn test_curve_and_direction() {
let x1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y1 = array![1.0, 3.0, 6.0, 10.0, 15.0];
assert_eq!(
find_shape(&x1, &y1),
(ValidDirection::Increasing, ValidCurve::Convex)
);
let x2 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y2 = array![1.0, 1.5, 1.8, 1.9, 2.0];
assert_eq!(
find_shape(&x2, &y2),
(ValidDirection::Increasing, ValidCurve::Concave)
);
let x3 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y3 = array![15.0, 10.0, 6.0, 3.0, 1.0];
assert_eq!(
find_shape(&x3, &y3),
(ValidDirection::Decreasing, ValidCurve::Convex)
);
let x4 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y4 = array![2.0, 1.9, 1.8, 1.5, 1.0];
assert_eq!(
find_shape(&x4, &y4),
(ValidDirection::Decreasing, ValidCurve::Concave)
);
let x5 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y5 = array![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(
find_shape(&x5, &y5),
(ValidDirection::Increasing, ValidCurve::Concave)
);
let x6 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y6 = array![5.0, 4.0, 3.0, 2.0, 1.0];
assert_eq!(
find_shape(&x6, &y6),
(ValidDirection::Decreasing, ValidCurve::Convex)
);
let x7 = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y7 = array![2.0, 2.0, 2.0, 2.0, 2.0];
assert_eq!(
find_shape(&x7, &y7),
(ValidDirection::Decreasing, ValidCurve::Convex)
);
let x8 = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let y8 = array![1.1, 1.0, 1.2, 1.3, 1.25, 1.4, 1.5, 1.6, 1.7, 1.8];
assert_eq!(
find_shape(&x8, &y8),
(ValidDirection::Increasing, ValidCurve::Convex)
);
}
#[test]
fn test_find_shape() {
let (x, y) = DataGenerator::concave_increasing();
assert_eq!(
find_shape(&x, &y),
(ValidDirection::Increasing, ValidCurve::Concave)
);
let (x, y) = DataGenerator::concave_decreasing();
assert_eq!(
find_shape(&x, &y),
(ValidDirection::Decreasing, ValidCurve::Concave)
);
let (x, y) = DataGenerator::convex_increasing();
assert_eq!(
find_shape(&x, &y),
(ValidDirection::Increasing, ValidCurve::Convex)
);
let (x, y) = DataGenerator::convex_decreasing();
assert_eq!(
find_shape(&x, &y),
(ValidDirection::Decreasing, ValidCurve::Convex)
);
}
}