Skip to main content

gam_sae/
k_selection.rs

1//! # Automatic dictionary-size (`K`) selection from the EV-vs-`K` frontier (#1026).
2//!
3//! The manifold-SAE fit takes a dictionary size `K` (the atom count). Today `K`
4//! is user-specified; this module turns the **EV-vs-`K` frontier** measured by
5//! the OLMo research battery (`tests/sae/olmo_research_battery.py`, the #1026
6//! data) into a principled automatic choice.
7//!
8//! ## Why a knee/MDL criterion (and not REML) here
9//!
10//! `K` is a **discrete structure** choice — like the topology race in
11//! [`gam_solve::structure_search`] picks between manifold topologies, not a
12//! continuous smoothing parameter. The REML-always / no-GCV-BIC policy governs
13//! *fitting* (the continuous `ρ`/`λ` smoothing tier); discrete structure
14//! selection legitimately uses a knee/penalized-fit criterion (the topology
15//! search already uses `score='bic'` for the same reason). So the dictionary
16//! size is selected by either:
17//!
18//! * an **elbow / kneedle** criterion — pick the `K` at the saturation knee of
19//!   the explained-variance curve, where the marginal EV gain per added atom
20//!   first drops below a principled fraction of the early (steep-regime) slope,
21//!   or equivalently the point of maximum curvature of the normalized curve; or
22//! * a **penalized-EV / MDL** stop — maximize `EV(K) − γ · (K / K_max)`, a
23//!   description-length-style trade of reconstruction gain against dictionary
24//!   complexity.
25//!
26//! ## The manifold-vs-linear advantage
27//!
28//! The frontier carries *two* curves: the manifold-SAE EV-vs-`K` and a
29//! linear-SAE baseline EV-vs-`K`. [`ManifoldVsLinearAdvantage`] quantifies the
30//! parameter-efficiency win: manifold reaches a target EV at `K_m` atoms vs
31//! linear at `K_l`, and the ratio `K_l / K_m > 1` is the compression factor —
32//! the manifold representation buys the same reconstruction quality with fewer
33//! atoms.
34//!
35//! ## Degenerate curves
36//!
37//! Real frontiers are not always knee-shaped. The selector classifies the
38//! curve and reports the verdict through [`KSelectionFlag`] so callers can act
39//! on it rather than silently trusting a spurious knee:
40//!
41//! * **`Knee`** — a clear saturation knee was found.
42//! * **`NoKnee`** — the curve keeps climbing (still steep at `K_max`): the
43//!   selector returns the largest `K`, flagged, because more atoms would still
44//!   help.
45//! * **`Linear`** — EV grows ~linearly in `K` with no curvature: no knee
46//!   exists; the largest `K` is returned, flagged.
47//! * **`Flat`** — EV is already saturated at the smallest `K`: the smallest
48//!   `K` is returned, flagged.
49
50use ndarray::ArrayView2;
51
52/// One `(K, EV)` sample on the EV-vs-`K` frontier: dictionary size `k` and the
53/// explained variance / R² of the reconstruction at that size.
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct EvVsKPoint {
56    /// Dictionary size (atom count) `K ≥ 1`.
57    pub k: usize,
58    /// Explained variance (R²) of the reconstruction at this `K`, in `(-∞, 1]`.
59    pub ev: f64,
60}
61
62impl EvVsKPoint {
63    pub fn new(k: usize, ev: f64) -> Self {
64        Self { k, ev }
65    }
66}
67
68/// An EV-vs-`K` frontier: a set of `(K, EV)` samples, e.g. from a fit sweep or
69/// the OLMo research battery.
70///
71/// Construction sorts by `K` ascending and rejects duplicate / empty / non-
72/// finite input so downstream slope math is well-defined.
73#[derive(Debug, Clone)]
74pub struct EvVsKCurve {
75    points: Vec<EvVsKPoint>,
76}
77
78impl EvVsKCurve {
79    /// Build from `(K, EV)` samples. Errors on an empty curve, a non-positive
80    /// `K`, a non-finite `EV`, or duplicate `K` values.
81    pub fn new(mut points: Vec<EvVsKPoint>) -> Result<Self, String> {
82        if points.is_empty() {
83            return Err("EvVsKCurve::new: at least one (K, EV) sample required".into());
84        }
85        for p in &points {
86            if p.k == 0 {
87                return Err("EvVsKCurve::new: K must be >= 1".into());
88            }
89            if !p.ev.is_finite() {
90                return Err(format!("EvVsKCurve::new: non-finite EV at K={}", p.k));
91            }
92        }
93        points.sort_by_key(|p| p.k);
94        for w in points.windows(2) {
95            if w[0].k == w[1].k {
96                return Err(format!("EvVsKCurve::new: duplicate K={}", w[0].k));
97            }
98        }
99        Ok(Self { points })
100    }
101
102    pub fn len(&self) -> usize {
103        self.points.len()
104    }
105
106    pub fn is_empty(&self) -> bool {
107        self.points.is_empty()
108    }
109
110    pub fn points(&self) -> &[EvVsKPoint] {
111        &self.points
112    }
113
114    /// Smallest `K` on the curve.
115    pub fn k_min(&self) -> usize {
116        self.points[0].k
117    }
118
119    /// Largest `K` on the curve.
120    pub fn k_max(&self) -> usize {
121        self.points[self.points.len() - 1].k
122    }
123
124    /// EV at the smallest `K` for which `EV(K) >= target`, if any. The curve is
125    /// scanned in ascending `K` order, so this is the *cheapest* dictionary
126    /// reaching the target. Returns `None` when the target is never reached.
127    pub fn k_reaching(&self, target_ev: f64) -> Option<usize> {
128        self.points.iter().find(|p| p.ev >= target_ev).map(|p| p.k)
129    }
130}
131
132/// How the dictionary size is selected from the frontier.
133#[derive(Debug, Clone, Copy, PartialEq)]
134pub enum KSelectionMode {
135    /// Kneedle-style: pick the `K` of maximum curvature of the min-max
136    /// normalized curve (the saturation knee), accepting it only when the
137    /// post-knee marginal slope has decayed below `knee_slope_fraction` of the
138    /// initial (steep-regime) slope.
139    Kneedle,
140    /// Penalized-EV / MDL: maximize `EV(K) − complexity_penalty · (K / K_max)`.
141    PenalizedMdl,
142}
143
144impl KSelectionMode {
145    pub fn parse(value: &str) -> Result<Self, String> {
146        match value.trim().to_ascii_lowercase().as_str() {
147            "kneedle" | "knee" | "elbow" => Ok(Self::Kneedle),
148            "mdl" | "penalized" | "penalized_mdl" => Ok(Self::PenalizedMdl),
149            other => Err(format!(
150                "K-selection mode must be 'kneedle' or 'mdl'; got {other:?}"
151            )),
152        }
153    }
154
155    pub const fn as_str(self) -> &'static str {
156        match self {
157            Self::Kneedle => "kneedle",
158            Self::PenalizedMdl => "mdl",
159        }
160    }
161}
162
163/// Tuning for [`select_k`].
164#[derive(Debug, Clone, Copy)]
165pub struct KSelectionConfig {
166    pub mode: KSelectionMode,
167    /// Kneedle: the post-knee marginal slope must fall below this fraction of
168    /// the initial slope for a point to count as a saturation knee. A curve
169    /// whose slope never decays this far is classified [`KSelectionFlag::NoKnee`]
170    /// (still climbing) or [`KSelectionFlag::Linear`].
171    pub knee_slope_fraction: f64,
172    /// MDL: the complexity weight `γ` on the normalized size `K / K_max`.
173    pub complexity_penalty: f64,
174    /// Below this total EV span (`max EV − min EV` across the curve) the curve
175    /// is treated as already saturated ([`KSelectionFlag::Flat`]) and the
176    /// smallest `K` is returned.
177    pub flat_span_tol: f64,
178}
179
180impl Default for KSelectionConfig {
181    fn default() -> Self {
182        Self {
183            mode: KSelectionMode::Kneedle,
184            // Knee = where marginal gain has decayed to 10% of the steep slope.
185            knee_slope_fraction: 0.10,
186            complexity_penalty: 0.05,
187            flat_span_tol: 1.0e-6,
188        }
189    }
190}
191
192/// Classification of the chosen `K` / the curve shape.
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub enum KSelectionFlag {
195    /// A clear saturation knee was found; `K` is the knee.
196    Knee,
197    /// The curve is still climbing at `K_max`; returned `K = K_max`.
198    NoKnee,
199    /// EV grows ~linearly in `K` (no curvature); returned `K = K_max`.
200    Linear,
201    /// EV is already saturated at the smallest `K`; returned `K = K_min`.
202    Flat,
203}
204
205impl KSelectionFlag {
206    pub const fn as_str(self) -> &'static str {
207        match self {
208            Self::Knee => "knee",
209            Self::NoKnee => "no_knee",
210            Self::Linear => "linear",
211            Self::Flat => "flat",
212        }
213    }
214
215    /// Whether the returned `K` is a genuine saturation knee (vs a fallback to
216    /// an endpoint because no knee exists).
217    pub const fn is_knee(self) -> bool {
218        matches!(self, Self::Knee)
219    }
220}
221
222/// Result of [`select_k`].
223#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct KSelection {
225    /// The selected dictionary size.
226    pub k: usize,
227    /// The explained variance at the selected `K`.
228    pub ev: f64,
229    /// Curve-shape classification of the selection.
230    pub flag: KSelectionFlag,
231    /// The strength score driving the choice: for Kneedle this is the post-knee
232    /// slope-decay fraction (smaller = sharper knee); for MDL it is the
233    /// penalized objective value at the selected `K`.
234    pub score: f64,
235}
236
237/// Select the dictionary size `K` at the saturation knee of an EV-vs-`K`
238/// frontier.
239///
240/// The discrete analogue of the topology race in
241/// [`gam_solve::structure_search`]: a knee/MDL criterion over a discrete
242/// structure parameter, *not* a REML smoothing choice (see module docs).
243///
244/// On a curve with no usable knee the largest (or, when already saturated,
245/// smallest) `K` is returned with the corresponding [`KSelectionFlag`] so the
246/// caller can decide whether to widen the sweep rather than trusting a spurious
247/// elbow.
248pub fn select_k(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
249    let pts = curve.points();
250    let n = pts.len();
251
252    // Single point: nothing to choose.
253    if n == 1 {
254        return KSelection {
255            k: pts[0].k,
256            ev: pts[0].ev,
257            flag: KSelectionFlag::Flat,
258            score: 0.0,
259        };
260    }
261
262    // Already-saturated curve (EV barely moves): smallest K wins.
263    let ev_min = pts.iter().map(|p| p.ev).fold(f64::INFINITY, f64::min);
264    let ev_max = pts.iter().map(|p| p.ev).fold(f64::NEG_INFINITY, f64::max);
265    let span = ev_max - ev_min;
266    if span <= config.flat_span_tol {
267        return KSelection {
268            k: pts[0].k,
269            ev: pts[0].ev,
270            flag: KSelectionFlag::Flat,
271            score: span,
272        };
273    }
274
275    match config.mode {
276        KSelectionMode::Kneedle => select_kneedle(curve, config, span),
277        KSelectionMode::PenalizedMdl => select_mdl(curve, config),
278    }
279}
280
281/// Per-segment marginal EV gain *per added atom*: `(EV_{i+1} − EV_i) / (K_{i+1} − K_i)`.
282fn marginal_slopes(pts: &[EvVsKPoint]) -> Vec<f64> {
283    pts.windows(2)
284        .map(|w| {
285            let dk = (w[1].k - w[0].k) as f64;
286            (w[1].ev - w[0].ev) / dk
287        })
288        .collect()
289}
290
291fn select_kneedle(curve: &EvVsKCurve, config: &KSelectionConfig, span: f64) -> KSelection {
292    let pts = curve.points();
293    let n = pts.len();
294    let slopes = marginal_slopes(pts);
295
296    // Initial (steep-regime) slope: the first positive segment slope. If the
297    // curve never rises, treat it as flat (handled by the span gate upstream,
298    // but guard anyway).
299    let init_slope = slopes.iter().copied().find(|s| *s > 0.0).unwrap_or(0.0);
300    if init_slope <= 0.0 {
301        return KSelection {
302            k: pts[0].k,
303            ev: pts[0].ev,
304            flag: KSelectionFlag::Flat,
305            score: 0.0,
306        };
307    }
308
309    // Linearity test: a straight EV-vs-K line has near-constant marginal slope.
310    // Compare the slope range to the mean slope; a small relative spread means
311    // there is no curvature, hence no knee.
312    let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
313    let slope_lo = slopes.iter().copied().fold(f64::INFINITY, f64::min);
314    let slope_hi = slopes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
315    let slope_spread = slope_hi - slope_lo;
316    if mean_slope > 0.0 && slope_spread <= LINEARITY_SLOPE_REL_TOL * mean_slope {
317        return KSelection {
318            k: curve.k_max(),
319            ev: pts[n - 1].ev,
320            flag: KSelectionFlag::Linear,
321            score: slope_spread / mean_slope.max(MIN_DENOM),
322        };
323    }
324
325    // Kneedle: on the min-max normalized curve, the knee is the point of
326    // greatest drop of the curve below the chord from first to last point.
327    // Equivalently the point maximizing the normalized-EV minus normalized-K
328    // difference d_i = ev_hat_i − k_hat_i. We locate that candidate, then
329    // accept it only if the *post-knee* marginal slope has decayed below
330    // `knee_slope_fraction` of the initial slope (the saturation test).
331    let k_first = pts[0].k as f64;
332    let k_last = pts[n - 1].k as f64;
333    let k_range = (k_last - k_first).max(MIN_DENOM);
334
335    let mut best_idx = 0usize;
336    let mut best_diff = f64::NEG_INFINITY;
337    for (i, p) in pts.iter().enumerate() {
338        let ev_hat = (p.ev - pts[0].ev) / span;
339        let k_hat = (p.k as f64 - k_first) / k_range;
340        let diff = ev_hat - k_hat;
341        if diff > best_diff {
342            best_diff = diff;
343            best_idx = i;
344        }
345    }
346
347    // Saturation acceptance: the marginal slope of the segment *after* the
348    // candidate knee must have decayed below the fraction of the initial slope.
349    // `best_idx` indexes a point; the post-knee slope is segment `best_idx`
350    // (between best_idx and best_idx+1) when it exists.
351    let post_slope = if best_idx < slopes.len() {
352        slopes[best_idx]
353    } else {
354        // Knee at the very last point => everything saturated up to here.
355        0.0
356    };
357    let decay_fraction = (post_slope / init_slope).max(0.0);
358
359    if decay_fraction <= config.knee_slope_fraction {
360        KSelection {
361            k: pts[best_idx].k,
362            ev: pts[best_idx].ev,
363            flag: KSelectionFlag::Knee,
364            score: decay_fraction,
365        }
366    } else {
367        // The "knee" candidate is still on a steep stretch: the curve has not
368        // saturated within the sampled range. Return the largest K, flagged.
369        KSelection {
370            k: curve.k_max(),
371            ev: pts[n - 1].ev,
372            flag: KSelectionFlag::NoKnee,
373            score: decay_fraction,
374        }
375    }
376}
377
378fn select_mdl(curve: &EvVsKCurve, config: &KSelectionConfig) -> KSelection {
379    let pts = curve.points();
380    let k_max = curve.k_max() as f64;
381    let gamma = config.complexity_penalty;
382
383    let mut best_idx = 0usize;
384    let mut best_obj = f64::NEG_INFINITY;
385    for (i, p) in pts.iter().enumerate() {
386        let obj = p.ev - gamma * (p.k as f64 / k_max.max(MIN_DENOM));
387        if obj > best_obj {
388            best_obj = obj;
389            best_idx = i;
390        }
391    }
392
393    // Classify the MDL pick the same way the Kneedle path reports its endpoints
394    // so callers get a consistent flag vocabulary: interior pick => Knee,
395    // endpoint picks => the endpoint reason.
396    let flag = if best_idx == 0 {
397        KSelectionFlag::Flat
398    } else if best_idx == pts.len() - 1 {
399        KSelectionFlag::NoKnee
400    } else {
401        KSelectionFlag::Knee
402    };
403
404    KSelection {
405        k: pts[best_idx].k,
406        ev: pts[best_idx].ev,
407        flag,
408        score: best_obj,
409    }
410}
411
412/// The manifold-vs-linear parameter-efficiency advantage at a target EV.
413///
414/// Manifold reaches `target_ev` at `k_manifold` atoms, linear at `k_linear`
415/// atoms. The compression ratio `k_linear / k_manifold > 1` is the number of
416/// linear atoms a single manifold atom is worth at equal reconstruction
417/// quality.
418#[derive(Debug, Clone, Copy, PartialEq)]
419pub struct ManifoldVsLinearAdvantage {
420    /// Target explained variance both representations are compared at.
421    pub target_ev: f64,
422    /// Smallest manifold `K` reaching `target_ev`, if any.
423    pub k_manifold: Option<usize>,
424    /// Smallest linear `K` reaching `target_ev`, if any.
425    pub k_linear: Option<usize>,
426    /// `k_linear / k_manifold` when both reach the target, else `None`.
427    pub compression_ratio: Option<f64>,
428}
429
430impl ManifoldVsLinearAdvantage {
431    /// True iff the manifold representation is strictly more
432    /// parameter-efficient at the target EV (`k_manifold < k_linear`).
433    pub fn manifold_dominates(&self) -> bool {
434        match (self.k_manifold, self.k_linear) {
435            (Some(km), Some(kl)) => km < kl,
436            _ => false,
437        }
438    }
439}
440
441/// Compute the manifold-vs-linear advantage at `target_ev`: the cheapest `K`
442/// each representation needs to reach the target, and their ratio.
443pub fn manifold_vs_linear_advantage(
444    manifold: &EvVsKCurve,
445    linear: &EvVsKCurve,
446    target_ev: f64,
447) -> ManifoldVsLinearAdvantage {
448    let k_manifold = manifold.k_reaching(target_ev);
449    let k_linear = linear.k_reaching(target_ev);
450    let compression_ratio = match (k_manifold, k_linear) {
451        (Some(km), Some(kl)) if km > 0 => Some(kl as f64 / km as f64),
452        _ => None,
453    };
454    ManifoldVsLinearAdvantage {
455        target_ev,
456        k_manifold,
457        k_linear,
458        compression_ratio,
459    }
460}
461
462/// The full auto-`K` recommendation the OLMo research battery / hillclimb
463/// driver consumes: the knee-selected dictionary size on the manifold frontier,
464/// plus the manifold-vs-linear advantage measured at the EV the auto-`K` choice
465/// reaches.
466#[derive(Debug, Clone, Copy, PartialEq)]
467pub struct AutoKRecommendation {
468    /// Knee-selected dictionary size on the manifold frontier.
469    pub selection: KSelection,
470    /// Manifold-vs-linear advantage at the auto-`K` operating EV.
471    pub advantage: ManifoldVsLinearAdvantage,
472}
473
474/// One-call auto-`K`: knee-select `K` on the manifold EV-vs-`K` frontier, then
475/// report how many linear atoms would be needed to match the EV the selected
476/// `K` achieves.
477///
478/// This is the intended battery entry point: feed it the manifold and linear
479/// EV-vs-`K` sweeps measured on real OLMo L25 activations and the returned
480/// [`AutoKRecommendation::selection`] is the engine's auto-`K`, ready to be
481/// checked against the human-chosen `K`.
482pub fn recommend_auto_k(
483    manifold: &EvVsKCurve,
484    linear: &EvVsKCurve,
485    config: &KSelectionConfig,
486) -> AutoKRecommendation {
487    let selection = select_k(manifold, config);
488    let advantage = manifold_vs_linear_advantage(manifold, linear, selection.ev);
489    AutoKRecommendation {
490        selection,
491        advantage,
492    }
493}
494
495/// Build an [`EvVsKCurve`] from explicit `(K, EV)` pairs, e.g. the columns the
496/// OLMo research battery emits per sweep. Convenience wrapper over
497/// [`EvVsKCurve::new`].
498pub fn curve_from_pairs(pairs: &[(usize, f64)]) -> Result<EvVsKCurve, String> {
499    EvVsKCurve::new(
500        pairs
501            .iter()
502            .map(|&(k, ev)| EvVsKPoint::new(k, ev))
503            .collect(),
504    )
505}
506
507/// Explained variance (R²) of `fitted` against `x`, total-SS normalized and
508/// column-mean centered. Shared definition with the linear-dictionary lane so
509/// the manifold and linear EV-vs-`K` curves are on the same scale.
510pub fn explained_variance(x: ArrayView2<'_, f64>, fitted: ArrayView2<'_, f64>) -> f64 {
511    assert_eq!(
512        x.dim(),
513        fitted.dim(),
514        "explained_variance: x {:?} != fitted {:?}",
515        x.dim(),
516        fitted.dim()
517    );
518    let n = x.nrows();
519    if n == 0 || x.ncols() == 0 {
520        return 0.0;
521    }
522    let mut rss = 0.0;
523    for row in 0..n {
524        for col in 0..x.ncols() {
525            let r = x[[row, col]] - fitted[[row, col]];
526            rss += r * r;
527        }
528    }
529    let means = x
530        .mean_axis(ndarray::Axis(0))
531        .expect("non-empty input has means");
532    let mut tss = 0.0;
533    for row in 0..n {
534        for col in 0..x.ncols() {
535            let c = x[[row, col]] - means[col];
536            tss += c * c;
537        }
538    }
539    if tss <= MIN_DENOM {
540        if rss <= MIN_DENOM { 1.0 } else { 0.0 }
541    } else {
542        1.0 - rss / tss
543    }
544}
545
546/// Relative slope spread below which an EV-vs-K curve is deemed straight
547/// (no curvature => no knee).
548const LINEARITY_SLOPE_REL_TOL: f64 = 0.05;
549
550/// Floor for denominators that could otherwise be zero.
551const MIN_DENOM: f64 = 1.0e-12;
552
553#[cfg(test)]
554mod k_selection_tests {
555    use super::*;
556    use ndarray::array;
557
558    fn knee_curve() -> EvVsKCurve {
559        // Steep rise to K=4 (EV ~0.9), then near-flat saturation. Knee at K=4.
560        curve_from_pairs(&[
561            (1, 0.40),
562            (2, 0.65),
563            (3, 0.82),
564            (4, 0.90),
565            (8, 0.915),
566            (16, 0.92),
567            (32, 0.922),
568        ])
569        .expect("knee curve")
570    }
571
572    #[test]
573    fn kneedle_picks_the_elbow() {
574        let curve = knee_curve();
575        let sel = select_k(&curve, &KSelectionConfig::default());
576        assert_eq!(sel.flag, KSelectionFlag::Knee);
577        assert_eq!(sel.k, 4, "knee should sit at the saturation corner K=4");
578        assert!((sel.ev - 0.90).abs() < 1e-9);
579    }
580
581    #[test]
582    fn linear_curve_returns_full_k_with_flag() {
583        // EV grows exactly linearly in K: no curvature, no knee.
584        let curve = curve_from_pairs(&[(1, 0.10), (2, 0.20), (3, 0.30), (4, 0.40), (5, 0.50)])
585            .expect("linear curve");
586        let sel = select_k(&curve, &KSelectionConfig::default());
587        assert_eq!(sel.flag, KSelectionFlag::Linear);
588        assert_eq!(sel.k, 5, "linear curve returns the largest K");
589    }
590
591    #[test]
592    fn still_climbing_curve_returns_full_k_no_knee() {
593        // Concave but still steeply rising at K_max: not yet saturated.
594        let curve = curve_from_pairs(&[(1, 0.10), (2, 0.30), (3, 0.48), (4, 0.64), (5, 0.78)])
595            .expect("climbing curve");
596        let cfg = KSelectionConfig {
597            // demand a sharp saturation (1% of initial slope) it cannot meet
598            knee_slope_fraction: 0.01,
599            ..KSelectionConfig::default()
600        };
601        let sel = select_k(&curve, &cfg);
602        assert_eq!(sel.flag, KSelectionFlag::NoKnee);
603        assert_eq!(sel.k, 5);
604    }
605
606    #[test]
607    fn flat_curve_returns_smallest_k() {
608        let curve =
609            curve_from_pairs(&[(1, 0.900), (2, 0.9000001), (4, 0.9000002)]).expect("flat curve");
610        let sel = select_k(&curve, &KSelectionConfig::default());
611        assert_eq!(sel.flag, KSelectionFlag::Flat);
612        assert_eq!(sel.k, 1, "already-saturated curve returns smallest K");
613    }
614
615    #[test]
616    fn single_point_curve_is_flat() {
617        let curve = curve_from_pairs(&[(7, 0.5)]).expect("single point");
618        let sel = select_k(&curve, &KSelectionConfig::default());
619        assert_eq!(sel.flag, KSelectionFlag::Flat);
620        assert_eq!(sel.k, 7);
621    }
622
623    #[test]
624    fn mdl_picks_interior_knee_on_saturating_curve() {
625        let curve = knee_curve();
626        let cfg = KSelectionConfig {
627            mode: KSelectionMode::PenalizedMdl,
628            complexity_penalty: 0.05,
629            ..KSelectionConfig::default()
630        };
631        let sel = select_k(&curve, &cfg);
632        // MDL trades EV gain against K/K_max=K/32. Past the knee the EV gain is
633        // tiny while the size penalty keeps growing, so the optimum is interior.
634        assert!(
635            sel.k <= 8,
636            "MDL should not chase the saturated tail, got {}",
637            sel.k
638        );
639        assert!(
640            sel.k >= 3,
641            "MDL should not under-fit the steep rise, got {}",
642            sel.k
643        );
644    }
645
646    #[test]
647    fn mdl_penalty_zero_takes_full_k() {
648        let curve = knee_curve();
649        let cfg = KSelectionConfig {
650            mode: KSelectionMode::PenalizedMdl,
651            complexity_penalty: 0.0,
652            ..KSelectionConfig::default()
653        };
654        let sel = select_k(&curve, &cfg);
655        // With no complexity penalty, max EV (largest K) wins.
656        assert_eq!(sel.k, 32);
657    }
658
659    #[test]
660    fn advantage_metric_rewards_manifold_compression() {
661        // Manifold reaches EV 0.90 at K=4; linear needs K=16 for the same.
662        let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.65), (4, 0.90), (8, 0.93), (16, 0.94)])
663            .expect("manifold curve");
664        let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35), (4, 0.55), (8, 0.78), (16, 0.91)])
665            .expect("linear curve");
666        let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
667        assert_eq!(adv.k_manifold, Some(4));
668        assert_eq!(adv.k_linear, Some(16));
669        assert!(adv.manifold_dominates());
670        let ratio = adv.compression_ratio.expect("both reach target");
671        assert!(
672            (ratio - 4.0).abs() < 1e-12,
673            "expected 16/4 = 4x, got {ratio}"
674        );
675    }
676
677    #[test]
678    fn advantage_metric_handles_unreached_target() {
679        let manifold = curve_from_pairs(&[(1, 0.40), (2, 0.55)]).expect("manifold curve");
680        let linear = curve_from_pairs(&[(1, 0.20), (2, 0.35)]).expect("linear curve");
681        let adv = manifold_vs_linear_advantage(&manifold, &linear, 0.90);
682        assert_eq!(adv.k_manifold, None);
683        assert_eq!(adv.k_linear, None);
684        assert!(adv.compression_ratio.is_none());
685        assert!(!adv.manifold_dominates());
686    }
687
688    #[test]
689    fn recommend_auto_k_combines_knee_and_advantage() {
690        // Manifold knees at K=4 (EV 0.90); linear needs K=16 for EV 0.90.
691        let manifold = curve_from_pairs(&[
692            (1, 0.40),
693            (2, 0.65),
694            (3, 0.82),
695            (4, 0.90),
696            (8, 0.915),
697            (16, 0.92),
698            (32, 0.922),
699        ])
700        .expect("manifold curve");
701        let linear = curve_from_pairs(&[
702            (1, 0.20),
703            (2, 0.35),
704            (4, 0.55),
705            (8, 0.78),
706            (16, 0.91),
707            (32, 0.93),
708        ])
709        .expect("linear curve");
710        let rec = recommend_auto_k(&manifold, &linear, &KSelectionConfig::default());
711        assert_eq!(rec.selection.k, 4);
712        assert_eq!(rec.selection.flag, KSelectionFlag::Knee);
713        // At the auto-K EV (0.90) linear needs K=16 => 4x compression.
714        assert_eq!(rec.advantage.k_manifold, Some(4));
715        assert_eq!(rec.advantage.k_linear, Some(16));
716        assert!(rec.advantage.manifold_dominates());
717        let ratio = rec.advantage.compression_ratio.expect("both reach EV");
718        assert!((ratio - 4.0).abs() < 1e-12);
719    }
720
721    #[test]
722    fn explained_variance_matches_perfect_and_mean_baselines() {
723        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
724        // Perfect reconstruction => EV 1.
725        let ev_perfect = explained_variance(x.view(), x.view());
726        assert!((ev_perfect - 1.0).abs() < 1e-12);
727        // Mean-only reconstruction => EV 0.
728        let means = x.mean_axis(ndarray::Axis(0)).expect("means");
729        let mean_fit = array![
730            [means[0], means[1]],
731            [means[0], means[1]],
732            [means[0], means[1]]
733        ];
734        let ev_mean = explained_variance(x.view(), mean_fit.view());
735        assert!(
736            ev_mean.abs() < 1e-12,
737            "mean baseline EV should be 0, got {ev_mean}"
738        );
739    }
740
741    #[test]
742    fn curve_rejects_bad_input() {
743        assert!(EvVsKCurve::new(vec![]).is_err());
744        assert!(curve_from_pairs(&[(0, 0.5)]).is_err());
745        assert!(curve_from_pairs(&[(1, f64::NAN)]).is_err());
746        assert!(curve_from_pairs(&[(2, 0.5), (2, 0.6)]).is_err());
747    }
748
749    #[test]
750    fn curve_sorts_by_k() {
751        let curve = curve_from_pairs(&[(8, 0.9), (1, 0.4), (4, 0.8)]).expect("curve");
752        assert_eq!(curve.k_min(), 1);
753        assert_eq!(curve.k_max(), 8);
754        assert_eq!(curve.points()[0].k, 1);
755        assert_eq!(curve.points()[2].k, 8);
756    }
757
758    #[test]
759    fn mode_parse_roundtrips() {
760        assert_eq!(
761            KSelectionMode::parse("elbow").expect("parse"),
762            KSelectionMode::Kneedle
763        );
764        assert_eq!(
765            KSelectionMode::parse("MDL").expect("parse"),
766            KSelectionMode::PenalizedMdl
767        );
768        assert_eq!(KSelectionMode::Kneedle.as_str(), "kneedle");
769        assert!(KSelectionMode::parse("nonsense").is_err());
770    }
771}