1use ndarray::ArrayView2;
51
52#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct EvVsKPoint {
56 pub k: usize,
58 pub ev: f64,
60}
61
62impl EvVsKPoint {
63 pub fn new(k: usize, ev: f64) -> Self {
64 Self { k, ev }
65 }
66}
67
68#[derive(Debug, Clone)]
74pub struct EvVsKCurve {
75 points: Vec<EvVsKPoint>,
76}
77
78impl EvVsKCurve {
79 pub fn new(mut points: Vec<EvVsKPoint>) -> Result<Self, String> {
82 if points.is_empty() {
83 return Err("EvVsKCurve::new: at least one (K, EV) sample required".into());
84 }
85 for p in &points {
86 if p.k == 0 {
87 return Err("EvVsKCurve::new: K must be >= 1".into());
88 }
89 if !p.ev.is_finite() {
90 return Err(format!("EvVsKCurve::new: non-finite EV at K={}", p.k));
91 }
92 }
93 points.sort_by_key(|p| p.k);
94 for w in points.windows(2) {
95 if w[0].k == w[1].k {
96 return Err(format!("EvVsKCurve::new: duplicate K={}", w[0].k));
97 }
98 }
99 Ok(Self { points })
100 }
101
102 pub fn len(&self) -> usize {
103 self.points.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
107 self.points.is_empty()
108 }
109
110 pub fn points(&self) -> &[EvVsKPoint] {
111 &self.points
112 }
113
114 pub fn k_min(&self) -> usize {
116 self.points[0].k
117 }
118
119 pub fn k_max(&self) -> usize {
121 self.points[self.points.len() - 1].k
122 }
123
124 pub fn k_reaching(&self, target_ev: f64) -> Option<usize> {
128 self.points.iter().find(|p| p.ev >= target_ev).map(|p| p.k)
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum KSelectionMode {
135 Kneedle,
140 PenalizedMdl,
142}
143
144impl KSelectionMode {
145 pub fn parse(value: &str) -> Result<Self, String> {
146 match value.trim().to_ascii_lowercase().as_str() {
147 "kneedle" | "knee" | "elbow" => Ok(Self::Kneedle),
148 "mdl" | "penalized" | "penalized_mdl" => Ok(Self::PenalizedMdl),
149 other => Err(format!(
150 "K-selection mode must be 'kneedle' or 'mdl'; got {other:?}"
151 )),
152 }
153 }
154
155 pub const fn as_str(self) -> &'static str {
156 match self {
157 Self::Kneedle => "kneedle",
158 Self::PenalizedMdl => "mdl",
159 }
160 }
161}
162
163#[derive(Debug, Clone, Copy)]
165pub struct KSelectionConfig {
166 pub mode: KSelectionMode,
167 pub knee_slope_fraction: f64,
172 pub complexity_penalty: f64,
174 pub flat_span_tol: f64,
178}
179
180impl Default for KSelectionConfig {
181 fn default() -> Self {
182 Self {
183 mode: KSelectionMode::Kneedle,
184 knee_slope_fraction: 0.10,
186 complexity_penalty: 0.05,
187 flat_span_tol: 1.0e-6,
188 }
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub enum KSelectionFlag {
195 Knee,
197 NoKnee,
199 Linear,
201 Flat,
203}
204
205impl KSelectionFlag {
206 pub const fn as_str(self) -> &'static str {
207 match self {
208 Self::Knee => "knee",
209 Self::NoKnee => "no_knee",
210 Self::Linear => "linear",
211 Self::Flat => "flat",
212 }
213 }
214
215 pub const fn is_knee(self) -> bool {
218 matches!(self, Self::Knee)
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct KSelection {
225 pub k: usize,
227 pub ev: f64,
229 pub flag: KSelectionFlag,
231 pub score: f64,
235}
236
237pub fn select_k(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
249 let pts = curve.points();
250 let n = pts.len();
251
252 if n == 1 {
254 return KSelection {
255 k: pts[0].k,
256 ev: pts[0].ev,
257 flag: KSelectionFlag::Flat,
258 score: 0.0,
259 };
260 }
261
262 let ev_min = pts.iter().map(|p| p.ev).fold(f64::INFINITY, f64::min);
264 let ev_max = pts.iter().map(|p| p.ev).fold(f64::NEG_INFINITY, f64::max);
265 let span = ev_max - ev_min;
266 if span <= config.flat_span_tol {
267 return KSelection {
268 k: pts[0].k,
269 ev: pts[0].ev,
270 flag: KSelectionFlag::Flat,
271 score: span,
272 };
273 }
274
275 match config.mode {
276 KSelectionMode::Kneedle => select_kneedle(curve, config, span),
277 KSelectionMode::PenalizedMdl => select_mdl(curve, config),
278 }
279}
280
281fn marginal_slopes(pts: &[EvVsKPoint]) -> Vec<f64> {
283 pts.windows(2)
284 .map(|w| {
285 let dk = (w[1].k - w[0].k) as f64;
286 (w[1].ev - w[0].ev) / dk
287 })
288 .collect()
289}
290
291fn select_kneedle(curve: &EvVsKCurve, config: &KSelectionConfig, span: f64) -> KSelection {
292 let pts = curve.points();
293 let n = pts.len();
294 let slopes = marginal_slopes(pts);
295
296 let init_slope = slopes.iter().copied().find(|s| *s > 0.0).unwrap_or(0.0);
300 if init_slope <= 0.0 {
301 return KSelection {
302 k: pts[0].k,
303 ev: pts[0].ev,
304 flag: KSelectionFlag::Flat,
305 score: 0.0,
306 };
307 }
308
309 let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
313 let slope_lo = slopes.iter().copied().fold(f64::INFINITY, f64::min);
314 let slope_hi = slopes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
315 let slope_spread = slope_hi - slope_lo;
316 if mean_slope > 0.0 && slope_spread <= LINEARITY_SLOPE_REL_TOL * mean_slope {
317 return KSelection {
318 k: curve.k_max(),
319 ev: pts[n - 1].ev,
320 flag: KSelectionFlag::Linear,
321 score: slope_spread / mean_slope.max(MIN_DENOM),
322 };
323 }
324
325 let k_first = pts[0].k as f64;
332 let k_last = pts[n - 1].k as f64;
333 let k_range = (k_last - k_first).max(MIN_DENOM);
334
335 let mut best_idx = 0usize;
336 let mut best_diff = f64::NEG_INFINITY;
337 for (i, p) in pts.iter().enumerate() {
338 let ev_hat = (p.ev - pts[0].ev) / span;
339 let k_hat = (p.k as f64 - k_first) / k_range;
340 let diff = ev_hat - k_hat;
341 if diff > best_diff {
342 best_diff = diff;
343 best_idx = i;
344 }
345 }
346
347 let post_slope = if best_idx < slopes.len() {
352 slopes[best_idx]
353 } else {
354 0.0
356 };
357 let decay_fraction = (post_slope / init_slope).max(0.0);
358
359 if decay_fraction <= config.knee_slope_fraction {
360 KSelection {
361 k: pts[best_idx].k,
362 ev: pts[best_idx].ev,
363 flag: KSelectionFlag::Knee,
364 score: decay_fraction,
365 }
366 } else {
367 KSelection {
370 k: curve.k_max(),
371 ev: pts[n - 1].ev,
372 flag: KSelectionFlag::NoKnee,
373 score: decay_fraction,
374 }
375 }
376}
377
378fn select_mdl(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
379 let pts = curve.points();
380 let k_max = curve.k_max() as f64;
381 let gamma = config.complexity_penalty;
382
383 let mut best_idx = 0usize;
384 let mut best_obj = f64::NEG_INFINITY;
385 for (i, p) in pts.iter().enumerate() {
386 let obj = p.ev - gamma * (p.k as f64 / k_max.max(MIN_DENOM));
387 if obj > best_obj {
388 best_obj = obj;
389 best_idx = i;
390 }
391 }
392
393 let flag = if best_idx == 0 {
397 KSelectionFlag::Flat
398 } else if best_idx == pts.len() - 1 {
399 KSelectionFlag::NoKnee
400 } else {
401 KSelectionFlag::Knee
402 };
403
404 KSelection {
405 k: pts[best_idx].k,
406 ev: pts[best_idx].ev,
407 flag,
408 score: best_obj,
409 }
410}
411
412#[derive(Debug, Clone, Copy, PartialEq)]
419pub struct ManifoldVsLinearAdvantage {
420 pub target_ev: f64,
422 pub k_manifold: Option<usize>,
424 pub k_linear: Option<usize>,
426 pub compression_ratio: Option<f64>,
428}
429
430impl ManifoldVsLinearAdvantage {
431 pub fn manifold_dominates(&self) -> bool {
434 match (self.k_manifold, self.k_linear) {
435 (Some(km), Some(kl)) => km < kl,
436 _ => false,
437 }
438 }
439}
440
441pub fn manifold_vs_linear_advantage(
444 manifold: &EvVsKCurve,
445 linear: &EvVsKCurve,
446 target_ev: f64,
447) -> ManifoldVsLinearAdvantage {
448 let k_manifold = manifold.k_reaching(target_ev);
449 let k_linear = linear.k_reaching(target_ev);
450 let compression_ratio = match (k_manifold, k_linear) {
451 (Some(km), Some(kl)) if km > 0 => Some(kl as f64 / km as f64),
452 _ => None,
453 };
454 ManifoldVsLinearAdvantage {
455 target_ev,
456 k_manifold,
457 k_linear,
458 compression_ratio,
459 }
460}
461
462#[derive(Debug, Clone, Copy, PartialEq)]
467pub struct AutoKRecommendation {
468 pub selection: KSelection,
470 pub advantage: ManifoldVsLinearAdvantage,
472}
473
474pub fn recommend_auto_k(
483 manifold: &EvVsKCurve,
484 linear: &EvVsKCurve,
485 config: &KSelectionConfig,
486) -> AutoKRecommendation {
487 let selection = select_k(manifold, config);
488 let advantage = manifold_vs_linear_advantage(manifold, linear, selection.ev);
489 AutoKRecommendation {
490 selection,
491 advantage,
492 }
493}
494
495pub fn curve_from_pairs(pairs: &[(usize, f64)]) -> Result<EvVsKCurve, String> {
499 EvVsKCurve::new(
500 pairs
501 .iter()
502 .map(|&(k, ev)| EvVsKPoint::new(k, ev))
503 .collect(),
504 )
505}
506
507pub fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
511 assert_eq!(
512 x.dim(),
513 fitted.dim(),
514 "explained_variance: x {:?} != fitted {:?}",
515 x.dim(),
516 fitted.dim()
517 );
518 let n = x.nrows();
519 if n == 0 || x.ncols() == 0 {
520 return 0.0;
521 }
522 let mut rss = 0.0;
523 for row in 0..n {
524 for col in 0..x.ncols() {
525 let r = x[[row, col]] - fitted[[row, col]];
526 rss += r * r;
527 }
528 }
529 let means = x
530 .mean_axis(ndarray::Axis(0))
531 .expect("non-empty input has means");
532 let mut tss = 0.0;
533 for row in 0..n {
534 for col in 0..x.ncols() {
535 let c = x[[row, col]] - means[col];
536 tss += c * c;
537 }
538 }
539 if tss <= MIN_DENOM {
540 if rss <= MIN_DENOM { 1.0 } else { 0.0 }
541 } else {
542 1.0 - rss / tss
543 }
544}
545
546const LINEARITY_SLOPE_REL_TOL: f64 = 0.05;
549
550const MIN_DENOM: f64 = 1.0e-12;
552
553#[cfg(test)]
554mod k_selection_tests {
555 use super::*;
556 use ndarray::array;
557
558 fn knee_curve() -> EvVsKCurve {
559 curve_from_pairs(&[
561 (1, 0.40),
562 (2, 0.65),
563 (3, 0.82),
564 (4, 0.90),
565 (8, 0.915),
566 (16, 0.92),
567 (32, 0.922),
568 ])
569 .expect("knee curve")
570 }
571
572 #[test]
573 fn kneedle_picks_the_elbow() {
574 let curve = knee_curve();
575 let sel = select_k(&curve, &KSelectionConfig::default());
576 assert_eq!(sel.flag, KSelectionFlag::Knee);
577 assert_eq!(sel.k, 4, "knee should sit at the saturation corner K=4");
578 assert!((sel.ev - 0.90).abs() < 1e-9);
579 }
580
581 #[test]
582 fn linear_curve_returns_full_k_with_flag() {
583 let curve = curve_from_pairs(&[(1, 0.10), (2, 0.20), (3, 0.30), (4, 0.40), (5, 0.50)])
585 .expect("linear curve");
586 let sel = select_k(&curve, &KSelectionConfig::default());
587 assert_eq!(sel.flag, KSelectionFlag::Linear);
588 assert_eq!(sel.k, 5, "linear curve returns the largest K");
589 }
590
591 #[test]
592 fn still_climbing_curve_returns_full_k_no_knee() {
593 let curve = curve_from_pairs(&[(1, 0.10), (2, 0.30), (3, 0.48), (4, 0.64), (5, 0.78)])
595 .expect("climbing curve");
596 let cfg = KSelectionConfig {
597 knee_slope_fraction: 0.01,
599 ..KSelectionConfig::default()
600 };
601 let sel = select_k(&curve, &cfg);
602 assert_eq!(sel.flag, KSelectionFlag::NoKnee);
603 assert_eq!(sel.k, 5);
604 }
605
606 #[test]
607 fn flat_curve_returns_smallest_k() {
608 let curve =
609 curve_from_pairs(&[(1, 0.900), (2, 0.9000001), (4, 0.9000002)]).expect("flat curve");
610 let sel = select_k(&curve, &KSelectionConfig::default());
611 assert_eq!(sel.flag, KSelectionFlag::Flat);
612 assert_eq!(sel.k, 1, "already-saturated curve returns smallest K");
613 }
614
615 #[test]
616 fn single_point_curve_is_flat() {
617 let curve = curve_from_pairs(&[(7, 0.5)]).expect("single point");
618 let sel = select_k(&curve, &KSelectionConfig::default());
619 assert_eq!(sel.flag, KSelectionFlag::Flat);
620 assert_eq!(sel.k, 7);
621 }
622
623 #[test]
624 fn mdl_picks_interior_knee_on_saturating_curve() {
625 let curve = knee_curve();
626 let cfg = KSelectionConfig {
627 mode: KSelectionMode::PenalizedMdl,
628 complexity_penalty: 0.05,
629 ..KSelectionConfig::default()
630 };
631 let sel = select_k(&curve, &cfg);
632 assert!(
635 sel.k <= 8,
636 "MDL should not chase the saturated tail, got {}",
637 sel.k
638 );
639 assert!(
640 sel.k >= 3,
641 "MDL should not under-fit the steep rise, got {}",
642 sel.k
643 );
644 }
645
646 #[test]
647 fn mdl_penalty_zero_takes_full_k() {
648 let curve = knee_curve();
649 let cfg = KSelectionConfig {
650 mode: KSelectionMode::PenalizedMdl,
651 complexity_penalty: 0.0,
652 ..KSelectionConfig::default()
653 };
654 let sel = select_k(&curve, &cfg);
655 assert_eq!(sel.k, 32);
657 }
658
659 #[test]
660 fn advantage_metric_rewards_manifold_compression() {
661 let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.65), (4, 0.90), (8, 0.93), (16, 0.94)])
663 .expect("manifold curve");
664 let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35), (4, 0.55), (8, 0.78), (16, 0.91)])
665 .expect("linear curve");
666 let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
667 assert_eq!(adv.k_manifold, Some(4));
668 assert_eq!(adv.k_linear, Some(16));
669 assert!(adv.manifold_dominates());
670 let ratio = adv.compression_ratio.expect("both reach target");
671 assert!(
672 (ratio - 4.0).abs() < 1e-12,
673 "expected 16/4 = 4x, got {ratio}"
674 );
675 }
676
677 #[test]
678 fn advantage_metric_handles_unreached_target() {
679 let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.55)]).expect("manifold curve");
680 let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35)]).expect("linear curve");
681 let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
682 assert_eq!(adv.k_manifold, None);
683 assert_eq!(adv.k_linear, None);
684 assert!(adv.compression_ratio.is_none());
685 assert!(!adv.manifold_dominates());
686 }
687
688 #[test]
689 fn recommend_auto_k_combines_knee_and_advantage() {
690 let manifold = curve_from_pairs(&[
692 (1, 0.40),
693 (2, 0.65),
694 (3, 0.82),
695 (4, 0.90),
696 (8, 0.915),
697 (16, 0.92),
698 (32, 0.922),
699 ])
700 .expect("manifold curve");
701 let linear = curve_from_pairs(&[
702 (1, 0.20),
703 (2, 0.35),
704 (4, 0.55),
705 (8, 0.78),
706 (16, 0.91),
707 (32, 0.93),
708 ])
709 .expect("linear curve");
710 let rec = recommend_auto_k(&manifold, &linear, &KSelectionConfig::default());
711 assert_eq!(rec.selection.k, 4);
712 assert_eq!(rec.selection.flag, KSelectionFlag::Knee);
713 assert_eq!(rec.advantage.k_manifold, Some(4));
715 assert_eq!(rec.advantage.k_linear, Some(16));
716 assert!(rec.advantage.manifold_dominates());
717 let ratio = rec.advantage.compression_ratio.expect("both reach EV");
718 assert!((ratio - 4.0).abs() < 1e-12);
719 }
720
721 #[test]
722 fn explained_variance_matches_perfect_and_mean_baselines() {
723 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
724 let ev_perfect = explained_variance(x.view(), x.view());
726 assert!((ev_perfect - 1.0).abs() < 1e-12);
727 let means = x.mean_axis(ndarray::Axis(0)).expect("means");
729 let mean_fit = array![
730 [means[0], means[1]],
731 [means[0], means[1]],
732 [means[0], means[1]]
733 ];
734 let ev_mean = explained_variance(x.view(), mean_fit.view());
735 assert!(
736 ev_mean.abs() < 1e-12,
737 "mean baseline EV should be 0, got {ev_mean}"
738 );
739 }
740
741 #[test]
742 fn curve_rejects_bad_input() {
743 assert!(EvVsKCurve::new(vec![]).is_err());
744 assert!(curve_from_pairs(&[(0, 0.5)]).is_err());
745 assert!(curve_from_pairs(&[(1, f64::NAN)]).is_err());
746 assert!(curve_from_pairs(&[(2, 0.5), (2, 0.6)]).is_err());
747 }
748
749 #[test]
750 fn curve_sorts_by_k() {
751 let curve = curve_from_pairs(&[(8, 0.9), (1, 0.4), (4, 0.8)]).expect("curve");
752 assert_eq!(curve.k_min(), 1);
753 assert_eq!(curve.k_max(), 8);
754 assert_eq!(curve.points()[0].k, 1);
755 assert_eq!(curve.points()[2].k, 8);
756 }
757
758 #[test]
759 fn mode_parse_roundtrips() {
760 assert_eq!(
761 KSelectionMode::parse("elbow").expect("parse"),
762 KSelectionMode::Kneedle
763 );
764 assert_eq!(
765 KSelectionMode::parse("MDL").expect("parse"),
766 KSelectionMode::PenalizedMdl
767 );
768 assert_eq!(KSelectionMode::Kneedle.as_str(), "kneedle");
769 assert!(KSelectionMode::parse("nonsense").is_err());
770 }
771}