use ndarray::ArrayView2;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct EvVsKPoint {
pub k: usize,
pub ev: f64,
}
impl EvVsKPoint {
pub fn new(k: usize, ev: f64) -> Self {
Self { k, ev }
}
}
#[derive(Debug, Clone)]
pub struct EvVsKCurve {
points: Vec<EvVsKPoint>,
}
impl EvVsKCurve {
pub fn new(mut points: Vec<EvVsKPoint>) -> Result<Self, String> {
if points.is_empty() {
return Err("EvVsKCurve::new: at least one (K, EV) sample required".into());
}
for p in &points {
if p.k == 0 {
return Err("EvVsKCurve::new: K must be >= 1".into());
}
if !p.ev.is_finite() {
return Err(format!("EvVsKCurve::new: non-finite EV at K={}", p.k));
}
}
points.sort_by_key(|p| p.k);
for w in points.windows(2) {
if w[0].k == w[1].k {
return Err(format!("EvVsKCurve::new: duplicate K={}", w[0].k));
}
}
Ok(Self { points })
}
pub fn len(&self) -> usize {
self.points.len()
}
pub fn is_empty(&self) -> bool {
self.points.is_empty()
}
pub fn points(&self) -> &[EvVsKPoint] {
&self.points
}
pub fn k_min(&self) -> usize {
self.points[0].k
}
pub fn k_max(&self) -> usize {
self.points[self.points.len() - 1].k
}
pub fn k_reaching(&self, target_ev: f64) -> Option<usize> {
self.points.iter().find(|p| p.ev >= target_ev).map(|p| p.k)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum KSelectionMode {
Kneedle,
PenalizedMdl,
}
impl KSelectionMode {
pub fn parse(value: &str) -> Result<Self, String> {
match value.trim().to_ascii_lowercase().as_str() {
"kneedle" | "knee" | "elbow" => Ok(Self::Kneedle),
"mdl" | "penalized" | "penalized_mdl" => Ok(Self::PenalizedMdl),
other => Err(format!(
"K-selection mode must be 'kneedle' or 'mdl'; got {other:?}"
)),
}
}
pub const fn as_str(self) -> &'static str {
match self {
Self::Kneedle => "kneedle",
Self::PenalizedMdl => "mdl",
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct KSelectionConfig {
pub mode: KSelectionMode,
pub knee_slope_fraction: f64,
pub complexity_penalty: f64,
pub flat_span_tol: f64,
}
impl Default for KSelectionConfig {
fn default() -> Self {
Self {
mode: KSelectionMode::Kneedle,
knee_slope_fraction: 0.10,
complexity_penalty: 0.05,
flat_span_tol: 1.0e-6,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KSelectionFlag {
Knee,
NoKnee,
Linear,
Flat,
}
impl KSelectionFlag {
pub const fn as_str(self) -> &'static str {
match self {
Self::Knee => "knee",
Self::NoKnee => "no_knee",
Self::Linear => "linear",
Self::Flat => "flat",
}
}
pub const fn is_knee(self) -> bool {
matches!(self, Self::Knee)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct KSelection {
pub k: usize,
pub ev: f64,
pub flag: KSelectionFlag,
pub score: f64,
}
pub fn select_k(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
let pts = curve.points();
let n = pts.len();
if n == 1 {
return KSelection {
k: pts[0].k,
ev: pts[0].ev,
flag: KSelectionFlag::Flat,
score: 0.0,
};
}
let ev_min = pts.iter().map(|p| p.ev).fold(f64::INFINITY, f64::min);
let ev_max = pts.iter().map(|p| p.ev).fold(f64::NEG_INFINITY, f64::max);
let span = ev_max - ev_min;
if span <= config.flat_span_tol {
return KSelection {
k: pts[0].k,
ev: pts[0].ev,
flag: KSelectionFlag::Flat,
score: span,
};
}
match config.mode {
KSelectionMode::Kneedle => select_kneedle(curve, config, span),
KSelectionMode::PenalizedMdl => select_mdl(curve, config),
}
}
fn marginal_slopes(pts: &[EvVsKPoint]) -> Vec<f64> {
pts.windows(2)
.map(|w| {
let dk = (w[1].k - w[0].k) as f64;
(w[1].ev - w[0].ev) / dk
})
.collect()
}
fn select_kneedle(curve: &EvVsKCurve, config: &KSelectionConfig, span: f64) -> KSelection {
let pts = curve.points();
let n = pts.len();
let slopes = marginal_slopes(pts);
let init_slope = slopes.iter().copied().find(|s| *s > 0.0).unwrap_or(0.0);
if init_slope <= 0.0 {
return KSelection {
k: pts[0].k,
ev: pts[0].ev,
flag: KSelectionFlag::Flat,
score: 0.0,
};
}
let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
let slope_lo = slopes.iter().copied().fold(f64::INFINITY, f64::min);
let slope_hi = slopes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let slope_spread = slope_hi - slope_lo;
if mean_slope > 0.0 && slope_spread <= LINEARITY_SLOPE_REL_TOL * mean_slope {
return KSelection {
k: curve.k_max(),
ev: pts[n - 1].ev,
flag: KSelectionFlag::Linear,
score: slope_spread / mean_slope.max(MIN_DENOM),
};
}
let k_first = pts[0].k as f64;
let k_last = pts[n - 1].k as f64;
let k_range = (k_last - k_first).max(MIN_DENOM);
let mut best_idx = 0usize;
let mut best_diff = f64::NEG_INFINITY;
for (i, p) in pts.iter().enumerate() {
let ev_hat = (p.ev - pts[0].ev) / span;
let k_hat = (p.k as f64 - k_first) / k_range;
let diff = ev_hat - k_hat;
if diff > best_diff {
best_diff = diff;
best_idx = i;
}
}
let post_slope = if best_idx < slopes.len() {
slopes[best_idx]
} else {
0.0
};
let decay_fraction = (post_slope / init_slope).max(0.0);
if decay_fraction <= config.knee_slope_fraction {
KSelection {
k: pts[best_idx].k,
ev: pts[best_idx].ev,
flag: KSelectionFlag::Knee,
score: decay_fraction,
}
} else {
KSelection {
k: curve.k_max(),
ev: pts[n - 1].ev,
flag: KSelectionFlag::NoKnee,
score: decay_fraction,
}
}
}
fn select_mdl(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
let pts = curve.points();
let k_max = curve.k_max() as f64;
let gamma = config.complexity_penalty;
let mut best_idx = 0usize;
let mut best_obj = f64::NEG_INFINITY;
for (i, p) in pts.iter().enumerate() {
let obj = p.ev - gamma * (p.k as f64 / k_max.max(MIN_DENOM));
if obj > best_obj {
best_obj = obj;
best_idx = i;
}
}
let flag = if best_idx == 0 {
KSelectionFlag::Flat
} else if best_idx == pts.len() - 1 {
KSelectionFlag::NoKnee
} else {
KSelectionFlag::Knee
};
KSelection {
k: pts[best_idx].k,
ev: pts[best_idx].ev,
flag,
score: best_obj,
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ManifoldVsLinearAdvantage {
pub target_ev: f64,
pub k_manifold: Option<usize>,
pub k_linear: Option<usize>,
pub compression_ratio: Option<f64>,
}
impl ManifoldVsLinearAdvantage {
pub fn manifold_dominates(&self) -> bool {
match (self.k_manifold, self.k_linear) {
(Some(km), Some(kl)) => km < kl,
_ => false,
}
}
}
pub fn manifold_vs_linear_advantage(
manifold: &EvVsKCurve,
linear: &EvVsKCurve,
target_ev: f64,
) -> ManifoldVsLinearAdvantage {
let k_manifold = manifold.k_reaching(target_ev);
let k_linear = linear.k_reaching(target_ev);
let compression_ratio = match (k_manifold, k_linear) {
(Some(km), Some(kl)) if km > 0 => Some(kl as f64 / km as f64),
_ => None,
};
ManifoldVsLinearAdvantage {
target_ev,
k_manifold,
k_linear,
compression_ratio,
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AutoKRecommendation {
pub selection: KSelection,
pub advantage: ManifoldVsLinearAdvantage,
}
pub fn recommend_auto_k(
manifold: &EvVsKCurve,
linear: &EvVsKCurve,
config: &KSelectionConfig,
) -> AutoKRecommendation {
let selection = select_k(manifold, config);
let advantage = manifold_vs_linear_advantage(manifold, linear, selection.ev);
AutoKRecommendation {
selection,
advantage,
}
}
pub fn curve_from_pairs(pairs: &[(usize, f64)]) -> Result<EvVsKCurve, String> {
EvVsKCurve::new(
pairs
.iter()
.map(|&(k, ev)| EvVsKPoint::new(k, ev))
.collect(),
)
}
pub fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
assert_eq!(
x.dim(),
fitted.dim(),
"explained_variance: x {:?} != fitted {:?}",
x.dim(),
fitted.dim()
);
let n = x.nrows();
if n == 0 || x.ncols() == 0 {
return 0.0;
}
let mut rss = 0.0;
for row in 0..n {
for col in 0..x.ncols() {
let r = x[[row, col]] - fitted[[row, col]];
rss += r * r;
}
}
let means = x
.mean_axis(ndarray::Axis(0))
.expect("non-empty input has means");
let mut tss = 0.0;
for row in 0..n {
for col in 0..x.ncols() {
let c = x[[row, col]] - means[col];
tss += c * c;
}
}
if tss <= MIN_DENOM {
if rss <= MIN_DENOM { 1.0 } else { 0.0 }
} else {
1.0 - rss / tss
}
}
const LINEARITY_SLOPE_REL_TOL: f64 = 0.05;
const MIN_DENOM: f64 = 1.0e-12;
#[cfg(test)]
mod k_selection_tests {
use super::*;
use ndarray::array;
fn knee_curve() -> EvVsKCurve {
curve_from_pairs(&[
(1, 0.40),
(2, 0.65),
(3, 0.82),
(4, 0.90),
(8, 0.915),
(16, 0.92),
(32, 0.922),
])
.expect("knee curve")
}
#[test]
fn kneedle_picks_the_elbow() {
let curve = knee_curve();
let sel = select_k(&curve, &KSelectionConfig::default());
assert_eq!(sel.flag, KSelectionFlag::Knee);
assert_eq!(sel.k, 4, "knee should sit at the saturation corner K=4");
assert!((sel.ev - 0.90).abs() < 1e-9);
}
#[test]
fn linear_curve_returns_full_k_with_flag() {
let curve = curve_from_pairs(&[(1, 0.10), (2, 0.20), (3, 0.30), (4, 0.40), (5, 0.50)])
.expect("linear curve");
let sel = select_k(&curve, &KSelectionConfig::default());
assert_eq!(sel.flag, KSelectionFlag::Linear);
assert_eq!(sel.k, 5, "linear curve returns the largest K");
}
#[test]
fn still_climbing_curve_returns_full_k_no_knee() {
let curve = curve_from_pairs(&[(1, 0.10), (2, 0.30), (3, 0.48), (4, 0.64), (5, 0.78)])
.expect("climbing curve");
let cfg = KSelectionConfig {
knee_slope_fraction: 0.01,
..KSelectionConfig::default()
};
let sel = select_k(&curve, &cfg);
assert_eq!(sel.flag, KSelectionFlag::NoKnee);
assert_eq!(sel.k, 5);
}
#[test]
fn flat_curve_returns_smallest_k() {
let curve =
curve_from_pairs(&[(1, 0.900), (2, 0.9000001), (4, 0.9000002)]).expect("flat curve");
let sel = select_k(&curve, &KSelectionConfig::default());
assert_eq!(sel.flag, KSelectionFlag::Flat);
assert_eq!(sel.k, 1, "already-saturated curve returns smallest K");
}
#[test]
fn single_point_curve_is_flat() {
let curve = curve_from_pairs(&[(7, 0.5)]).expect("single point");
let sel = select_k(&curve, &KSelectionConfig::default());
assert_eq!(sel.flag, KSelectionFlag::Flat);
assert_eq!(sel.k, 7);
}
#[test]
fn mdl_picks_interior_knee_on_saturating_curve() {
let curve = knee_curve();
let cfg = KSelectionConfig {
mode: KSelectionMode::PenalizedMdl,
complexity_penalty: 0.05,
..KSelectionConfig::default()
};
let sel = select_k(&curve, &cfg);
assert!(
sel.k <= 8,
"MDL should not chase the saturated tail, got {}",
sel.k
);
assert!(
sel.k >= 3,
"MDL should not under-fit the steep rise, got {}",
sel.k
);
}
#[test]
fn mdl_penalty_zero_takes_full_k() {
let curve = knee_curve();
let cfg = KSelectionConfig {
mode: KSelectionMode::PenalizedMdl,
complexity_penalty: 0.0,
..KSelectionConfig::default()
};
let sel = select_k(&curve, &cfg);
assert_eq!(sel.k, 32);
}
#[test]
fn advantage_metric_rewards_manifold_compression() {
let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.65), (4, 0.90), (8, 0.93), (16, 0.94)])
.expect("manifold curve");
let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35), (4, 0.55), (8, 0.78), (16, 0.91)])
.expect("linear curve");
let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
assert_eq!(adv.k_manifold, Some(4));
assert_eq!(adv.k_linear, Some(16));
assert!(adv.manifold_dominates());
let ratio = adv.compression_ratio.expect("both reach target");
assert!(
(ratio - 4.0).abs() < 1e-12,
"expected 16/4 = 4x, got {ratio}"
);
}
#[test]
fn advantage_metric_handles_unreached_target() {
let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.55)]).expect("manifold curve");
let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35)]).expect("linear curve");
let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
assert_eq!(adv.k_manifold, None);
assert_eq!(adv.k_linear, None);
assert!(adv.compression_ratio.is_none());
assert!(!adv.manifold_dominates());
}
#[test]
fn recommend_auto_k_combines_knee_and_advantage() {
let manifold = curve_from_pairs(&[
(1, 0.40),
(2, 0.65),
(3, 0.82),
(4, 0.90),
(8, 0.915),
(16, 0.92),
(32, 0.922),
])
.expect("manifold curve");
let linear = curve_from_pairs(&[
(1, 0.20),
(2, 0.35),
(4, 0.55),
(8, 0.78),
(16, 0.91),
(32, 0.93),
])
.expect("linear curve");
let rec = recommend_auto_k(&manifold, &linear, &KSelectionConfig::default());
assert_eq!(rec.selection.k, 4);
assert_eq!(rec.selection.flag, KSelectionFlag::Knee);
assert_eq!(rec.advantage.k_manifold, Some(4));
assert_eq!(rec.advantage.k_linear, Some(16));
assert!(rec.advantage.manifold_dominates());
let ratio = rec.advantage.compression_ratio.expect("both reach EV");
assert!((ratio - 4.0).abs() < 1e-12);
}
#[test]
fn explained_variance_matches_perfect_and_mean_baselines() {
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let ev_perfect = explained_variance(x.view(), x.view());
assert!((ev_perfect - 1.0).abs() < 1e-12);
let means = x.mean_axis(ndarray::Axis(0)).expect("means");
let mean_fit = array![
[means[0], means[1]],
[means[0], means[1]],
[means[0], means[1]]
];
let ev_mean = explained_variance(x.view(), mean_fit.view());
assert!(
ev_mean.abs() < 1e-12,
"mean baseline EV should be 0, got {ev_mean}"
);
}
#[test]
fn curve_rejects_bad_input() {
assert!(EvVsKCurve::new(vec![]).is_err());
assert!(curve_from_pairs(&[(0, 0.5)]).is_err());
assert!(curve_from_pairs(&[(1, f64::NAN)]).is_err());
assert!(curve_from_pairs(&[(2, 0.5), (2, 0.6)]).is_err());
}
#[test]
fn curve_sorts_by_k() {
let curve = curve_from_pairs(&[(8, 0.9), (1, 0.4), (4, 0.8)]).expect("curve");
assert_eq!(curve.k_min(), 1);
assert_eq!(curve.k_max(), 8);
assert_eq!(curve.points()[0].k, 1);
assert_eq!(curve.points()[2].k, 8);
}
#[test]
fn mode_parse_roundtrips() {
assert_eq!(
KSelectionMode::parse("elbow").expect("parse"),
KSelectionMode::Kneedle
);
assert_eq!(
KSelectionMode::parse("MDL").expect("parse"),
KSelectionMode::PenalizedMdl
);
assert_eq!(KSelectionMode::Kneedle.as_str(), "kneedle");
assert!(KSelectionMode::parse("nonsense").is_err());
}
}