Skip to main content

sphereql_embed/
tuner.rs

1//! Auto-tuner: search [`PipelineConfig`] space to maximize a [`QualityMetric`].
2//!
3//! This is the first usable rung of the metalearning ladder. Given a corpus
4//! and a scalar objective, the tuner enumerates or samples candidate
5//! configurations, builds a full pipeline for each, and records the score.
6//! Three strategies ship: exhaustive [`SearchStrategy::Grid`], uniform
7//! [`SearchStrategy::Random`], and the axis-parallel TPE-lite
8//! [`SearchStrategy::Bayesian`] acquisition — all reproducible under a
9//! fixed seed, establishing baselines for higher-order tuners (CMA-ES,
10//! meta-learning) to beat.
11//!
12//! Projections are fit **once per distinct fit-affecting hyperparameter
13//! tuple** from the input corpus and reused across every trial: PCA and
14//! Kernel PCA key per kind, Laplacian per `(k_neighbors,
15//! active_threshold)`, and UMAP per `(n_neighbors, n_epochs,
16//! category_weight, min_dist)` — with UMAP's kNN graph additionally
17//! cached per `n_neighbors` (see [`TuneReport::umap_graph_builds`]).
18//! Only the
19//! downstream config knobs (bridge thresholds, inner-sphere gates,
20//! domain-group counts, etc.) vary per trial.
21
22use std::collections::HashMap;
23use std::time::Instant;
24
25use crate::config::{
26    BridgeConfig, InnerSphereConfig, LaplacianConfig, PipelineConfig, ProjectionKind,
27    RoutingConfig, UmapConfig,
28};
29use crate::configured_projection::ConfiguredProjection;
30use crate::pipeline::{
31    PipelineError, PipelineInput, SphereQLPipeline, fit_projection_for_config, fit_umap_from_graph,
32};
33use crate::projection::SplitMix64;
34use crate::quality_metric::QualityMetric;
35use crate::types::Embedding;
36
37// ── Search space ─────────────────────────────────────────────────────
38
39/// Discrete candidate values for each tunable knob.
40///
41/// Every field holds the full set of values the tuner will consider for
42/// that knob. Grid search enumerates the Cartesian product; random search
43/// samples uniformly from each set per trial.
44///
45/// Defaults are chosen to bracket the historical hardcoded value on each
46/// knob, giving the tuner room to move either direction without being
47/// unreasonable.
48#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
49pub struct SearchSpace {
50    /// Candidate projection families for the outer sphere. Each kind is
51    /// prefit once per distinct fit-affecting hyperparameter tuple in
52    /// [`auto_tune`]; trials pick the prefit matching their config.
53    pub projection_kinds: Vec<ProjectionKind>,
54
55    // ── Projection-kind-specific knobs ────────────────────────────
56    // These only take effect when the trial's projection_kind matches.
57    // PCA trials ignore them (no waste — grid enumeration is
58    // kind-conditional, so PCA trials don't multiply against these
59    // dimensions).
60    /// Candidate values for [`LaplacianConfig::k_neighbors`]. Only
61    /// explored when [`ProjectionKind::LaplacianEigenmap`] is in
62    /// `projection_kinds`.
63    pub laplacian_k_neighbors: Vec<usize>,
64    /// Candidate values for [`LaplacianConfig::active_threshold`]. Only
65    /// explored when [`ProjectionKind::LaplacianEigenmap`] is in
66    /// `projection_kinds`.
67    pub laplacian_active_threshold: Vec<f64>,
68
69    /// Candidate values for [`UmapConfig::n_neighbors`]. Only explored
70    /// when [`ProjectionKind::UmapSphere`] is in `projection_kinds`.
71    pub umap_n_neighbors: Vec<usize>,
72    /// Candidate values for [`UmapConfig::n_epochs`]. Only explored
73    /// when [`ProjectionKind::UmapSphere`] is in `projection_kinds`.
74    pub umap_n_epochs: Vec<usize>,
75    /// Candidate values for [`UmapConfig::category_weight`]. Only
76    /// explored when [`ProjectionKind::UmapSphere`] is in
77    /// `projection_kinds`.
78    pub umap_category_weight: Vec<f64>,
79    /// Candidate values for [`UmapConfig::min_dist`]. Only explored
80    /// when [`ProjectionKind::UmapSphere`] is in `projection_kinds`.
81    pub umap_min_dist: Vec<f64>,
82
83    // ── Kind-agnostic knobs ───────────────────────────────────────
84    /// Candidate values for [`RoutingConfig::num_domain_groups`].
85    pub num_domain_groups: Vec<usize>,
86    /// Candidate values for [`RoutingConfig::low_evr_threshold`].
87    pub low_evr_threshold: Vec<f64>,
88    /// Candidate values for [`BridgeConfig::overlap_artifact_territorial`].
89    pub overlap_artifact_territorial: Vec<f64>,
90    /// Candidate values for [`BridgeConfig::threshold_base`].
91    pub threshold_base: Vec<f64>,
92    /// Candidate values for [`BridgeConfig::threshold_evr_penalty`].
93    pub threshold_evr_penalty: Vec<f64>,
94    /// Candidate values for [`InnerSphereConfig::min_evr_improvement`].
95    pub min_evr_improvement: Vec<f64>,
96}
97
98impl SearchSpace {
99    /// Search space optimized for large corpora (> 10 000 items).
100    ///
101    /// Includes PCA and UMAP (but not Laplacian eigenmap, which is O(N²)
102    /// on the affinity matrix). UMAP uses the ANN-backed kNN graph,
103    /// making it O(N log N) for graph construction.
104    pub fn large_corpus() -> Self {
105        Self {
106            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::UmapSphere],
107            // Laplacian is excluded — kept as singletons so the axes
108            // exist if a caller swaps the kind in later, but they cost
109            // nothing at grid time because Laplacian isn't enumerated.
110            laplacian_k_neighbors: vec![15],
111            laplacian_active_threshold: vec![0.05],
112            umap_n_neighbors: vec![10, 15, 30],
113            umap_n_epochs: vec![150, 300],
114            umap_category_weight: vec![0.0, 1.5, 3.0],
115            umap_min_dist: vec![0.0, 0.1, 0.25],
116            num_domain_groups: vec![3, 5, 7],
117            low_evr_threshold: vec![0.25, 0.35],
118            overlap_artifact_territorial: vec![0.2, 0.3],
119            threshold_base: vec![0.4, 0.5],
120            threshold_evr_penalty: vec![0.3, 0.5],
121            min_evr_improvement: vec![0.05, 0.10],
122        }
123    }
124}
125
126impl Default for SearchSpace {
127    fn default() -> Self {
128        Self {
129            // Kernel PCA has O(n²·d) fit and is excluded from the default
130            // sweep — callers who want it can add ProjectionKind::KernelPca
131            // explicitly, accepting the longer fit cost.
132            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
133            // Laplacian hyperparameters bracket the default values
134            // (k=15, threshold=0.05) widely enough that the tuner can
135            // actually move the projection's geometry.
136            laplacian_k_neighbors: vec![10, 15, 25],
137            laplacian_active_threshold: vec![0.03, 0.05, 0.10],
138            umap_n_neighbors: vec![10, 15, 30],
139            umap_n_epochs: vec![150, 250],
140            umap_category_weight: vec![0.0, 1.5, 3.0],
141            umap_min_dist: vec![0.0, 0.1, 0.25],
142            num_domain_groups: vec![3, 5, 7],
143            low_evr_threshold: vec![0.25, 0.35, 0.45],
144            overlap_artifact_territorial: vec![0.2, 0.3, 0.4],
145            threshold_base: vec![0.4, 0.5, 0.6],
146            threshold_evr_penalty: vec![0.2, 0.4, 0.6],
147            min_evr_improvement: vec![0.05, 0.10, 0.15],
148        }
149    }
150}
151
152impl SearchSpace {
153    /// Validate this space against a [`SearchStrategy`] and return a
154    /// structured [`PipelineError::InvalidSearchSpace`] if anything is
155    /// off — empty axis, missing kind-specific knobs, or budget too
156    /// small for the strategy to make progress. Called upfront from
157    /// [`auto_tune`] so every strategy fails at the same boundary
158    /// instead of panicking mid-trial (random/Bayesian) or silently
159    /// rolling up into `AllTrialsFailed { failures: [] }` (Grid).
160    pub fn validate(&self, strategy: &SearchStrategy) -> Result<(), PipelineError> {
161        if self.projection_kinds.is_empty() {
162            return Err(PipelineError::InvalidSearchSpace(
163                "axis `projection_kinds` is empty".into(),
164            ));
165        }
166        for &kind in &self.projection_kinds {
167            self.check_axes_non_empty(kind)?;
168        }
169        match strategy {
170            SearchStrategy::Grid => {}
171            SearchStrategy::Random { budget, .. } => {
172                if *budget == 0 {
173                    return Err(PipelineError::InvalidSearchSpace(
174                        "Random search requires budget >= 1".into(),
175                    ));
176                }
177            }
178            SearchStrategy::Bayesian {
179                budget,
180                warmup,
181                gamma,
182                ..
183            } => {
184                if *budget < 2 {
185                    return Err(PipelineError::InvalidSearchSpace(format!(
186                        "Bayesian search requires budget >= 2 (got {budget})"
187                    )));
188                }
189                if *warmup < 2 {
190                    return Err(PipelineError::InvalidSearchSpace(format!(
191                        "Bayesian search requires warmup >= 2 (got {warmup})"
192                    )));
193                }
194                if !gamma.is_finite() || *gamma <= 0.0 || *gamma >= 1.0 {
195                    return Err(PipelineError::InvalidSearchSpace(format!(
196                        "Bayesian gamma must be in (0, 1), got {gamma}"
197                    )));
198                }
199            }
200        }
201        Ok(())
202    }
203
204    fn check_axes_non_empty(&self, kind: ProjectionKind) -> Result<(), PipelineError> {
205        let common = [
206            ("num_domain_groups", self.num_domain_groups.len()),
207            ("low_evr_threshold", self.low_evr_threshold.len()),
208            (
209                "overlap_artifact_territorial",
210                self.overlap_artifact_territorial.len(),
211            ),
212            ("threshold_base", self.threshold_base.len()),
213            ("threshold_evr_penalty", self.threshold_evr_penalty.len()),
214            ("min_evr_improvement", self.min_evr_improvement.len()),
215        ];
216        for (name, len) in common {
217            if len == 0 {
218                return Err(PipelineError::InvalidSearchSpace(format!(
219                    "axis `{name}` is empty"
220                )));
221            }
222        }
223        if matches!(kind, ProjectionKind::LaplacianEigenmap) {
224            if self.laplacian_k_neighbors.is_empty() {
225                return Err(PipelineError::InvalidSearchSpace(
226                    "axis `laplacian_k_neighbors` is empty".into(),
227                ));
228            }
229            if self.laplacian_active_threshold.is_empty() {
230                return Err(PipelineError::InvalidSearchSpace(
231                    "axis `laplacian_active_threshold` is empty".into(),
232                ));
233            }
234        }
235        if matches!(kind, ProjectionKind::UmapSphere) {
236            if self.umap_n_neighbors.is_empty() {
237                return Err(PipelineError::InvalidSearchSpace(
238                    "axis `umap_n_neighbors` is empty".into(),
239                ));
240            }
241            if self.umap_n_epochs.is_empty() {
242                return Err(PipelineError::InvalidSearchSpace(
243                    "axis `umap_n_epochs` is empty".into(),
244                ));
245            }
246            if self.umap_category_weight.is_empty() {
247                return Err(PipelineError::InvalidSearchSpace(
248                    "axis `umap_category_weight` is empty".into(),
249                ));
250            }
251            if self.umap_min_dist.is_empty() {
252                return Err(PipelineError::InvalidSearchSpace(
253                    "axis `umap_min_dist` is empty".into(),
254                ));
255            }
256        }
257        Ok(())
258    }
259
260    /// Defensive guard for the public [`Self::config_at_index`] path.
261    /// `auto_tune` always validates upfront via [`Self::validate`], but
262    /// external callers that decode grid indices on a hand-built
263    /// [`SearchSpace`] still need a clear panic instead of a
264    /// mod-by-zero deeper in the decoder.
265    fn assert_axes_non_empty(&self, kind: ProjectionKind) {
266        if let Err(e) = self.check_axes_non_empty(kind) {
267            panic!("{e}");
268        }
269    }
270
271    /// Number of kind-agnostic knob combinations. Every projection kind's
272    /// grid slice is at least this large; Laplacian multiplies by its
273    /// specific knob counts on top.
274    fn common_cardinality(&self) -> usize {
275        self.num_domain_groups.len()
276            * self.low_evr_threshold.len()
277            * self.overlap_artifact_territorial.len()
278            * self.threshold_base.len()
279            * self.threshold_evr_penalty.len()
280            * self.min_evr_improvement.len()
281    }
282
283    /// Per-kind grid cardinality — common knobs × any kind-specific
284    /// knobs this kind opts into.
285    fn kind_cardinality(&self, kind: ProjectionKind) -> usize {
286        let common = self.common_cardinality();
287        match kind {
288            ProjectionKind::LaplacianEigenmap => {
289                common * self.laplacian_k_neighbors.len() * self.laplacian_active_threshold.len()
290            }
291            ProjectionKind::UmapSphere => {
292                common
293                    * self.umap_n_neighbors.len()
294                    * self.umap_n_epochs.len()
295                    * self.umap_category_weight.len()
296                    * self.umap_min_dist.len()
297            }
298            ProjectionKind::Pca | ProjectionKind::KernelPca => common,
299        }
300    }
301
302    /// Cardinality of the kind-conditional grid: the sum of each projection
303    /// kind's own slice. `grid` search visits exactly this many configurations.
304    pub fn grid_cardinality(&self) -> usize {
305        self.projection_kinds
306            .iter()
307            .map(|&k| self.kind_cardinality(k))
308            .sum()
309    }
310
311    /// Build a [`PipelineConfig`] from one grid index.
312    ///
313    /// The grid is laid out as disjoint per-kind slices concatenated in
314    /// the order of [`Self::projection_kinds`]: indices 0..c₀ enumerate
315    /// the first kind's subspace, c₀..c₀+c₁ the second kind's, etc. This
316    /// keeps kind-specific knobs (e.g. Laplacian's k, threshold) from
317    /// multiplying against trials of other kinds that wouldn't use them.
318    pub fn config_at_index(&self, index: usize, base: &PipelineConfig) -> Option<PipelineConfig> {
319        let mut offset = 0usize;
320        for &kind in &self.projection_kinds {
321            self.assert_axes_non_empty(kind);
322            let slice = self.kind_cardinality(kind);
323            if index < offset + slice {
324                return Some(self.config_at_kind_index(kind, index - offset, base));
325            }
326            offset += slice;
327        }
328        None
329    }
330
331    /// Decode an index within a single kind's slice.
332    fn config_at_kind_index(
333        &self,
334        kind: ProjectionKind,
335        mut idx: usize,
336        base: &PipelineConfig,
337    ) -> PipelineConfig {
338        let take = |idx: &mut usize, len: usize| -> usize {
339            let v = *idx % len;
340            *idx /= len;
341            v
342        };
343
344        let i_ndg = take(&mut idx, self.num_domain_groups.len());
345        let i_let = take(&mut idx, self.low_evr_threshold.len());
346        let i_oat = take(&mut idx, self.overlap_artifact_territorial.len());
347        let i_tb = take(&mut idx, self.threshold_base.len());
348        let i_tep = take(&mut idx, self.threshold_evr_penalty.len());
349        let i_mei = take(&mut idx, self.min_evr_improvement.len());
350
351        let mut cfg = base.clone();
352        cfg.projection_kind = kind;
353        cfg.routing = RoutingConfig {
354            num_domain_groups: self.num_domain_groups[i_ndg],
355            low_evr_threshold: self.low_evr_threshold[i_let],
356            ..base.routing.clone()
357        };
358        cfg.bridges = BridgeConfig {
359            threshold_base: self.threshold_base[i_tb],
360            threshold_evr_penalty: self.threshold_evr_penalty[i_tep],
361            overlap_artifact_territorial: self.overlap_artifact_territorial[i_oat],
362            ..base.bridges.clone()
363        };
364        cfg.inner_sphere = InnerSphereConfig {
365            min_evr_improvement: self.min_evr_improvement[i_mei],
366            ..base.inner_sphere.clone()
367        };
368
369        if matches!(kind, ProjectionKind::LaplacianEigenmap) {
370            let i_k = take(&mut idx, self.laplacian_k_neighbors.len());
371            let i_thr = take(&mut idx, self.laplacian_active_threshold.len());
372            cfg.laplacian = LaplacianConfig {
373                k_neighbors: self.laplacian_k_neighbors[i_k],
374                active_threshold: self.laplacian_active_threshold[i_thr],
375            };
376        }
377
378        if matches!(kind, ProjectionKind::UmapSphere) {
379            let i_nn = take(&mut idx, self.umap_n_neighbors.len());
380            let i_ne = take(&mut idx, self.umap_n_epochs.len());
381            let i_cw = take(&mut idx, self.umap_category_weight.len());
382            let i_md = take(&mut idx, self.umap_min_dist.len());
383            cfg.umap = UmapConfig {
384                n_neighbors: self.umap_n_neighbors[i_nn],
385                n_epochs: self.umap_n_epochs[i_ne],
386                category_weight: self.umap_category_weight[i_cw],
387                min_dist: self.umap_min_dist[i_md],
388                ..base.umap.clone()
389            };
390        }
391
392        cfg
393    }
394
395    /// Sample one random [`PipelineConfig`] from this space. Every knob's
396    /// value set is sampled uniformly and independently; kind-specific
397    /// knobs are only sampled when the sampled kind uses them. Internal
398    /// to the tuner — external callers go through [`auto_tune`] with a
399    /// [`SearchStrategy::Random`] strategy.
400    pub(crate) fn sample(&self, rng: &mut SplitMix64, base: &PipelineConfig) -> PipelineConfig {
401        // `auto_tune` calls [`Self::validate`] upfront, so by the time
402        // we reach here projection_kinds and every per-kind axis are
403        // guaranteed non-empty. `debug_assert!` keeps the invariant
404        // visible during development without re-introducing a runtime
405        // panic on the hot path.
406        debug_assert!(
407            !self.projection_kinds.is_empty(),
408            "SearchSpace::sample called without prior validate()"
409        );
410        let mut cfg = base.clone();
411        cfg.projection_kind = pick_uniform(rng, &self.projection_kinds);
412        debug_assert!(
413            self.check_axes_non_empty(cfg.projection_kind).is_ok(),
414            "SearchSpace::sample called without prior validate()"
415        );
416        cfg.routing = RoutingConfig {
417            num_domain_groups: pick_uniform(rng, &self.num_domain_groups),
418            low_evr_threshold: pick_uniform(rng, &self.low_evr_threshold),
419            ..base.routing.clone()
420        };
421        cfg.bridges = BridgeConfig {
422            threshold_base: pick_uniform(rng, &self.threshold_base),
423            threshold_evr_penalty: pick_uniform(rng, &self.threshold_evr_penalty),
424            overlap_artifact_territorial: pick_uniform(rng, &self.overlap_artifact_territorial),
425            ..base.bridges.clone()
426        };
427        cfg.inner_sphere = InnerSphereConfig {
428            min_evr_improvement: pick_uniform(rng, &self.min_evr_improvement),
429            ..base.inner_sphere.clone()
430        };
431
432        if matches!(cfg.projection_kind, ProjectionKind::LaplacianEigenmap) {
433            cfg.laplacian = LaplacianConfig {
434                k_neighbors: pick_uniform(rng, &self.laplacian_k_neighbors),
435                active_threshold: pick_uniform(rng, &self.laplacian_active_threshold),
436            };
437        }
438
439        if matches!(cfg.projection_kind, ProjectionKind::UmapSphere) {
440            cfg.umap = UmapConfig {
441                n_neighbors: pick_uniform(rng, &self.umap_n_neighbors),
442                n_epochs: pick_uniform(rng, &self.umap_n_epochs),
443                category_weight: pick_uniform(rng, &self.umap_category_weight),
444                min_dist: pick_uniform(rng, &self.umap_min_dist),
445                ..base.umap.clone()
446            };
447        }
448
449        cfg
450    }
451}
452
453// ── Prefit cache key ─────────────────────────────────────────────────
454
455/// Identifies a single fittable projection configuration.
456///
457/// Two [`PipelineConfig`]s that produce the same `ProjectionFitKey` share
458/// a prefit projection; two that differ need distinct fits. PCA and
459/// Kernel PCA have no fit-affecting hyperparameters in the current
460/// search space so they share a key per kind; Laplacian's fit depends on
461/// (k_neighbors, active_threshold).
462#[derive(Clone, PartialEq, Eq, Hash)]
463enum ProjectionFitKey {
464    Pca,
465    KernelPca,
466    Laplacian {
467        k: usize,
468        threshold_bits: u64,
469    },
470    UmapSphere {
471        n_neighbors: usize,
472        n_epochs: usize,
473        category_weight_bits: u64,
474        min_dist_bits: u64,
475    },
476}
477
478impl ProjectionFitKey {
479    fn from_config(cfg: &PipelineConfig) -> Self {
480        match cfg.projection_kind {
481            ProjectionKind::Pca => Self::Pca,
482            ProjectionKind::KernelPca => Self::KernelPca,
483            ProjectionKind::LaplacianEigenmap => Self::Laplacian {
484                k: cfg.laplacian.k_neighbors,
485                threshold_bits: cfg.laplacian.active_threshold.to_bits(),
486            },
487            ProjectionKind::UmapSphere => Self::UmapSphere {
488                n_neighbors: cfg.umap.n_neighbors,
489                n_epochs: cfg.umap.n_epochs,
490                category_weight_bits: cfg.umap.category_weight.to_bits(),
491                // min_dist sets the optimizer's kernel (a, b), so two
492                // configs differing only here must not share a fitted
493                // projection — though they still share the kNN graph,
494                // which is built before the optimizer runs.
495                min_dist_bits: cfg.umap.min_dist.to_bits(),
496            },
497        }
498    }
499}
500
501// ── Strategy, report, trial record ───────────────────────────────────────
502
503/// Which enumeration to use over the [`SearchSpace`].
504#[derive(Debug, Clone)]
505pub enum SearchStrategy {
506    /// Exhaustive Cartesian-product enumeration. Cost scales with the
507    /// grid cardinality — see [`SearchSpace::grid_cardinality`].
508    Grid,
509    /// Uniform random sampling for `budget` trials.
510    Random {
511        budget: usize,
512        seed: u64,
513        /// Optional wall-time cap in seconds. When set, the tuner stops
514        /// proposing new trials once cumulative elapsed time exceeds
515        /// this limit. Already-running trials are not interrupted.
516        /// `None` = unlimited (legacy behavior).
517        max_wall_secs: Option<u64>,
518    },
519    /// Sequential Bayesian-ish search. After `warmup` uniform random
520    /// trials, subsequent trials pick each knob's value by the ratio of
521    /// per-value probabilities between the top-`gamma`-fraction trials
522    /// (“good”) and the bottom `1 − gamma` (“bad”). This is an
523    /// axis-parallel TPE-lite acquisition: independent across knobs,
524    /// Laplace-smoothed, reproducible under a fixed `seed`.
525    ///
526    /// Trades a constant-factor more code for meaningful sample
527    /// efficiency versus uniform random — typical win on our default
528    /// space is ~30% fewer trials to reach the random-search ceiling.
529    Bayesian {
530        budget: usize,
531        /// Initial uniform random trials before the acquisition kicks in.
532        /// Must be ≥ 2 so the "good" / "bad" split is non-degenerate.
533        warmup: usize,
534        /// Fraction of past trials treated as "good" when fitting the
535        /// acquisition. 0.25 is the TPE default; smaller = more exploit,
536        /// larger = more explore.
537        gamma: f64,
538        seed: u64,
539        /// Optional wall-time cap in seconds. Same semantics as
540        /// [`Self::Random::max_wall_secs`].
541        max_wall_secs: Option<u64>,
542    },
543}
544
545impl SearchStrategy {
546    /// Extract the wall-time cap, if one was set.
547    fn max_wall_secs(&self) -> Option<u64> {
548        match self {
549            Self::Random { max_wall_secs, .. } => *max_wall_secs,
550            Self::Bayesian { max_wall_secs, .. } => *max_wall_secs,
551            Self::Grid => None,
552        }
553    }
554}
555
556/// One trial's observation.
557#[derive(Debug, Clone)]
558pub struct TrialRecord {
559    pub config: PipelineConfig,
560    pub score: f64,
561    /// Wall-clock build time for this trial (pipeline rebuild only —
562    /// projection fit is amortized across the tuner run).
563    pub build_ms: u128,
564    /// Per-component metric breakdown as `(name, weight, score)`.
565    /// Populated when the metric is a composite (see
566    /// [`QualityMetric::score_with_components`]); empty for leaf
567    /// metrics. The fastest way to diagnose a flat tuner landscape:
568    /// a component whose score barely varies across trials carries no
569    /// signal for the knobs being swept.
570    pub components: Vec<(String, f64, f64)>,
571}
572
573/// Full tuner output.
574#[derive(Debug, Clone)]
575pub struct TuneReport {
576    pub metric_name: String,
577    pub best_score: f64,
578    pub best_config: PipelineConfig,
579    pub trials: Vec<TrialRecord>,
580    /// Trials that failed to build (e.g., too few embeddings, config
581    /// combination rejected by a downstream validator). Each entry is
582    /// `(config, error_message)`.
583    pub failures: Vec<(PipelineConfig, String)>,
584    /// Number of distinct UMAP kNN graphs built during the run. The
585    /// tuner caches graphs by `n_neighbors`, so this equals the number
586    /// of unique `n_neighbors` values tried across UMAP trials. Lower
587    /// than the count of UMAP trials means the cache hit — a metric
588    /// for verifying the reuse path is actually firing.
589    pub umap_graph_builds: usize,
590}
591
592impl TuneReport {
593    /// Trials ranked by descending score.
594    pub fn ranked_trials(&self) -> Vec<&TrialRecord> {
595        let mut refs: Vec<&TrialRecord> = self.trials.iter().collect();
596        refs.sort_by(|a, b| {
597            b.score
598                .partial_cmp(&a.score)
599                .unwrap_or(std::cmp::Ordering::Equal)
600        });
601        refs
602    }
603
604    /// Mean score across successful trials. Useful for gauging how
605    /// sensitive the pipeline is to the tuned knobs: a flat landscape
606    /// means the knobs don't matter on this corpus.
607    pub fn mean_score(&self) -> f64 {
608        if self.trials.is_empty() {
609            return 0.0;
610        }
611        self.trials.iter().map(|t| t.score).sum::<f64>() / self.trials.len() as f64
612    }
613}
614
615// ── The tuner itself ─────────────────────────────────────────────────
616
617/// Run the auto-tuner and return the best pipeline plus a report.
618///
619/// Fits one projection per [`ProjectionKind`] listed in
620/// `space.projection_kinds` (honoring Laplacian hyperparameters from
621/// `base_config.laplacian`), then reuses those prefit projections across
622/// every trial. Only the downstream [`PipelineConfig`] knobs (bridge
623/// thresholds, inner-sphere gates, domain-group counts, etc.) vary per
624/// trial — this keeps per-trial cost dominated by spatial quality
625/// sampling and graph construction rather than projection fitting.
626///
627/// Under [`SearchStrategy::Random`] and [`SearchStrategy::Bayesian`],
628/// `base_config` itself is evaluated as trial 0 (counted against the
629/// budget) so a warm-start prediction competes directly with sampled
630/// candidates. [`SearchStrategy::Grid`] is excluded: its trial set is
631/// defined as the exact Cartesian enumeration of the space, and callers
632/// assert on [`SearchSpace::grid_cardinality`] matching the trial count.
633pub fn auto_tune<M: QualityMetric + ?Sized>(
634    input: PipelineInput,
635    space: &SearchSpace,
636    metric: &M,
637    strategy: SearchStrategy,
638    base_config: &PipelineConfig,
639) -> Result<(SphereQLPipeline, TuneReport), PipelineError> {
640    // Validate space + strategy upfront so every search mode fails at
641    // the same boundary with a structured PipelineError instead of
642    // panicking mid-trial (random/Bayesian) or silently rolling up
643    // into AllTrialsFailed { failures: [] } (Grid).
644    space.validate(&strategy)?;
645
646    // `PipelineInput` is owned — move the Vec<f64>s straight into the
647    // Embedding wrappers instead of cloning each row.
648    let categories = input.categories;
649    let embeddings: Vec<Embedding> = input.embeddings.into_iter().map(Embedding::new).collect();
650
651    let mut prefit: HashMap<ProjectionFitKey, ConfiguredProjection> = HashMap::new();
652    // UMAP kNN graphs are reusable across configs that share `n_neighbors`
653    // but differ in `n_epochs` / `category_weight` / `min_dist`. Building
654    // the graph
655    // dominates UMAP fit cost (O(N log N) for the ANN-backed graph plus
656    // PCA warm-start), so caching it collapses the per-config sweep onto
657    // a handful of graph builds.
658    let mut umap_graph_cache: HashMap<usize, crate::umap::UmapGraph> = HashMap::new();
659    let mut umap_graph_builds: usize = 0;
660    let mut trials: Vec<TrialRecord> = Vec::new();
661    let mut failures: Vec<(PipelineConfig, String)> = Vec::new();
662    // Only the current best trial's pipeline stays alive — keeping every
663    // trial's pipeline would multiply peak memory by the trial count at
664    // 500k scale. Replaced (and the old one dropped) whenever a trial
665    // scores at least as high, matching the old post-loop `max_by`
666    // selection (last max wins) without rebuilding the winner.
667    let mut best: Option<(f64, SphereQLPipeline)> = None;
668
669    // Closure: evaluate one config, update prefit cache, push record or
670    // failure. Shared by every strategy so they only differ in how they
671    // propose configs.
672    let run_trial = |cfg: PipelineConfig,
673                     prefit: &mut HashMap<ProjectionFitKey, ConfiguredProjection>,
674                     umap_graph_cache: &mut HashMap<usize, crate::umap::UmapGraph>,
675                     umap_graph_builds: &mut usize,
676                     trials: &mut Vec<TrialRecord>,
677                     failures: &mut Vec<(PipelineConfig, String)>,
678                     best: &mut Option<(f64, SphereQLPipeline)>| {
679        let key = ProjectionFitKey::from_config(&cfg);
680        let projection = if cfg.projection_kind == ProjectionKind::UmapSphere {
681            // UMAP fast path: build the kNN graph once per `n_neighbors`
682            // and reuse it across `(n_epochs, category_weight, min_dist)`
683            // variations.
684            // The fully-realized projection still goes into `prefit` so
685            // the final pipeline rebuild and any exact-config repeats are
686            // free.
687            match prefit.get(&key) {
688                Some(p) => p.clone(),
689                None => {
690                    let k = cfg.umap.n_neighbors;
691                    if let std::collections::hash_map::Entry::Vacant(entry) =
692                        umap_graph_cache.entry(k)
693                    {
694                        match crate::umap::UmapGraph::build(&embeddings, k) {
695                            Ok(g) => {
696                                entry.insert(g);
697                                *umap_graph_builds += 1;
698                            }
699                            Err(err) => {
700                                failures.push((cfg, err.to_string()));
701                                return;
702                            }
703                        }
704                    }
705                    let graph = &umap_graph_cache[&k];
706                    match fit_umap_from_graph(graph, &categories, &cfg) {
707                        Ok(p) => {
708                            prefit.insert(key, p.clone());
709                            p
710                        }
711                        Err(e) => {
712                            failures.push((cfg, e.to_string()));
713                            return;
714                        }
715                    }
716                }
717            }
718        } else {
719            match prefit.get(&key) {
720                Some(p) => p.clone(),
721                None => match fit_projection_for_config(&embeddings, &categories, &cfg) {
722                    Ok(p) => {
723                        prefit.insert(key, p.clone());
724                        p
725                    }
726                    Err(e) => {
727                        failures.push((cfg, e.to_string()));
728                        return;
729                    }
730                },
731            }
732        };
733
734        let start = Instant::now();
735        // Embeddings are borrowed (the pipeline doesn't retain them), so
736        // only `categories` — which it does own — is cloned per trial.
737        // TODO: an Arc<[String]> categories field would drop that clone
738        // too (~tens of MB at 500k), but it touches the pipeline's
739        // retained-field accessors and downstream crates.
740        match SphereQLPipeline::with_projection_parts(
741            categories.clone(),
742            &embeddings,
743            projection,
744            cfg.clone(),
745        ) {
746            Ok(pipeline) => {
747                let (score, components) = metric.score_with_components(&pipeline);
748                let build_ms = start.elapsed().as_millis();
749                trials.push(TrialRecord {
750                    config: cfg,
751                    score,
752                    build_ms,
753                    components,
754                });
755                let replace = match best {
756                    Some((best_score, _)) => !matches!(
757                        score.partial_cmp(best_score),
758                        Some(std::cmp::Ordering::Less)
759                    ),
760                    None => true,
761                };
762                if replace {
763                    *best = Some((score, pipeline));
764                }
765            }
766            Err(e) => {
767                failures.push((cfg, e.to_string()));
768            }
769        }
770    };
771
772    let wall_start = Instant::now();
773    let max_wall = strategy.max_wall_secs();
774    let wall_exceeded = || match max_wall {
775        Some(max_secs) => wall_start.elapsed().as_secs() >= max_secs,
776        None => false,
777    };
778
779    match &strategy {
780        SearchStrategy::Grid => {
781            // Grid deliberately skips the base-config seed trial: its
782            // contract is "trial set == the exact Cartesian enumeration"
783            // and callers assert on grid_cardinality matching the count.
784            for i in 0..space.grid_cardinality() {
785                if let Some(cfg) = space.config_at_index(i, base_config) {
786                    run_trial(
787                        cfg,
788                        &mut prefit,
789                        &mut umap_graph_cache,
790                        &mut umap_graph_builds,
791                        &mut trials,
792                        &mut failures,
793                        &mut best,
794                    );
795                }
796            }
797        }
798        SearchStrategy::Random { budget, seed, .. } => {
799            let mut rng = SplitMix64::new(*seed);
800            // Trial 0: warm-start seed. base_config competes directly
801            // with the sampled candidates and counts against the budget.
802            run_trial(
803                base_config.clone(),
804                &mut prefit,
805                &mut umap_graph_cache,
806                &mut umap_graph_builds,
807                &mut trials,
808                &mut failures,
809                &mut best,
810            );
811            if !wall_exceeded() {
812                for _ in 1..*budget {
813                    let cfg = space.sample(&mut rng, base_config);
814                    run_trial(
815                        cfg,
816                        &mut prefit,
817                        &mut umap_graph_cache,
818                        &mut umap_graph_builds,
819                        &mut trials,
820                        &mut failures,
821                        &mut best,
822                    );
823                    if wall_exceeded() {
824                        break;
825                    }
826                }
827            }
828        }
829        SearchStrategy::Bayesian {
830            budget,
831            warmup,
832            gamma,
833            seed,
834            ..
835        } => {
836            // budget/warmup/gamma already validated above by space.validate(&strategy).
837            let budget = *budget;
838            let mut rng = SplitMix64::new(*seed);
839            let warmup = (*warmup).clamp(2, budget);
840            let gamma = gamma.clamp(0.05, 0.95);
841
842            // Trial 0: warm-start seed (counts as the first warmup trial).
843            run_trial(
844                base_config.clone(),
845                &mut prefit,
846                &mut umap_graph_cache,
847                &mut umap_graph_builds,
848                &mut trials,
849                &mut failures,
850                &mut best,
851            );
852            // Remaining warmup: uniform random.
853            if !wall_exceeded() {
854                for _ in 1..warmup {
855                    let cfg = space.sample(&mut rng, base_config);
856                    run_trial(
857                        cfg,
858                        &mut prefit,
859                        &mut umap_graph_cache,
860                        &mut umap_graph_builds,
861                        &mut trials,
862                        &mut failures,
863                        &mut best,
864                    );
865                    if wall_exceeded() {
866                        break;
867                    }
868                }
869            }
870            // Acquisition: axis-parallel TPE-lite.
871            if !wall_exceeded() {
872                for _ in warmup..budget {
873                    let cfg = tpe_propose(space, base_config, &trials, gamma, &mut rng);
874                    run_trial(
875                        cfg,
876                        &mut prefit,
877                        &mut umap_graph_cache,
878                        &mut umap_graph_builds,
879                        &mut trials,
880                        &mut failures,
881                        &mut best,
882                    );
883                    if wall_exceeded() {
884                        break;
885                    }
886                }
887            }
888        }
889    }
890
891    if trials.is_empty() {
892        // Every candidate config was rejected downstream. Surface the
893        // real failure list instead of the misleading `TooFewEmbeddings`
894        // roll-up we used to return here.
895        return Err(PipelineError::AllTrialsFailed { failures });
896    }
897
898    // Every successful trial offered its pipeline to `best`, so a
899    // non-empty `trials` guarantees one was kept. Returning it directly
900    // saves rebuilding the winner from scratch (a second O(N·d)
901    // projection pass + category-layer build).
902    let (best_score, best_pipeline) = best.expect("non-empty trials imply a kept best pipeline");
903    let best_config = best_pipeline.config().clone();
904
905    let report = TuneReport {
906        metric_name: metric.name().to_string(),
907        best_score,
908        best_config,
909        trials,
910        failures,
911        umap_graph_builds,
912    };
913
914    Ok((best_pipeline, report))
915}
916
917// ── TPE-lite acquisition ──────────────────────────────────────────────
918
919/// Propose the next [`PipelineConfig`] using axis-parallel good/bad
920/// ratios over the trial history.
921///
922/// For each knob, counts how often each candidate value appeared in the
923/// top-`gamma` fraction ("good") of past trials vs. the rest ("bad").
924/// Samples the next value with probability proportional to
925/// `(good + 1) / (bad + 1)` per candidate, Laplace-smoothed so no value
926/// is ever assigned zero probability.
927///
928/// Kind-specific knobs (Laplacian's `k`, `active_threshold`) condition on
929/// kind — their histograms are built from kind-matching trials only, with
930/// a uniform fallback when fewer than 2 kind-matching trials exist.
931fn tpe_propose(
932    space: &SearchSpace,
933    base: &PipelineConfig,
934    trials: &[TrialRecord],
935    gamma: f64,
936    rng: &mut SplitMix64,
937) -> PipelineConfig {
938    // Sort by descending score, split at gamma threshold.
939    let mut sorted: Vec<&TrialRecord> = trials.iter().collect();
940    sorted.sort_by(|a, b| {
941        b.score
942            .partial_cmp(&a.score)
943            .unwrap_or(std::cmp::Ordering::Equal)
944    });
945    let n_good = ((sorted.len() as f64) * gamma).ceil() as usize;
946    let n_good = n_good.max(1).min(sorted.len().saturating_sub(1).max(1));
947    let good: Vec<&TrialRecord> = sorted.iter().take(n_good).copied().collect();
948    let bad: Vec<&TrialRecord> = sorted.iter().skip(n_good).copied().collect();
949
950    // Fall back to uniform sampling if we somehow don't have both sides.
951    if good.is_empty() || bad.is_empty() {
952        return space.sample(rng, base);
953    }
954
955    let pick_idx = |rng: &mut SplitMix64, good_counts: &[f64], bad_counts: &[f64]| -> usize {
956        let n_g = good_counts.iter().sum::<f64>() + good_counts.len() as f64;
957        let n_b = bad_counts.iter().sum::<f64>() + bad_counts.len() as f64;
958        let weights: Vec<f64> = good_counts
959            .iter()
960            .zip(bad_counts.iter())
961            .map(|(&g, &b)| ((g + 1.0) / n_g) / ((b + 1.0) / n_b))
962            .collect();
963        sample_categorical(rng, &weights)
964    };
965
966    // Projection kind (histogram across all trials).
967    let pk_g = hist_kind(&good, &space.projection_kinds);
968    let pk_b = hist_kind(&bad, &space.projection_kinds);
969    let kind = space.projection_kinds[pick_idx(rng, &pk_g, &pk_b)];
970
971    // Kind-agnostic knobs: histograms deliberately pool ALL trials
972    // regardless of projection kind. Conditioning each knob on the
973    // sampled kind would shrink the histograms to near-uselessness at
974    // typical budgets — accepting cross-kind aliasing is the
975    // axis-parallel TPE simplification.
976    let ndg_g = hist_usize(&good, &space.num_domain_groups, |c| {
977        c.routing.num_domain_groups
978    });
979    let ndg_b = hist_usize(&bad, &space.num_domain_groups, |c| {
980        c.routing.num_domain_groups
981    });
982    let let_g = hist_f64(&good, &space.low_evr_threshold, |c| {
983        c.routing.low_evr_threshold
984    });
985    let let_b = hist_f64(&bad, &space.low_evr_threshold, |c| {
986        c.routing.low_evr_threshold
987    });
988    let oat_g = hist_f64(&good, &space.overlap_artifact_territorial, |c| {
989        c.bridges.overlap_artifact_territorial
990    });
991    let oat_b = hist_f64(&bad, &space.overlap_artifact_territorial, |c| {
992        c.bridges.overlap_artifact_territorial
993    });
994    let tb_g = hist_f64(&good, &space.threshold_base, |c| c.bridges.threshold_base);
995    let tb_b = hist_f64(&bad, &space.threshold_base, |c| c.bridges.threshold_base);
996    let tep_g = hist_f64(&good, &space.threshold_evr_penalty, |c| {
997        c.bridges.threshold_evr_penalty
998    });
999    let tep_b = hist_f64(&bad, &space.threshold_evr_penalty, |c| {
1000        c.bridges.threshold_evr_penalty
1001    });
1002    let mei_g = hist_f64(&good, &space.min_evr_improvement, |c| {
1003        c.inner_sphere.min_evr_improvement
1004    });
1005    let mei_b = hist_f64(&bad, &space.min_evr_improvement, |c| {
1006        c.inner_sphere.min_evr_improvement
1007    });
1008
1009    let mut cfg = base.clone();
1010    cfg.projection_kind = kind;
1011    cfg.routing = RoutingConfig {
1012        num_domain_groups: space.num_domain_groups[pick_idx(rng, &ndg_g, &ndg_b)],
1013        low_evr_threshold: space.low_evr_threshold[pick_idx(rng, &let_g, &let_b)],
1014        ..base.routing.clone()
1015    };
1016    cfg.bridges = BridgeConfig {
1017        threshold_base: space.threshold_base[pick_idx(rng, &tb_g, &tb_b)],
1018        threshold_evr_penalty: space.threshold_evr_penalty[pick_idx(rng, &tep_g, &tep_b)],
1019        overlap_artifact_territorial: space.overlap_artifact_territorial
1020            [pick_idx(rng, &oat_g, &oat_b)],
1021        ..base.bridges.clone()
1022    };
1023    cfg.inner_sphere = InnerSphereConfig {
1024        min_evr_improvement: space.min_evr_improvement[pick_idx(rng, &mei_g, &mei_b)],
1025        ..base.inner_sphere.clone()
1026    };
1027
1028    // Kind-specific knobs: condition on kind-matching trials only.
1029    if matches!(kind, ProjectionKind::LaplacianEigenmap) {
1030        let good_l: Vec<&TrialRecord> = good
1031            .iter()
1032            .copied()
1033            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1034            .collect();
1035        let bad_l: Vec<&TrialRecord> = bad
1036            .iter()
1037            .copied()
1038            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1039            .collect();
1040        if good_l.is_empty() || bad_l.is_empty() {
1041            // Not enough Laplacian trials on both sides — uniform fallback.
1042            cfg.laplacian = LaplacianConfig {
1043                k_neighbors: pick_uniform(rng, &space.laplacian_k_neighbors),
1044                active_threshold: pick_uniform(rng, &space.laplacian_active_threshold),
1045            };
1046        } else {
1047            let k_g = hist_usize(&good_l, &space.laplacian_k_neighbors, |c| {
1048                c.laplacian.k_neighbors
1049            });
1050            let k_b = hist_usize(&bad_l, &space.laplacian_k_neighbors, |c| {
1051                c.laplacian.k_neighbors
1052            });
1053            let at_g = hist_f64(&good_l, &space.laplacian_active_threshold, |c| {
1054                c.laplacian.active_threshold
1055            });
1056            let at_b = hist_f64(&bad_l, &space.laplacian_active_threshold, |c| {
1057                c.laplacian.active_threshold
1058            });
1059            cfg.laplacian = LaplacianConfig {
1060                k_neighbors: space.laplacian_k_neighbors[pick_idx(rng, &k_g, &k_b)],
1061                active_threshold: space.laplacian_active_threshold[pick_idx(rng, &at_g, &at_b)],
1062            };
1063        }
1064    }
1065
1066    if matches!(kind, ProjectionKind::UmapSphere) {
1067        let good_u: Vec<&TrialRecord> = good
1068            .iter()
1069            .copied()
1070            .filter(|t| t.config.projection_kind == ProjectionKind::UmapSphere)
1071            .collect();
1072        let bad_u: Vec<&TrialRecord> = bad
1073            .iter()
1074            .copied()
1075            .filter(|t| t.config.projection_kind == ProjectionKind::UmapSphere)
1076            .collect();
1077        if good_u.is_empty() || bad_u.is_empty() {
1078            cfg.umap = UmapConfig {
1079                n_neighbors: pick_uniform(rng, &space.umap_n_neighbors),
1080                n_epochs: pick_uniform(rng, &space.umap_n_epochs),
1081                category_weight: pick_uniform(rng, &space.umap_category_weight),
1082                min_dist: pick_uniform(rng, &space.umap_min_dist),
1083                ..base.umap.clone()
1084            };
1085        } else {
1086            let nn_g = hist_usize(&good_u, &space.umap_n_neighbors, |c| c.umap.n_neighbors);
1087            let nn_b = hist_usize(&bad_u, &space.umap_n_neighbors, |c| c.umap.n_neighbors);
1088            let ne_g = hist_usize(&good_u, &space.umap_n_epochs, |c| c.umap.n_epochs);
1089            let ne_b = hist_usize(&bad_u, &space.umap_n_epochs, |c| c.umap.n_epochs);
1090            let cw_g = hist_f64(&good_u, &space.umap_category_weight, |c| {
1091                c.umap.category_weight
1092            });
1093            let cw_b = hist_f64(&bad_u, &space.umap_category_weight, |c| {
1094                c.umap.category_weight
1095            });
1096            let md_g = hist_f64(&good_u, &space.umap_min_dist, |c| c.umap.min_dist);
1097            let md_b = hist_f64(&bad_u, &space.umap_min_dist, |c| c.umap.min_dist);
1098            cfg.umap = UmapConfig {
1099                n_neighbors: space.umap_n_neighbors[pick_idx(rng, &nn_g, &nn_b)],
1100                n_epochs: space.umap_n_epochs[pick_idx(rng, &ne_g, &ne_b)],
1101                category_weight: space.umap_category_weight[pick_idx(rng, &cw_g, &cw_b)],
1102                min_dist: space.umap_min_dist[pick_idx(rng, &md_g, &md_b)],
1103                ..base.umap.clone()
1104            };
1105        }
1106    }
1107
1108    cfg
1109}
1110
1111fn hist_kind(trials: &[&TrialRecord], values: &[ProjectionKind]) -> Vec<f64> {
1112    let mut counts = vec![0.0f64; values.len()];
1113    for t in trials {
1114        if let Some(i) = values.iter().position(|&v| v == t.config.projection_kind) {
1115            counts[i] += 1.0;
1116        }
1117    }
1118    counts
1119}
1120
1121fn hist_usize(
1122    trials: &[&TrialRecord],
1123    values: &[usize],
1124    extract: impl Fn(&PipelineConfig) -> usize,
1125) -> Vec<f64> {
1126    let mut counts = vec![0.0f64; values.len()];
1127    for t in trials {
1128        let v = extract(&t.config);
1129        if let Some(i) = values.iter().position(|&x| x == v) {
1130            counts[i] += 1.0;
1131        }
1132    }
1133    counts
1134}
1135
1136/// f64 candidates are matched by nearest-neighbor since equality on
1137/// floats is fraught even when every sampled value came from the same
1138/// source slice. In practice the match is always exact but this keeps
1139/// us honest under future refactors.
1140fn hist_f64(
1141    trials: &[&TrialRecord],
1142    values: &[f64],
1143    extract: impl Fn(&PipelineConfig) -> f64,
1144) -> Vec<f64> {
1145    let mut counts = vec![0.0f64; values.len()];
1146    for t in trials {
1147        let v = extract(&t.config);
1148        if let Some((i, _)) = values.iter().enumerate().min_by(|a, b| {
1149            (a.1 - v)
1150                .abs()
1151                .partial_cmp(&(b.1 - v).abs())
1152                .unwrap_or(std::cmp::Ordering::Equal)
1153        }) {
1154            counts[i] += 1.0;
1155        }
1156    }
1157    counts
1158}
1159
1160/// Pick one element of `vals` uniformly at random. Panics if `vals` is
1161/// empty — callers always pass non-empty `SearchSpace` axes, so the
1162/// empty case would be a programmer error rather than a recoverable
1163/// input.
1164fn pick_uniform<T: Copy>(rng: &mut SplitMix64, vals: &[T]) -> T {
1165    // next_f64 instead of next_u64 % len: the modulo form is biased
1166    // toward low indices whenever len doesn't divide 2^64. The min
1167    // guards the next_f64 == 1.0 edge.
1168    vals[((rng.next_f64() * vals.len() as f64) as usize).min(vals.len() - 1)]
1169}
1170
1171fn sample_categorical(rng: &mut SplitMix64, weights: &[f64]) -> usize {
1172    let total: f64 = weights.iter().sum();
1173    if total <= 0.0 || !total.is_finite() {
1174        let n = weights.len().max(1);
1175        return ((rng.next_f64() * n as f64) as usize).min(n - 1);
1176    }
1177    let r = rng.next_f64() * total;
1178    let mut acc = 0.0;
1179    for (i, &w) in weights.iter().enumerate() {
1180        acc += w;
1181        if r <= acc {
1182            return i;
1183        }
1184    }
1185    weights.len() - 1
1186}
1187
1188// ── Tests ──────────────────────────────────────────────────────────
1189
1190#[cfg(test)]
1191mod tests {
1192    use super::*;
1193    use crate::quality_metric::{BridgeCoherence, CompositeMetric, TerritorialHealth};
1194
1195    fn make_input(n: usize, dim: usize) -> PipelineInput {
1196        let mut embeddings = Vec::new();
1197        let mut categories = Vec::new();
1198        for i in 0..n {
1199            let mut v = vec![0.0; dim];
1200            if i < n / 3 {
1201                v[0] = 1.0 + (i as f64 * 0.01);
1202                v[1] = 0.1;
1203                categories.push("one".into());
1204            } else if i < 2 * n / 3 {
1205                v[2] = 1.0 + (i as f64 * 0.01);
1206                v[3] = 0.1;
1207                categories.push("two".into());
1208            } else {
1209                v[4] = 1.0 + (i as f64 * 0.01);
1210                v[5] = 0.1;
1211                categories.push("three".into());
1212            }
1213            v[6] = 0.02 * i as f64;
1214            embeddings.push(v);
1215        }
1216        PipelineInput {
1217            categories,
1218            embeddings,
1219        }
1220    }
1221
1222    fn full_search_space() -> SearchSpace {
1223        SearchSpace {
1224            projection_kinds: vec![ProjectionKind::Pca],
1225            laplacian_k_neighbors: vec![15],
1226            laplacian_active_threshold: vec![0.05],
1227            umap_n_neighbors: vec![15],
1228            umap_n_epochs: vec![200],
1229            umap_category_weight: vec![1.5],
1230            umap_min_dist: vec![0.1],
1231            num_domain_groups: vec![3],
1232            low_evr_threshold: vec![0.3],
1233            overlap_artifact_territorial: vec![0.3],
1234            threshold_base: vec![0.5],
1235            threshold_evr_penalty: vec![0.4],
1236            min_evr_improvement: vec![0.10],
1237        }
1238    }
1239
1240    #[test]
1241    fn validate_rejects_empty_projection_kinds_for_every_strategy() {
1242        let mut s = full_search_space();
1243        s.projection_kinds.clear();
1244        for strategy in [
1245            SearchStrategy::Grid,
1246            SearchStrategy::Random {
1247                budget: 4,
1248                seed: 1,
1249                max_wall_secs: None,
1250            },
1251            SearchStrategy::Bayesian {
1252                budget: 4,
1253                warmup: 2,
1254                gamma: 0.25,
1255                seed: 1,
1256                max_wall_secs: None,
1257            },
1258        ] {
1259            match s.validate(&strategy) {
1260                Err(PipelineError::InvalidSearchSpace(msg)) => {
1261                    assert!(msg.contains("projection_kinds"), "msg = {msg:?}");
1262                }
1263                other => panic!("expected InvalidSearchSpace, got {other:?}"),
1264            }
1265        }
1266    }
1267
1268    #[test]
1269    fn validate_rejects_empty_axis() {
1270        let mut s = full_search_space();
1271        s.threshold_base.clear();
1272        match s.validate(&SearchStrategy::Grid) {
1273            Err(PipelineError::InvalidSearchSpace(msg)) => {
1274                assert!(msg.contains("threshold_base"), "msg = {msg:?}");
1275            }
1276            other => panic!("expected InvalidSearchSpace, got {other:?}"),
1277        }
1278    }
1279
1280    #[test]
1281    fn validate_rejects_empty_laplacian_axis_only_when_kind_present() {
1282        let mut s = full_search_space();
1283        s.laplacian_k_neighbors.clear();
1284        // PCA-only space: missing laplacian axis is fine because the
1285        // kind isn't in `projection_kinds`.
1286        assert!(s.validate(&SearchStrategy::Grid).is_ok());
1287        s.projection_kinds.push(ProjectionKind::LaplacianEigenmap);
1288        match s.validate(&SearchStrategy::Grid) {
1289            Err(PipelineError::InvalidSearchSpace(msg)) => {
1290                assert!(msg.contains("laplacian_k_neighbors"), "msg = {msg:?}");
1291            }
1292            other => panic!("expected InvalidSearchSpace, got {other:?}"),
1293        }
1294    }
1295
1296    #[test]
1297    fn validate_rejects_bad_bayesian_params() {
1298        let s = full_search_space();
1299        let cases: &[(SearchStrategy, &str)] = &[
1300            (
1301                SearchStrategy::Bayesian {
1302                    budget: 1,
1303                    warmup: 2,
1304                    gamma: 0.25,
1305                    seed: 1,
1306                    max_wall_secs: None,
1307                },
1308                "budget",
1309            ),
1310            (
1311                SearchStrategy::Bayesian {
1312                    budget: 5,
1313                    warmup: 1,
1314                    gamma: 0.25,
1315                    seed: 1,
1316                    max_wall_secs: None,
1317                },
1318                "warmup",
1319            ),
1320            (
1321                SearchStrategy::Bayesian {
1322                    budget: 5,
1323                    warmup: 2,
1324                    gamma: 0.0,
1325                    seed: 1,
1326                    max_wall_secs: None,
1327                },
1328                "gamma",
1329            ),
1330            (
1331                SearchStrategy::Bayesian {
1332                    budget: 5,
1333                    warmup: 2,
1334                    gamma: f64::NAN,
1335                    seed: 1,
1336                    max_wall_secs: None,
1337                },
1338                "gamma",
1339            ),
1340        ];
1341        for (strategy, needle) in cases {
1342            match s.validate(strategy) {
1343                Err(PipelineError::InvalidSearchSpace(msg)) => {
1344                    assert!(msg.contains(needle), "msg={msg:?} needle={needle:?}");
1345                }
1346                other => panic!("expected InvalidSearchSpace for {needle:?}, got {other:?}"),
1347            }
1348        }
1349    }
1350
1351    #[test]
1352    fn auto_tune_propagates_invalid_search_space_for_grid() {
1353        let s = SearchSpace {
1354            projection_kinds: vec![],
1355            ..full_search_space()
1356        };
1357        let metric = BridgeCoherence;
1358        let base = PipelineConfig::default();
1359        match auto_tune(make_input(30, 10), &s, &metric, SearchStrategy::Grid, &base) {
1360            Err(PipelineError::InvalidSearchSpace(_)) => {}
1361            Err(other) => panic!("expected InvalidSearchSpace, got {other:?}"),
1362            Ok(_) => panic!("expected error, got Ok"),
1363        }
1364    }
1365
1366    #[test]
1367    fn search_space_grid_cardinality_sums_per_kind() {
1368        let s = SearchSpace::default();
1369        let common = s.num_domain_groups.len()
1370            * s.low_evr_threshold.len()
1371            * s.overlap_artifact_territorial.len()
1372            * s.threshold_base.len()
1373            * s.threshold_evr_penalty.len()
1374            * s.min_evr_improvement.len();
1375        // Default kinds = {PCA, Laplacian}; PCA adds `common`, Laplacian
1376        // adds `common × k_neighbors × active_threshold`.
1377        let expected =
1378            common + common * s.laplacian_k_neighbors.len() * s.laplacian_active_threshold.len();
1379        assert_eq!(s.grid_cardinality(), expected);
1380    }
1381
1382    #[test]
1383    fn default_search_space_includes_pca_and_laplacian() {
1384        let s = SearchSpace::default();
1385        assert!(s.projection_kinds.contains(&ProjectionKind::Pca));
1386        assert!(
1387            s.projection_kinds
1388                .contains(&ProjectionKind::LaplacianEigenmap)
1389        );
1390        // Kernel PCA excluded by default (expensive fit).
1391        assert!(!s.projection_kinds.contains(&ProjectionKind::KernelPca));
1392    }
1393
1394    #[test]
1395    fn grid_index_enumerates_full_space() {
1396        let s = SearchSpace {
1397            projection_kinds: vec![ProjectionKind::Pca],
1398            laplacian_k_neighbors: vec![15],
1399            laplacian_active_threshold: vec![0.05],
1400            umap_n_neighbors: vec![15],
1401            umap_n_epochs: vec![200],
1402            umap_category_weight: vec![1.5],
1403            umap_min_dist: vec![0.1],
1404            num_domain_groups: vec![3, 5],
1405            low_evr_threshold: vec![0.3, 0.4],
1406            overlap_artifact_territorial: vec![0.3],
1407            threshold_base: vec![0.5],
1408            threshold_evr_penalty: vec![0.4],
1409            min_evr_improvement: vec![0.10],
1410        };
1411        let base = PipelineConfig::default();
1412        let n = s.grid_cardinality();
1413        let mut seen = std::collections::HashSet::new();
1414        for i in 0..n {
1415            let cfg = s.config_at_index(i, &base).unwrap();
1416            let key = (
1417                cfg.routing.num_domain_groups,
1418                (cfg.routing.low_evr_threshold * 1000.0) as i64,
1419            );
1420            seen.insert(key);
1421        }
1422        assert_eq!(seen.len(), n);
1423        assert!(s.config_at_index(n, &base).is_none());
1424    }
1425
1426    #[test]
1427    fn grid_index_enumerates_across_projection_kinds() {
1428        let s = SearchSpace {
1429            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
1430            laplacian_k_neighbors: vec![15],
1431            laplacian_active_threshold: vec![0.05],
1432            umap_n_neighbors: vec![15],
1433            umap_n_epochs: vec![200],
1434            umap_category_weight: vec![1.5],
1435            umap_min_dist: vec![0.1],
1436            num_domain_groups: vec![3],
1437            low_evr_threshold: vec![0.35],
1438            overlap_artifact_territorial: vec![0.3],
1439            threshold_base: vec![0.5],
1440            threshold_evr_penalty: vec![0.4],
1441            min_evr_improvement: vec![0.10],
1442        };
1443        let base = PipelineConfig::default();
1444        let kinds: std::collections::HashSet<ProjectionKind> = (0..s.grid_cardinality())
1445            .map(|i| s.config_at_index(i, &base).unwrap().projection_kind)
1446            .collect();
1447        assert_eq!(kinds.len(), 2);
1448        assert!(kinds.contains(&ProjectionKind::Pca));
1449        assert!(kinds.contains(&ProjectionKind::LaplacianEigenmap));
1450    }
1451
1452    #[test]
1453    fn grid_search_runs_and_picks_best() {
1454        let input = make_input(24, 8);
1455        let space = SearchSpace {
1456            projection_kinds: vec![ProjectionKind::Pca],
1457            laplacian_k_neighbors: vec![15],
1458            laplacian_active_threshold: vec![0.05],
1459            umap_n_neighbors: vec![15],
1460            umap_n_epochs: vec![200],
1461            umap_category_weight: vec![1.5],
1462            umap_min_dist: vec![0.1],
1463            num_domain_groups: vec![3, 5],
1464            low_evr_threshold: vec![0.35],
1465            overlap_artifact_territorial: vec![0.3],
1466            threshold_base: vec![0.5],
1467            threshold_evr_penalty: vec![0.4],
1468            min_evr_improvement: vec![0.10],
1469        };
1470        let metric = TerritorialHealth;
1471        let (pipeline, report) = auto_tune(
1472            input,
1473            &space,
1474            &metric,
1475            SearchStrategy::Grid,
1476            &PipelineConfig::default(),
1477        )
1478        .unwrap();
1479
1480        assert_eq!(report.trials.len(), 2);
1481        assert!(report.best_score >= report.mean_score() - 1e-9);
1482        assert!(pipeline.num_categories() > 0);
1483        assert_eq!(report.metric_name, "territorial_health");
1484        assert!(report.failures.is_empty());
1485    }
1486
1487    #[test]
1488    fn trial_records_carry_component_breakdown_for_composites() {
1489        let input = make_input(24, 8);
1490        let metric = CompositeMetric::default_composite();
1491        let (_p, report) = auto_tune(
1492            input,
1493            &full_search_space(),
1494            &metric,
1495            SearchStrategy::Grid,
1496            &PipelineConfig::default(),
1497        )
1498        .unwrap();
1499        assert!(!report.trials.is_empty());
1500        for t in &report.trials {
1501            assert_eq!(
1502                t.components.len(),
1503                4,
1504                "composite trials must record the 4-component breakdown"
1505            );
1506            let recomposed: f64 = t.components.iter().map(|(_, w, s)| w * s).sum();
1507            assert!(
1508                (t.score - recomposed).abs() < 1e-12,
1509                "breakdown must recompose to the recorded score"
1510            );
1511        }
1512    }
1513
1514    #[test]
1515    fn trial_records_have_empty_components_for_leaf_metrics() {
1516        let input = make_input(24, 8);
1517        let metric = TerritorialHealth;
1518        let (_p, report) = auto_tune(
1519            input,
1520            &full_search_space(),
1521            &metric,
1522            SearchStrategy::Grid,
1523            &PipelineConfig::default(),
1524        )
1525        .unwrap();
1526        assert!(!report.trials.is_empty());
1527        for t in &report.trials {
1528            assert!(t.components.is_empty());
1529        }
1530    }
1531
1532    #[test]
1533    fn random_search_respects_budget() {
1534        let input = make_input(24, 8);
1535        let space = SearchSpace::default();
1536        let metric = BridgeCoherence;
1537        let (_pipeline, report) = auto_tune(
1538            input,
1539            &space,
1540            &metric,
1541            SearchStrategy::Random {
1542                budget: 5,
1543                seed: 42,
1544                max_wall_secs: None,
1545            },
1546            &PipelineConfig::default(),
1547        )
1548        .unwrap();
1549        assert_eq!(report.trials.len(), 5);
1550    }
1551
1552    #[test]
1553    fn random_search_respects_wall_time_cap() {
1554        let input = make_input(24, 8);
1555        let space = SearchSpace::default();
1556        let metric = TerritorialHealth;
1557        let (_pipeline, report) = auto_tune(
1558            input,
1559            &space,
1560            &metric,
1561            SearchStrategy::Random {
1562                budget: 1000,
1563                // Some(0) trips the cap deterministically: trial 0 (the
1564                // warm-start seed) always runs, then `wall_exceeded()` is
1565                // true immediately, so exactly one trial completes
1566                // regardless of host throughput. Some(1) would be racy —
1567                // a fast machine can finish all 1000 trials under a second.
1568                seed: 42,
1569                max_wall_secs: Some(0),
1570            },
1571            &PipelineConfig::default(),
1572        )
1573        .unwrap();
1574        assert!(
1575            report.trials.len() < 1000,
1576            "wall time cap should have stopped early, got {} trials",
1577            report.trials.len()
1578        );
1579        assert!(
1580            !report.trials.is_empty(),
1581            "should complete at least one trial before checking wall time"
1582        );
1583    }
1584
1585    #[test]
1586    fn none_wall_time_is_unlimited() {
1587        let input = make_input(24, 8);
1588        let space = full_search_space();
1589        let metric = TerritorialHealth;
1590        let (_pipeline, report) = auto_tune(
1591            input,
1592            &space,
1593            &metric,
1594            SearchStrategy::Random {
1595                budget: 3,
1596                seed: 1,
1597                max_wall_secs: None,
1598            },
1599            &PipelineConfig::default(),
1600        )
1601        .unwrap();
1602        assert_eq!(report.trials.len(), 3);
1603    }
1604
1605    #[test]
1606    fn random_search_is_seed_reproducible() {
1607        let space = SearchSpace::default();
1608        let metric = TerritorialHealth;
1609
1610        let run = |seed: u64| {
1611            let input = make_input(24, 8);
1612            auto_tune(
1613                input,
1614                &space,
1615                &metric,
1616                SearchStrategy::Random {
1617                    budget: 8,
1618                    seed,
1619                    max_wall_secs: None,
1620                },
1621                &PipelineConfig::default(),
1622            )
1623            .unwrap()
1624            .1
1625        };
1626
1627        let a = run(7);
1628        let b = run(7);
1629        let c = run(13);
1630
1631        assert_eq!(a.trials.len(), b.trials.len());
1632        for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1633            assert_eq!(
1634                ta.config.routing.num_domain_groups,
1635                tb.config.routing.num_domain_groups
1636            );
1637            assert!((ta.score - tb.score).abs() < 1e-12);
1638        }
1639        // Different seed should (very likely) produce a different trial
1640        // sequence. If it accidentally matches, the test is still valid
1641        // but we check at least one config differs.
1642        let any_differ = a.trials.iter().zip(c.trials.iter()).any(|(ta, tc)| {
1643            ta.config.routing.num_domain_groups != tc.config.routing.num_domain_groups
1644                || (ta.config.bridges.threshold_base - tc.config.bridges.threshold_base).abs()
1645                    > 1e-12
1646        });
1647        assert!(any_differ, "different seeds produced identical trial set");
1648    }
1649
1650    #[test]
1651    fn ranked_trials_are_descending() {
1652        let input = make_input(24, 8);
1653        let metric = CompositeMetric::default_composite();
1654        let (_p, report) = auto_tune(
1655            input,
1656            &SearchSpace::default(),
1657            &metric,
1658            SearchStrategy::Random {
1659                budget: 6,
1660                seed: 99,
1661                max_wall_secs: None,
1662            },
1663            &PipelineConfig::default(),
1664        )
1665        .unwrap();
1666        let ranked = report.ranked_trials();
1667        for w in ranked.windows(2) {
1668            assert!(w[0].score >= w[1].score);
1669        }
1670    }
1671
1672    #[test]
1673    fn best_config_actually_in_trials() {
1674        let input = make_input(24, 8);
1675        let metric = TerritorialHealth;
1676        let (_p, report) = auto_tune(
1677            input,
1678            &SearchSpace::default(),
1679            &metric,
1680            SearchStrategy::Random {
1681                budget: 4,
1682                seed: 1,
1683                max_wall_secs: None,
1684            },
1685            &PipelineConfig::default(),
1686        )
1687        .unwrap();
1688        let any_match = report.trials.iter().any(|t| {
1689            t.config.routing.num_domain_groups == report.best_config.routing.num_domain_groups
1690                && (t.config.routing.low_evr_threshold
1691                    - report.best_config.routing.low_evr_threshold)
1692                    .abs()
1693                    < 1e-12
1694                && (t.score - report.best_score).abs() < 1e-12
1695        });
1696        assert!(any_match, "best_config must appear in trials");
1697    }
1698
1699    #[test]
1700    fn grid_search_across_projection_kinds_yields_both() {
1701        let input = make_input(24, 8);
1702        let space = SearchSpace {
1703            projection_kinds: vec![ProjectionKind::Pca, ProjectionKind::LaplacianEigenmap],
1704            laplacian_k_neighbors: vec![10, 20],
1705            laplacian_active_threshold: vec![0.05],
1706            umap_n_neighbors: vec![15],
1707            umap_n_epochs: vec![200],
1708            umap_category_weight: vec![1.5],
1709            umap_min_dist: vec![0.1],
1710            num_domain_groups: vec![3],
1711            low_evr_threshold: vec![0.35],
1712            overlap_artifact_territorial: vec![0.3],
1713            threshold_base: vec![0.5],
1714            threshold_evr_penalty: vec![0.4],
1715            min_evr_improvement: vec![0.10],
1716        };
1717        let metric = TerritorialHealth;
1718        let (_pipeline, report) = auto_tune(
1719            input,
1720            &space,
1721            &metric,
1722            SearchStrategy::Grid,
1723            &PipelineConfig::default(),
1724        )
1725        .unwrap();
1726        // PCA contributes 1 trial; Laplacian contributes 2 × 1 = 2 trials
1727        // (two k_neighbors values × one threshold value). Total = 3.
1728        assert_eq!(report.trials.len(), 3);
1729        let kinds_in_trials: std::collections::HashSet<ProjectionKind> = report
1730            .trials
1731            .iter()
1732            .map(|t| t.config.projection_kind)
1733            .collect();
1734        assert!(kinds_in_trials.contains(&ProjectionKind::Pca));
1735        assert!(kinds_in_trials.contains(&ProjectionKind::LaplacianEigenmap));
1736        // Verify the two Laplacian trials actually use different k values.
1737        let lap_ks: std::collections::HashSet<usize> = report
1738            .trials
1739            .iter()
1740            .filter(|t| t.config.projection_kind == ProjectionKind::LaplacianEigenmap)
1741            .map(|t| t.config.laplacian.k_neighbors)
1742            .collect();
1743        assert_eq!(lap_ks.len(), 2);
1744    }
1745
1746    #[test]
1747    fn laplacian_knobs_produce_distinct_configs() {
1748        // Sanity check that when Laplacian is the only kind, varying its
1749        // hyperparameters produces configs whose LaplacianConfig actually
1750        // differs (and doesn't accidentally alias on same-(k, threshold) pairs).
1751        let s = SearchSpace {
1752            projection_kinds: vec![ProjectionKind::LaplacianEigenmap],
1753            laplacian_k_neighbors: vec![10, 20],
1754            laplacian_active_threshold: vec![0.03, 0.08],
1755            umap_n_neighbors: vec![15],
1756            umap_n_epochs: vec![200],
1757            umap_category_weight: vec![1.5],
1758            umap_min_dist: vec![0.1],
1759            num_domain_groups: vec![3],
1760            low_evr_threshold: vec![0.35],
1761            overlap_artifact_territorial: vec![0.3],
1762            threshold_base: vec![0.5],
1763            threshold_evr_penalty: vec![0.4],
1764            min_evr_improvement: vec![0.10],
1765        };
1766        let base = PipelineConfig::default();
1767        let configs: Vec<(usize, u64)> = (0..s.grid_cardinality())
1768            .map(|i| {
1769                let cfg = s.config_at_index(i, &base).unwrap();
1770                (
1771                    cfg.laplacian.k_neighbors,
1772                    cfg.laplacian.active_threshold.to_bits(),
1773                )
1774            })
1775            .collect();
1776        let unique: std::collections::HashSet<(usize, u64)> = configs.iter().copied().collect();
1777        assert_eq!(unique.len(), 4, "expected 4 distinct (k, threshold) pairs");
1778    }
1779
1780    #[test]
1781    fn bayesian_respects_budget() {
1782        let input = make_input(24, 8);
1783        let metric = TerritorialHealth;
1784        let (_p, report) = auto_tune(
1785            input,
1786            &SearchSpace::default(),
1787            &metric,
1788            SearchStrategy::Bayesian {
1789                budget: 10,
1790                warmup: 4,
1791                gamma: 0.25,
1792                seed: 42,
1793                max_wall_secs: None,
1794            },
1795            &PipelineConfig::default(),
1796        )
1797        .unwrap();
1798        assert_eq!(report.trials.len(), 10);
1799    }
1800
1801    #[test]
1802    fn bayesian_seed_reproducible() {
1803        let metric = TerritorialHealth;
1804        let run = |seed: u64| {
1805            let input = make_input(24, 8);
1806            auto_tune(
1807                input,
1808                &SearchSpace::default(),
1809                &metric,
1810                SearchStrategy::Bayesian {
1811                    budget: 8,
1812                    warmup: 3,
1813                    gamma: 0.25,
1814                    seed,
1815                    max_wall_secs: None,
1816                },
1817                &PipelineConfig::default(),
1818            )
1819            .unwrap()
1820            .1
1821        };
1822        let a = run(7);
1823        let b = run(7);
1824        assert_eq!(a.trials.len(), b.trials.len());
1825        for (ta, tb) in a.trials.iter().zip(b.trials.iter()) {
1826            assert_eq!(ta.config.projection_kind, tb.config.projection_kind);
1827            assert!((ta.score - tb.score).abs() < 1e-12);
1828        }
1829    }
1830
1831    #[test]
1832    fn bayesian_finds_something_under_default_metric() {
1833        // Only asserting the tuner runs to completion and best_score is a
1834        // valid [0, 1] value — not that Bayesian strictly beats random at
1835        // this small budget (it often does, but not monotonically).
1836        let input = make_input(30, 10);
1837        let metric = CompositeMetric::default_composite();
1838        let (_p, report) = auto_tune(
1839            input,
1840            &SearchSpace::default(),
1841            &metric,
1842            SearchStrategy::Bayesian {
1843                budget: 12,
1844                warmup: 4,
1845                gamma: 0.25,
1846                seed: 0xC0FFEE,
1847                max_wall_secs: None,
1848            },
1849            &PipelineConfig::default(),
1850        )
1851        .unwrap();
1852        assert_eq!(report.trials.len(), 12);
1853        assert!(report.best_score >= 0.0 && report.best_score <= 1.0);
1854    }
1855
1856    #[test]
1857    fn bayesian_warmup_clamped() {
1858        // warmup = 100 with budget = 5 should clamp to 5 (all warmup).
1859        let input = make_input(24, 8);
1860        let metric = TerritorialHealth;
1861        let (_p, report) = auto_tune(
1862            input,
1863            &SearchSpace::default(),
1864            &metric,
1865            SearchStrategy::Bayesian {
1866                budget: 5,
1867                warmup: 100,
1868                gamma: 0.25,
1869                seed: 1,
1870                max_wall_secs: None,
1871            },
1872            &PipelineConfig::default(),
1873        )
1874        .unwrap();
1875        assert_eq!(report.trials.len(), 5);
1876    }
1877
1878    #[test]
1879    fn umap_search_space_cardinality() {
1880        let s = SearchSpace::large_corpus();
1881        let common = s.num_domain_groups.len()
1882            * s.low_evr_threshold.len()
1883            * s.overlap_artifact_territorial.len()
1884            * s.threshold_base.len()
1885            * s.threshold_evr_penalty.len()
1886            * s.min_evr_improvement.len();
1887        let umap_specific = s.umap_n_neighbors.len()
1888            * s.umap_n_epochs.len()
1889            * s.umap_category_weight.len()
1890            * s.umap_min_dist.len();
1891        // PCA contributes `common`, UMAP contributes `common * umap_specific`.
1892        let expected = common + common * umap_specific;
1893        assert_eq!(s.grid_cardinality(), expected);
1894    }
1895
1896    #[test]
1897    fn umap_trials_produce_umap_configs() {
1898        let input = make_input(24, 8);
1899        let space = SearchSpace {
1900            projection_kinds: vec![ProjectionKind::UmapSphere],
1901            laplacian_k_neighbors: vec![15],
1902            laplacian_active_threshold: vec![0.05],
1903            umap_n_neighbors: vec![10, 20],
1904            umap_n_epochs: vec![50],
1905            umap_category_weight: vec![1.0],
1906            umap_min_dist: vec![0.1],
1907            num_domain_groups: vec![3],
1908            low_evr_threshold: vec![0.35],
1909            overlap_artifact_territorial: vec![0.3],
1910            threshold_base: vec![0.5],
1911            threshold_evr_penalty: vec![0.4],
1912            min_evr_improvement: vec![0.10],
1913        };
1914        let metric = TerritorialHealth;
1915        let (_pipeline, report) = auto_tune(
1916            input,
1917            &space,
1918            &metric,
1919            SearchStrategy::Grid,
1920            &PipelineConfig::default(),
1921        )
1922        .unwrap();
1923
1924        assert_eq!(report.trials.len(), 2);
1925        for t in &report.trials {
1926            assert_eq!(t.config.projection_kind, ProjectionKind::UmapSphere);
1927        }
1928        let nn_values: std::collections::HashSet<usize> = report
1929            .trials
1930            .iter()
1931            .map(|t| t.config.umap.n_neighbors)
1932            .collect();
1933        assert_eq!(nn_values.len(), 2);
1934    }
1935
1936    #[test]
1937    fn umap_graph_cache_reuses_across_trials_sharing_n_neighbors() {
1938        // Six UMAP configs all share `n_neighbors = 10` and differ only in
1939        // `n_epochs` × `category_weight`. The kNN graph + PCA warm-start
1940        // should be built once, then reused — `umap_graph_builds` must
1941        // equal the number of distinct `n_neighbors` values (= 1), not
1942        // the number of trials.
1943        let input = make_input(24, 8);
1944        let space = SearchSpace {
1945            projection_kinds: vec![ProjectionKind::UmapSphere],
1946            laplacian_k_neighbors: vec![15],
1947            laplacian_active_threshold: vec![0.05],
1948            umap_n_neighbors: vec![10],
1949            umap_n_epochs: vec![30, 60],
1950            umap_category_weight: vec![0.0, 1.0, 2.0],
1951            umap_min_dist: vec![0.1],
1952            num_domain_groups: vec![3],
1953            low_evr_threshold: vec![0.35],
1954            overlap_artifact_territorial: vec![0.3],
1955            threshold_base: vec![0.5],
1956            threshold_evr_penalty: vec![0.4],
1957            min_evr_improvement: vec![0.10],
1958        };
1959        let metric = TerritorialHealth;
1960        let (_pipeline, report) = auto_tune(
1961            input,
1962            &space,
1963            &metric,
1964            SearchStrategy::Grid,
1965            &PipelineConfig::default(),
1966        )
1967        .unwrap();
1968
1969        assert_eq!(report.trials.len(), 6, "6 UMAP configs in the grid");
1970        assert_eq!(
1971            report.umap_graph_builds, 1,
1972            "all 6 configs share n_neighbors=10, so the cache should build the graph exactly once"
1973        );
1974    }
1975
1976    #[test]
1977    fn umap_graph_cache_builds_one_per_unique_n_neighbors() {
1978        // Two distinct `n_neighbors` values × two `n_epochs` = 4 UMAP
1979        // trials. The cache builds the graph once per unique
1980        // `n_neighbors`, so `umap_graph_builds` should equal 2.
1981        let input = make_input(24, 8);
1982        let space = SearchSpace {
1983            projection_kinds: vec![ProjectionKind::UmapSphere],
1984            laplacian_k_neighbors: vec![15],
1985            laplacian_active_threshold: vec![0.05],
1986            umap_n_neighbors: vec![10, 20],
1987            umap_n_epochs: vec![30, 60],
1988            umap_category_weight: vec![0.0],
1989            umap_min_dist: vec![0.1],
1990            num_domain_groups: vec![3],
1991            low_evr_threshold: vec![0.35],
1992            overlap_artifact_territorial: vec![0.3],
1993            threshold_base: vec![0.5],
1994            threshold_evr_penalty: vec![0.4],
1995            min_evr_improvement: vec![0.10],
1996        };
1997        let metric = TerritorialHealth;
1998        let (_pipeline, report) = auto_tune(
1999            input,
2000            &space,
2001            &metric,
2002            SearchStrategy::Grid,
2003            &PipelineConfig::default(),
2004        )
2005        .unwrap();
2006
2007        assert_eq!(report.trials.len(), 4);
2008        assert_eq!(
2009            report.umap_graph_builds, 2,
2010            "n_neighbors ∈ {{10, 20}} should produce exactly 2 graph builds"
2011        );
2012    }
2013
2014    #[test]
2015    fn umap_fit_key_distinguishes_min_dist() {
2016        // min_dist changes the optimizer's kernel (a, b), so two configs
2017        // that differ only there must not share a cached fitted
2018        // projection. They still share the kNN graph (keyed by
2019        // n_neighbors alone), which is built before the optimizer.
2020        let a = PipelineConfig {
2021            projection_kind: ProjectionKind::UmapSphere,
2022            ..PipelineConfig::default()
2023        };
2024        let mut b = a.clone();
2025        b.umap.min_dist = 0.5;
2026        assert!(ProjectionFitKey::from_config(&a) == ProjectionFitKey::from_config(&a.clone()));
2027        assert!(ProjectionFitKey::from_config(&a) != ProjectionFitKey::from_config(&b));
2028    }
2029
2030    #[test]
2031    fn umap_graph_cache_zero_when_no_umap_trials() {
2032        // PCA-only search space — no UMAP trials, no graph builds.
2033        let input = make_input(24, 8);
2034        let space = SearchSpace {
2035            projection_kinds: vec![ProjectionKind::Pca],
2036            laplacian_k_neighbors: vec![15],
2037            laplacian_active_threshold: vec![0.05],
2038            umap_n_neighbors: vec![10],
2039            umap_n_epochs: vec![30],
2040            umap_category_weight: vec![0.0],
2041            umap_min_dist: vec![0.1],
2042            num_domain_groups: vec![3],
2043            low_evr_threshold: vec![0.35],
2044            overlap_artifact_territorial: vec![0.3],
2045            threshold_base: vec![0.5],
2046            threshold_evr_penalty: vec![0.4],
2047            min_evr_improvement: vec![0.10],
2048        };
2049        let metric = TerritorialHealth;
2050        let (_pipeline, report) = auto_tune(
2051            input,
2052            &space,
2053            &metric,
2054            SearchStrategy::Grid,
2055            &PipelineConfig::default(),
2056        )
2057        .unwrap();
2058
2059        assert_eq!(report.umap_graph_builds, 0);
2060    }
2061
2062    #[test]
2063    fn validate_rejects_empty_umap_axis_only_when_kind_present() {
2064        let mut s = full_search_space();
2065        s.umap_n_neighbors.clear();
2066        // PCA-only space: missing UMAP axis is fine.
2067        assert!(s.validate(&SearchStrategy::Grid).is_ok());
2068        s.projection_kinds.push(ProjectionKind::UmapSphere);
2069        match s.validate(&SearchStrategy::Grid) {
2070            Err(PipelineError::InvalidSearchSpace(msg)) => {
2071                assert!(msg.contains("umap_n_neighbors"), "msg = {msg:?}");
2072            }
2073            other => panic!("expected InvalidSearchSpace, got {other:?}"),
2074        }
2075    }
2076
2077    #[test]
2078    fn tpe_proposes_dominating_value_more_often_than_uniform() {
2079        // Hand-crafted history: every top-gamma trial used
2080        // num_domain_groups = 7, every bad trial used 3 or 5. The
2081        // acquisition should propose 7 far more often than the uniform
2082        // 1/3 baseline.
2083        let space = SearchSpace {
2084            projection_kinds: vec![ProjectionKind::Pca],
2085            laplacian_k_neighbors: vec![15],
2086            laplacian_active_threshold: vec![0.05],
2087            umap_n_neighbors: vec![15],
2088            umap_n_epochs: vec![200],
2089            umap_category_weight: vec![1.5],
2090            umap_min_dist: vec![0.1],
2091            num_domain_groups: vec![3, 5, 7],
2092            low_evr_threshold: vec![0.3],
2093            overlap_artifact_territorial: vec![0.3],
2094            threshold_base: vec![0.5],
2095            threshold_evr_penalty: vec![0.4],
2096            min_evr_improvement: vec![0.10],
2097        };
2098        let base = PipelineConfig::default();
2099
2100        let trial = |ndg: usize, score: f64| -> TrialRecord {
2101            let mut config = base.clone();
2102            config.projection_kind = ProjectionKind::Pca;
2103            config.routing.num_domain_groups = ndg;
2104            TrialRecord {
2105                config,
2106                score,
2107                build_ms: 0,
2108                components: Vec::new(),
2109            }
2110        };
2111
2112        let mut trials = Vec::new();
2113        for i in 0..4 {
2114            trials.push(trial(7, 0.9 + i as f64 * 0.01));
2115        }
2116        for i in 0..6 {
2117            trials.push(trial(3, 0.1 + i as f64 * 0.01));
2118            trials.push(trial(5, 0.1 + i as f64 * 0.005));
2119        }
2120
2121        let mut rng = SplitMix64::new(42);
2122        let n_proposals = 300;
2123        let mut count_7 = 0usize;
2124        for _ in 0..n_proposals {
2125            let cfg = tpe_propose(&space, &base, &trials, 0.25, &mut rng);
2126            if cfg.routing.num_domain_groups == 7 {
2127                count_7 += 1;
2128            }
2129        }
2130
2131        // Uniform would land near 100/300. The good/bad ratio for 7 puts
2132        // its sampling probability above 0.9, so 180 is a comfortable
2133        // margin that still fails if the acquisition stops conditioning
2134        // on the split.
2135        assert!(
2136            count_7 > 180,
2137            "dominating value proposed only {count_7}/{n_proposals} times (uniform ≈ {})",
2138            n_proposals / 3
2139        );
2140    }
2141
2142    #[test]
2143    fn random_seeds_base_config_as_trial_zero() {
2144        let input = make_input(24, 8);
2145        let mut base = PipelineConfig::default();
2146        base.bridges.overlap_artifact_territorial = 0.123; // off-axis
2147        let metric = TerritorialHealth;
2148        let (_p, report) = auto_tune(
2149            input,
2150            &full_search_space(),
2151            &metric,
2152            SearchStrategy::Random {
2153                budget: 4,
2154                seed: 9,
2155                max_wall_secs: None,
2156            },
2157            &base,
2158        )
2159        .unwrap();
2160
2161        assert_eq!(report.trials.len(), 4, "seed trial counts against budget");
2162        assert!(
2163            (report.trials[0].config.bridges.overlap_artifact_territorial - 0.123).abs() < 1e-12,
2164            "trial 0 must be base_config itself"
2165        );
2166        for t in &report.trials[1..] {
2167            assert!(
2168                (t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-12,
2169                "sampled trials must come from the space's axes"
2170            );
2171        }
2172    }
2173
2174    #[test]
2175    fn bayesian_seeds_base_config_as_trial_zero() {
2176        let input = make_input(24, 8);
2177        let mut base = PipelineConfig::default();
2178        base.bridges.overlap_artifact_territorial = 0.123;
2179        let metric = TerritorialHealth;
2180        let (_p, report) = auto_tune(
2181            input,
2182            &full_search_space(),
2183            &metric,
2184            SearchStrategy::Bayesian {
2185                budget: 5,
2186                warmup: 2,
2187                gamma: 0.25,
2188                seed: 9,
2189                max_wall_secs: None,
2190            },
2191            &base,
2192        )
2193        .unwrap();
2194
2195        assert_eq!(report.trials.len(), 5);
2196        assert!(
2197            (report.trials[0].config.bridges.overlap_artifact_territorial - 0.123).abs() < 1e-12
2198        );
2199    }
2200
2201    #[test]
2202    fn grid_does_not_seed_base_config() {
2203        let input = make_input(24, 8);
2204        let mut base = PipelineConfig::default();
2205        base.bridges.overlap_artifact_territorial = 0.123;
2206        let metric = TerritorialHealth;
2207        let (_p, report) = auto_tune(
2208            input,
2209            &full_search_space(),
2210            &metric,
2211            SearchStrategy::Grid,
2212            &base,
2213        )
2214        .unwrap();
2215
2216        assert_eq!(
2217            report.trials.len(),
2218            full_search_space().grid_cardinality(),
2219            "grid trial count must stay the exact enumeration"
2220        );
2221        for t in &report.trials {
2222            assert!((t.config.bridges.overlap_artifact_territorial - 0.3).abs() < 1e-12);
2223        }
2224    }
2225
2226    #[test]
2227    fn returned_pipeline_uses_best_config() {
2228        let input = make_input(24, 8);
2229        let metric = TerritorialHealth;
2230        let (pipeline, report) = auto_tune(
2231            input,
2232            &SearchSpace::default(),
2233            &metric,
2234            SearchStrategy::Random {
2235                budget: 4,
2236                seed: 11,
2237                max_wall_secs: None,
2238            },
2239            &PipelineConfig::default(),
2240        )
2241        .unwrap();
2242        assert_eq!(
2243            pipeline.config().routing.num_domain_groups,
2244            report.best_config.routing.num_domain_groups
2245        );
2246        assert_eq!(
2247            pipeline.projection_kind(),
2248            report.best_config.projection_kind
2249        );
2250    }
2251}