Skip to main content

gam_solve/
row_sampling_measure.rs

1//! `RowSamplingMeasure` — the Fisher-mass **enrichment** producer (role (c) of #980).
2//!
3//! # What this is, and what it must never be
4//!
5//! A [`RowSamplingMeasure`] turns a [`RowMetric`] into a per-row **sampling measure**:
6//! a normalized non-negative weight per row, proportional to that row's
7//! behavioral *liveness* (its output-Fisher mass). It exists for **discovery /
8//! seeding only** — to OVERSAMPLE the behaviorally-live rows so that a rare but
9//! behaviorally-important feature (few rows, high Fisher mass, drowned among
10//! many common low-coupling rows) is actually *seen* by a discovery batch.
11//!
12//! ## The load-bearing invariant
13//!
14//! **The measure NEVER enters the reconstruction loss, the gradient, the
15//! evidence criterion, or any optimizer-facing quantity.** Sampling ADDS
16//! attention; it never reweights representation. Concretely:
17//!
18//! * it does not multiply any residual, any `quad_form`, any whitened Jacobian,
19//!   or any penalty;
20//! * it does not feed REML / LAML, the ρ trust-region ratio, or `φ̂`;
21//! * it only chooses *which rows a discovery/seeding pass looks at first*, and
22//!   how many times, leaving every per-row loss bit-for-bit unchanged.
23//!
24//! This is the dual of the #980 failure mode (where an output-Fisher inner
25//! product silently replaced the reconstruction loss): here the Fisher mass is
26//! used *strictly* as an attention prior over rows, with the loss untouched.
27//! The enrichment ordering returns row indices with multiplicity — the consumer
28//! visits those rows for *seeding/proposal* purposes; the fit it then runs on
29//! any selected row uses the unmodified per-row objective.
30//!
31//! # Graceful degradation (absent harvest ⇒ today's behavior)
32//!
33//! The measure is **magic-by-default**, mirroring [`RowMetric`]:
34//!
35//! * [`MetricProvenance::Euclidean`] (no per-row Fisher factors were harvested)
36//!   ⇒ every row's liveness is identical (`tr(I_p) = p`), so the measure is
37//!   **exactly uniform** and the enrichment ordering is the plain index order
38//!   with uniform multiplicity. Absent harvest is therefore bit-for-bit today's
39//!   "look at every row equally" behavior, never an error.
40//! * A factored provenance ([`MetricProvenance::OutputFisher`] /
41//!   [`MetricProvenance::WhitenedStructured`]) ⇒ rows are weighted by their
42//!   `tr(M_n)` Fisher mass, oversampling the live rows.
43//!
44//! Any pathological metric (all-zero mass, a non-finite block) also degrades to
45//! the uniform measure rather than producing a degenerate or `NaN` sampling
46//! distribution.
47//!
48//! # Why `tr(M_n)` is the right liveness scalar
49//!
50//! The per-row metric `M_n = U_n U_nᵀ` is the output-Fisher inner product on
51//! latent motion at row `n`. Its trace `tr(M_n) = Σ_i e_iᵀ M_n e_i =
52//! Σ_i fisher_mass(n, e_i)` is the total behavioral mass of that row summed over
53//! output coordinates — basis-independent and exactly the quantity
54//! [`RowMetric::fisher_mass`] reports for a unit of motion along each axis. It
55//! is the canonical row liveness derivable from the metric *alone*, with no
56//! external tangent supplied, and it collapses to the constant `p` under
57//! Euclidean — which is precisely the uniform-measure degeneracy we want.
58
59use gam_problem::{MetricProvenance, RowMetric};
60use gam_linalg::faer_ndarray::{FaerEigh, FaerSvd};
61use gam_linalg::utils::splitmix64_hash;
62use faer::Side;
63use ndarray::{Array2, ArrayView2};
64
65/// Where a [`RowSamplingMeasure`] came from — the honest record of whether the
66/// enrichment is real (Fisher-mass driven) or the graceful uniform fallback.
67#[derive(Clone, Copy, PartialEq, Eq, Debug)]
68pub enum MeasureProvenance {
69    /// No behavioral signal was available (Euclidean metric, or a degenerate
70    /// metric that produced no usable mass). The measure is exactly uniform:
71    /// every row carries weight `1 / n`. This is bit-for-bit "look at every row
72    /// equally" — today's behavior with no harvest.
73    Uniform,
74    /// The measure is `∝ tr(M_n)` from a factored [`RowMetric`]. Behaviorally
75    /// live rows carry proportionally more sampling weight. The carried
76    /// [`MetricProvenance`] is the metric provenance that produced the mass, so
77    /// a consumer can certify the inner product behind the enrichment.
78    FisherMass(MetricProvenance),
79}
80
81/// A per-row **sampling measure** over `n` rows, normalized to sum to 1.
82///
83/// Built from a [`RowMetric`] via [`RowSamplingMeasure::from_metric`]. The weights are a
84/// proper probability measure (non-negative, finite, summing to 1) used for
85/// **discovery/seeding oversampling only** — see the module docs for the
86/// invariant that it touches no loss / gradient / criterion.
87#[derive(Clone, Debug)]
88pub struct RowSamplingMeasure {
89    provenance: MeasureProvenance,
90    /// Normalized per-row sampling weights; `weights.len() == n_rows` and
91    /// `Σ weights == 1` (exactly uniform `1/n` in the fallback).
92    weights: Vec<f64>,
93}
94
95/// Certified coreset error budget carried to race consumers.
96#[derive(Clone, Copy, Debug, PartialEq)]
97pub struct CoresetCertificate {
98    /// Spectral approximation radius for the log-determinant term:
99    /// `(1 - eps_spectral)H <= H_C <= (1 + eps_spectral)H` on the effective
100    /// eigenspace.
101    pub eps_spectral: f64,
102    /// Additive likelihood error radius supplied by the sensitivity coreset on
103    /// its documented chart ball.
104    pub eps_likelihood: f64,
105    /// Rank of the factored border plus active-coordinate subspace actually
106    /// certified. Null directions of the summed row sketch are excluded.
107    pub dim_effective: usize,
108    /// Number of distinct rows retained by the coreset.
109    pub n_selected: usize,
110}
111
112impl CoresetCertificate {
113    pub fn new(
114        eps_spectral: f64,
115        eps_likelihood: f64,
116        dim_effective: usize,
117        n_selected: usize,
118    ) -> Result<Self, String> {
119        if !(eps_spectral.is_finite() && eps_spectral >= 0.0 && eps_spectral < 1.0) {
120            return Err(format!(
121                "coreset certificate requires 0 <= eps_spectral < 1, got {eps_spectral}"
122            ));
123        }
124        if !(eps_likelihood.is_finite() && eps_likelihood >= 0.0) {
125            return Err(format!(
126                "coreset certificate requires finite non-negative eps_likelihood, got {eps_likelihood}"
127            ));
128        }
129        Ok(Self {
130            eps_spectral,
131            eps_likelihood,
132            dim_effective,
133            n_selected,
134        })
135    }
136
137    /// Worst-case log-determinant transfer error implied by the spectral
138    /// certificate.
139    pub fn logdet_error_bound(&self) -> f64 {
140        self.dim_effective as f64 * ((1.0 + self.eps_spectral) / (1.0 - self.eps_spectral)).ln()
141    }
142
143    /// Race-transfer margin: consumers must require a coreset decision margin
144    /// strictly above this value before inheriting the full-corpus verdict.
145    pub fn race_transfer_margin(&self) -> f64 {
146        2.0 * (self.logdet_error_bound() + self.eps_likelihood)
147    }
148
149    /// Explicit verdict for a proposed coreset race margin. Consumers should
150    /// propagate [`CoresetMarginVerdict::InsufficientMargin`] instead of making
151    /// a silent decision below the certificate margin.
152    pub fn certify_margin(&self, decision_margin: f64) -> CoresetMarginVerdict {
153        let required_margin = self.race_transfer_margin();
154        if decision_margin.is_finite() && decision_margin > required_margin {
155            CoresetMarginVerdict::Certified {
156                decision_margin,
157                required_margin,
158            }
159        } else {
160            CoresetMarginVerdict::InsufficientMargin {
161                decision_margin,
162                required_margin,
163            }
164        }
165    }
166}
167
168/// Certificate gate for coreset-backed race decisions.
169#[derive(Clone, Copy, Debug, PartialEq)]
170pub enum CoresetMarginVerdict {
171    Certified {
172        decision_margin: f64,
173        required_margin: f64,
174    },
175    InsufficientMargin {
176        decision_margin: f64,
177        required_margin: f64,
178    },
179}
180
181/// Output of deterministic BSS spectral row selection.
182#[derive(Clone, Debug, PartialEq)]
183pub struct SpectralCoreset {
184    /// Distinct selected row indices, ascending.
185    pub indices: Vec<usize>,
186    /// Non-negative row weights aligned with `indices`.
187    pub weights: Vec<f64>,
188    /// Spectral certificate for this row coreset. `eps_likelihood` is zero here;
189    /// combine with a sensitivity certificate before certifying full evidence.
190    pub certificate: CoresetCertificate,
191}
192
193/// Deterministic Batson-Spielman-Srivastava spectral row coreset.
194///
195/// Each input item is a small row factor `R_i` with contribution
196/// `H_i = R_i.t() R_i`. Selection is run on the effective eigenspace of
197/// `sum_i H_i`; rank-null directions are excluded from the certificate. The
198/// algorithm whitens the factors into that effective space, then applies the
199/// standard two-barrier BSS potential update with deterministic row-index
200/// tie-breaking. Per-row dense `H_i` blocks are never materialized.
201///
202/// Deterministic BSS spectral row coreset with the attached certificate.
203pub fn bss_spectral_coreset_certified<'a, I>(
204    rows: I,
205    target_eps: f64,
206) -> Result<SpectralCoreset, String>
207where
208    I: IntoIterator<Item = ArrayView2<'a, f64>>,
209{
210    if !(target_eps.is_finite() && target_eps > 0.0 && target_eps < 1.0) {
211        return Err(format!(
212            "BSS spectral coreset requires 0 < target_eps < 1, got {target_eps}"
213        ));
214    }
215
216    let factors = collect_row_factors(rows)?;
217    let n = factors.len();
218    if n == 0 {
219        let certificate = CoresetCertificate::new(target_eps, 0.0, 0, 0)?;
220        return Ok(SpectralCoreset {
221            indices: Vec::new(),
222            weights: Vec::new(),
223            certificate,
224        });
225    }
226
227    let ambient_dim = factors[0].ncols();
228    let effective = stacked_factor_whitener(&factors, ambient_dim)?;
229    let dim = effective.ncols();
230    if dim == 0 {
231        let certificate = CoresetCertificate::new(target_eps, 0.0, 0, 0)?;
232        return Ok(SpectralCoreset {
233            indices: Vec::new(),
234            weights: Vec::new(),
235            certificate,
236        });
237    }
238
239    let whitened = whiten_row_factors(&factors, &effective);
240    let eta = 0.5 * target_eps;
241    let steps = ((dim as f64) / (eta * eta)).ceil().max(dim as f64) as usize;
242    let delta_lower = 1.0_f64;
243    let delta_upper = (1.0 + eta) / (1.0 - eta);
244    let root = (steps as f64 * dim as f64).sqrt();
245    let mut barrier_matrix = Array2::<f64>::zeros((dim, dim));
246    let mut row_weights = vec![0.0_f64; n];
247
248    for step in 0..steps {
249        let lower = step as f64 - root;
250        let upper = delta_upper * (step as f64 + root);
251        let lower_next = lower + delta_lower;
252        let upper_next = upper + delta_upper;
253
254        let lower_inv = inverse_shifted_lower(&barrier_matrix, lower_next)?;
255        let upper_inv = inverse_shifted_upper(&barrier_matrix, upper_next)?;
256        let lower_denom = lower_potential(&barrier_matrix, lower_next)?
257            - lower_potential(&barrier_matrix, lower)?;
258        let upper_denom = upper_potential(&barrier_matrix, upper)?
259            - upper_potential(&barrier_matrix, upper_next)?;
260        if !(lower_denom.is_finite() && lower_denom > 0.0) {
261            return Err(format!(
262                "BSS lower potential denominator became invalid at step {step}: {lower_denom}"
263            ));
264        }
265        if !(upper_denom.is_finite() && upper_denom > 0.0) {
266            return Err(format!(
267                "BSS upper potential denominator became invalid at step {step}: {upper_denom}"
268            ));
269        }
270
271        let mut chosen: Option<(usize, f64, f64)> = None;
272        for (row, factor) in whitened.iter().enumerate() {
273            let lower_trace = trace_factor_quadratic(factor, &lower_inv);
274            let lower_trace_sq = trace_factor_quadratic_square(factor, &lower_inv);
275            let upper_trace = trace_factor_quadratic(factor, &upper_inv);
276            let upper_trace_sq = trace_factor_quadratic_square(factor, &upper_inv);
277            let lower_score = lower_trace_sq / lower_denom - lower_trace;
278            let upper_score = upper_trace_sq / upper_denom + upper_trace;
279            if lower_score.is_finite()
280                && upper_score.is_finite()
281                && lower_score > 0.0
282                && upper_score > 0.0
283                && lower_score + BSS_SCORE_TOL >= upper_score
284            {
285                match chosen {
286                    None => chosen = Some((row, lower_score, upper_score)),
287                    Some((best_row, best_lower, best_upper)) => {
288                        let gap = lower_score - upper_score;
289                        let best_gap = best_lower - best_upper;
290                        if gap > best_gap + BSS_SCORE_TOL
291                            || ((gap - best_gap).abs() <= BSS_SCORE_TOL && row < best_row)
292                        {
293                            chosen = Some((row, lower_score, upper_score));
294                        }
295                    }
296                }
297            }
298        }
299
300        let (row, lower_score, upper_score) = chosen
301            .ok_or_else(|| format!("BSS failed to find a barrier-admissible row at step {step}"))?;
302        let inv_step_weight = 0.5 * (lower_score + upper_score);
303        if !(inv_step_weight.is_finite() && inv_step_weight > 0.0) {
304            return Err(format!(
305                "BSS invalid inverse step weight at step {step}: {inv_step_weight}"
306            ));
307        }
308        let step_weight = 1.0 / inv_step_weight;
309        add_factor_gram_scaled(&mut barrier_matrix, &whitened[row], step_weight);
310        row_weights[row] += step_weight;
311    }
312
313    let lower_final = steps as f64 - root;
314    let upper_final = delta_upper * (steps as f64 + root);
315    let scale = 2.0 / (lower_final + upper_final);
316    let mut indexed: Vec<(usize, f64)> = row_weights
317        .into_iter()
318        .enumerate()
319        .filter_map(|(row, weight)| {
320            let scaled = weight * scale;
321            (scaled > 0.0).then_some((row, scaled))
322        })
323        .collect();
324    indexed.sort_by_key(|&(row, _)| row);
325    let indices: Vec<usize> = indexed.iter().map(|&(row, _)| row).collect();
326    let weights: Vec<f64> = indexed.iter().map(|&(_, weight)| weight).collect();
327    let certificate = CoresetCertificate::new(target_eps, 0.0, dim, indices.len())?;
328    Ok(SpectralCoreset {
329        indices,
330        weights,
331        certificate,
332    })
333}
334
335/// Sensitivity upper bounds on the chart ball
336/// `||chart(theta) - chart(theta_anchor)|| <= chart_radius`.
337///
338/// The bound uses the linear-anchor leverage and inflates it by the curvature
339/// slack `kappa_hat * chart_radius`, i.e.
340/// `sigma_i <= leverage_i * (1 + kappa_hat * chart_radius)`. The same ball and
341/// curvature estimate must be used by the likelihood consumer that interprets
342/// the returned additive `eps_likelihood` certificate.
343pub fn sensitivity_upper_bounds(
344    linear_anchor_leverage: &[f64],
345    kappa_hat: f64,
346    chart_radius: f64,
347) -> Result<Vec<f64>, String> {
348    if !(kappa_hat.is_finite() && kappa_hat >= 0.0) {
349        return Err(format!(
350            "sensitivity bounds require finite non-negative kappa_hat, got {kappa_hat}"
351        ));
352    }
353    if !(chart_radius.is_finite() && chart_radius >= 0.0) {
354        return Err(format!(
355            "sensitivity bounds require finite non-negative chart_radius, got {chart_radius}"
356        ));
357    }
358    let inflation = 1.0 + kappa_hat * chart_radius;
359    linear_anchor_leverage
360        .iter()
361        .enumerate()
362        .map(|(row, &lev)| {
363            if lev.is_finite() && lev >= 0.0 {
364                Ok(lev * inflation)
365            } else {
366                Err(format!(
367                    "sensitivity leverage at row {row} must be finite and non-negative, got {lev}"
368                ))
369            }
370        })
371        .collect()
372}
373
374/// Greedy deterministic sensitivity coreset under a row budget.
375#[derive(Clone, Debug, PartialEq)]
376pub struct SensitivityCoreset {
377    /// Selected rows sorted by decreasing sensitivity, then row index.
378    pub indices: Vec<usize>,
379    /// Sensitivity mass retained by the selected rows.
380    pub selected_sensitivity_mass: f64,
381    /// Sensitivity mass not retained by the budget. A likelihood consumer can
382    /// map this to its additive `eps_likelihood` on the documented chart ball.
383    pub residual_sensitivity_mass: f64,
384}
385
386pub fn greedy_sensitivity_coreset(
387    sigma_upper_bounds: &[f64],
388    budget: usize,
389) -> Result<SensitivityCoreset, String> {
390    let mut indexed = Vec::with_capacity(sigma_upper_bounds.len());
391    for (row, &sigma) in sigma_upper_bounds.iter().enumerate() {
392        if !(sigma.is_finite() && sigma >= 0.0) {
393            return Err(format!(
394                "sensitivity upper bound at row {row} must be finite and non-negative, got {sigma}"
395            ));
396        }
397        indexed.push((row, sigma));
398    }
399    indexed.sort_by(|&(row_a, sigma_a), &(row_b, sigma_b)| {
400        sigma_b
401            .partial_cmp(&sigma_a)
402            .unwrap_or(std::cmp::Ordering::Equal)
403            .then(row_a.cmp(&row_b))
404    });
405    let selected_len = budget.min(indexed.len());
406    let indices: Vec<usize> = indexed
407        .iter()
408        .take(selected_len)
409        .map(|&(row, _)| row)
410        .collect();
411    let selected_sensitivity_mass: f64 = indexed
412        .iter()
413        .take(selected_len)
414        .map(|&(_, sigma)| sigma)
415        .sum();
416    let residual_sensitivity_mass: f64 = indexed
417        .iter()
418        .skip(selected_len)
419        .map(|&(_, sigma)| sigma)
420        .sum();
421    Ok(SensitivityCoreset {
422        indices,
423        selected_sensitivity_mass,
424        residual_sensitivity_mass,
425    })
426}
427
428impl RowSamplingMeasure {
429    /// Build the enrichment measure from a [`RowMetric`].
430    ///
431    /// The per-row liveness is the Fisher mass `tr(M_n)` read from the metric's
432    /// validated PSD blocks. The result is normalized to a proper sampling
433    /// measure. Degrades to the **uniform** measure (every row `1/n`) when the
434    /// metric is Euclidean, carries no usable mass (all rows ≤ 0), or yields any
435    /// non-finite mass — never an error, mirroring [`RowMetric`]'s
436    /// magic-by-default discipline.
437    ///
438    /// This function reads only the metric's geometry; it writes nothing into
439    /// the metric, the loss, the gradient, or any criterion.
440    pub fn from_metric(metric: &RowMetric) -> Self {
441        let n = metric.n_rows();
442        if n == 0 {
443            return Self {
444                provenance: MeasureProvenance::Uniform,
445                weights: Vec::new(),
446            };
447        }
448
449        // Euclidean ⇒ exactly uniform by construction. Short-circuit so the
450        // fallback is bit-for-bit `1/n`, not "tr(I_p)=p then renormalize" (which
451        // is the same value, but the explicit path documents intent and avoids
452        // any floating-point renormalization noise).
453        if matches!(metric.provenance(), MetricProvenance::Euclidean) {
454            return Self::uniform(n);
455        }
456
457        let mass = per_row_fisher_mass(metric);
458        Self::from_masses(metric.provenance(), mass)
459    }
460
461    /// The uniform measure over `n` rows: every row weight `1 / n`. The graceful
462    /// fallback and the explicit "no behavioral harvest" measure.
463    pub fn uniform(n: usize) -> Self {
464        let w = if n == 0 { 0.0 } else { 1.0 / n as f64 };
465        Self {
466            provenance: MeasureProvenance::Uniform,
467            weights: vec![w; n],
468        }
469    }
470
471    /// Construct from raw per-row masses, normalizing to a proper measure.
472    /// Falls back to uniform if the masses carry no usable signal.
473    ///
474    /// Crate-visible so the two-tier harvest (`gam_inference::harvest`)
475    /// can lift designed-subsample Fisher masses to a full-corpus measure
476    /// through the same validation/normalization path.
477    pub fn from_masses(metric_provenance: MetricProvenance, masses: Vec<f64>) -> Self {
478        let n = masses.len();
479        if n == 0 {
480            return Self::uniform(0);
481        }
482        // Clamp negatives to zero (a validated PSD block has `tr ≥ 0`, but a
483        // tiny normalizer round-off could dip below) and reject non-finite.
484        let mut total = 0.0_f64;
485        let mut clean = vec![0.0_f64; n];
486        let mut all_finite = true;
487        for (i, &m) in masses.iter().enumerate() {
488            if !m.is_finite() {
489                all_finite = false;
490                break;
491            }
492            let v = if m > 0.0 { m } else { 0.0 };
493            clean[i] = v;
494            total += v;
495        }
496
497        if !all_finite || !(total > 0.0) {
498            // No usable behavioral signal ⇒ degrade to uniform, never NaN.
499            return Self::uniform(n);
500        }
501
502        let inv = 1.0 / total;
503        for w in clean.iter_mut() {
504            *w *= inv;
505        }
506        Self {
507            provenance: MeasureProvenance::FisherMass(metric_provenance),
508            weights: clean,
509        }
510    }
511
512    /// The normalized per-row sampling weights (`Σ == 1`). Read-only; this is a
513    /// sampling measure, never a loss weight.
514    pub fn weights(&self) -> &[f64] {
515        &self.weights
516    }
517
518    /// The measure's provenance — `Uniform` (graceful fallback / no harvest) or
519    /// `FisherMass` (real behavioral enrichment).
520    pub fn provenance(&self) -> MeasureProvenance {
521        self.provenance
522    }
523
524    /// Number of rows the measure is defined over.
525    pub fn n_rows(&self) -> usize {
526        self.weights.len()
527    }
528
529    /// Whether this measure actually enriches (is non-uniform Fisher-mass).
530    /// `false` for the uniform fallback.
531    pub fn is_enriched(&self) -> bool {
532        matches!(self.provenance, MeasureProvenance::FisherMass(_))
533    }
534
535    /// Deterministic **systematic-resampling** enrichment ordering.
536    ///
537    /// Returns a length-`count` vector of row indices drawn `∝ weights`, using
538    /// low-variance systematic resampling with a fixed, *index-derived* jitter —
539    /// there is **no clock randomness**; the same `(measure, count, seed)`
540    /// always yields the same ordering. Behaviorally-live rows therefore appear
541    /// with multiplicity proportional to their Fisher mass, so a rare-but-live
542    /// feature's rows are oversampled relative to uniform.
543    ///
544    /// Systematic resampling places `count` equally spaced pointers
545    /// `(j + u) / count`, `j = 0..count`, against the cumulative weight CDF and
546    /// emits the row each pointer lands in. The single offset `u ∈ [0, 1)` is a
547    /// `splitmix64`-hash of `seed` (deterministic), giving an unbiased draw
548    /// whose per-row expected count is `count · weights[row]` while guaranteeing
549    /// every weight-`≥ 1/count` row appears at least once (the recall property
550    /// the rare-feature control asserts).
551    ///
552    /// The uniform fallback reproduces an even, deterministic round-robin over
553    /// all rows — i.e. plain attention to every row, today's behavior.
554    ///
555    /// This ordering is consumed **only** by a discovery/seeding pass. The rows
556    /// it names carry their ordinary, unmodified per-row objective.
557    pub fn enrichment_order(&self, count: usize, seed: u64) -> Vec<usize> {
558        let n = self.weights.len();
559        if n == 0 || count == 0 {
560            return Vec::new();
561        }
562
563        // Deterministic offset u ∈ [0, 1) from the seed (index-/seed-derived,
564        // never the clock). 53-bit mantissa for an exact double in [0, 1).
565        let u = {
566            let bits = splitmix64_hash(seed ^ ENRICHMENT_SALT);
567            let mantissa = (bits >> 11) as f64; // top 53 bits
568            mantissa / ((1_u64 << 53) as f64)
569        };
570
571        // Cumulative distribution over rows. `weights` already sums to 1; guard
572        // the last bucket to exactly 1.0 against round-off so every pointer
573        // lands in a valid row.
574        let mut cdf = vec![0.0_f64; n];
575        let mut acc = 0.0_f64;
576        for i in 0..n {
577            acc += self.weights[i];
578            cdf[i] = acc;
579        }
580        cdf[n - 1] = 1.0;
581
582        let mut out = Vec::with_capacity(count);
583        let step = 1.0 / count as f64;
584        let mut cursor = 0usize;
585        for j in 0..count {
586            let pointer = (j as f64 + u) * step;
587            // Advance the CDF cursor to the first bucket whose cumulative mass
588            // covers the pointer. Monotone in `j`, so this is one linear sweep.
589            while cursor < n - 1 && pointer > cdf[cursor] {
590                cursor += 1;
591            }
592            out.push(cursor);
593        }
594        out
595    }
596
597    /// Expected number of times each row is drawn in a `count`-sized enrichment
598    /// batch: `count · weights[row]`. A diagnostic for the discovery-recall
599    /// control — it lets a test assert that a rare-but-live feature's rows have
600    /// markedly higher expected representation under enrichment than under
601    /// uniform, with no sampling noise.
602    pub fn expected_representation(&self, count: usize) -> Vec<f64> {
603        let c = count as f64;
604        self.weights.iter().map(|&w| c * w).collect()
605    }
606
607    /// Draw a **designed subsample** with honest inclusion weights — the
608    /// frontier estimator of #987 (mechanizing the #973 subsample-honesty
609    /// contract for measure-driven designs).
610    ///
611    /// This is a different animal from [`Self::enrichment_order`], and the
612    /// distinction is load-bearing:
613    ///
614    /// * **Enrichment** orders rows for *discovery/seeding attention*; each
615    ///   visited row keeps its ordinary, unweighted per-row objective. The
616    ///   measure never touches the loss.
617    /// * A **designed subsample** *replaces the full corpus* as what the fit
618    ///   sums over. That is only sound if every selected row's loss term is
619    ///   multiplied by `1 / π_i` (its inclusion probability), so that the
620    ///   subsampled criterion is **unbiased** for the full-corpus criterion:
621    ///   `E[Σ_{i ∈ S} ℓ_i / π_i] = Σ_i ℓ_i`. The returned
622    ///   [`DesignedRowSample`] carries exactly those weights; the caller folds
623    ///   them into the likelihood as row weights. These are sampling-design
624    ///   corrections — they are *not* a Fisher reweighting of residuals (the
625    ///   #980 failure mode), and under the uniform measure they degrade to the
626    ///   constant `n / budget`, the plain Horvitz–Thompson scale-up.
627    ///
628    /// Design: inclusion probabilities are water-filled as
629    /// `π_i = min(1, τ · w'_i)` with `τ` solved so `Σ π_i = budget`, where
630    /// `w'` is the measure defensively mixed with
631    /// [`DESIGNED_SAMPLE_UNIFORM_MIX`] of uniform — the standard
632    /// defensive-mixture guard that keeps every row's `π_i > 0` (no row's loss
633    /// is unreachable, so the estimator stays unbiased) and bounds the largest
634    /// weight. Selection is Madow systematic sampling against the cumulative
635    /// `π` with a single deterministic `splitmix64`-derived offset — no clock
636    /// randomness; the same `(measure, budget, seed)` always yields the same
637    /// sample. Rows are returned in ascending order (stream-friendly).
638    ///
639    /// `budget ≥ n` returns every row with weight `1.0` — the exact full pass,
640    /// bit-for-bit today's behavior, so a driver can call this unconditionally
641    /// and let the budget decide.
642    pub fn designed_subsample(&self, budget: usize, seed: u64) -> DesignedRowSample {
643        let n = self.weights.len();
644        if n == 0 || budget == 0 {
645            return DesignedRowSample {
646                provenance: self.provenance,
647                rows: Vec::new(),
648                likelihood_weights: Vec::new(),
649                expected_size: 0.0,
650            };
651        }
652        if budget >= n {
653            return DesignedRowSample {
654                provenance: self.provenance,
655                rows: (0..n).collect(),
656                likelihood_weights: vec![1.0; n],
657                expected_size: n as f64,
658            };
659        }
660
661        // Defensive mixture: w' = (1 − ε)·w + ε/n. Keeps every π_i > 0.
662        let eps = DESIGNED_SAMPLE_UNIFORM_MIX;
663        let unif = 1.0 / n as f64;
664        let mixed: Vec<f64> = self
665            .weights
666            .iter()
667            .map(|&w| (1.0 - eps) * w + eps * unif)
668            .collect();
669
670        // Water-fill τ so that Σ min(1, τ·w'_i) = budget. Sort descending and
671        // peel off the capped prefix; deterministic (index tie-break).
672        let mut order: Vec<usize> = (0..n).collect();
673        order.sort_by(|&a, &b| {
674            mixed[b]
675                .partial_cmp(&mixed[a])
676                .unwrap_or(std::cmp::Ordering::Equal)
677                .then(a.cmp(&b))
678        });
679        let total: f64 = mixed.iter().sum();
680        let mut capped = 0usize;
681        let mut tail_mass = total;
682        let mut tau = budget as f64 / tail_mass;
683        while capped < n {
684            let next = mixed[order[capped]];
685            if tau * next <= 1.0 {
686                break;
687            }
688            // Cap this row at π = 1 and re-solve τ over the remainder.
689            capped += 1;
690            tail_mass -= next;
691            let remaining_budget = budget as f64 - capped as f64;
692            if remaining_budget <= 0.0 || tail_mass <= 0.0 {
693                break;
694            }
695            tau = remaining_budget / tail_mass;
696        }
697        let mut pi = vec![0.0_f64; n];
698        for (rank, &i) in order.iter().enumerate() {
699            pi[i] = if rank < capped {
700                1.0
701            } else {
702                (tau * mixed[i]).min(1.0)
703            };
704        }
705
706        // Madow systematic selection in row order: row i is selected iff an
707        // integer pointer k + u falls inside its cumulative-π interval.
708        // Deterministic offset u ∈ [0, 1) from the seed.
709        let u = {
710            let bits = splitmix64_hash(seed ^ DESIGNED_SAMPLE_SALT);
711            let mantissa = (bits >> 11) as f64;
712            mantissa / ((1_u64 << 53) as f64)
713        };
714        let mut rows = Vec::with_capacity(budget + 1);
715        let mut likelihood_weights = Vec::with_capacity(budget + 1);
716        let mut acc = 0.0_f64;
717        for (i, &p) in pi.iter().enumerate() {
718            let before = acc;
719            acc += p;
720            // Selected iff ⌊acc − u⌋ > ⌊before − u⌋ (a pointer crossed).
721            if (acc - u).floor() > (before - u).floor() {
722                rows.push(i);
723                likelihood_weights.push(1.0 / p);
724            }
725        }
726        DesignedRowSample {
727            provenance: self.provenance,
728            rows,
729            likelihood_weights,
730            expected_size: pi.iter().sum(),
731        }
732    }
733
734    /// Draw a **certified** designed subsample within a target `eps` of the full
735    /// corpus on BOTH evidence halves (#1012).
736    ///
737    /// Unlike [`Self::designed_subsample`] — whose Horvitz–Thompson design is
738    /// unbiased only in expectation — this is the deterministic CERTIFIED mode:
739    ///
740    /// * **spectral half (`½log|H|`):** deterministic Batson–Spielman–Srivastava
741    ///   selection of `O(dim/eps²)` weighted rows from the per-row factors
742    ///   `R_i` (`H_i = R_iᵀR_i`), giving `(1−eps)H ⪯ H_C ⪯ (1+eps)H` and hence
743    ///   `|log|H_C| − log|H|| ≤ dim·log((1+eps)/(1−eps))`;
744    /// * **likelihood half (`L`):** the sensitivity bounds
745    ///   `σ_i ≤ leverage_i·(1 + κ̂·chart_radius)` on the documented chart ball,
746    ///   greedily selected against the row budget; the residual sensitivity mass
747    ///   is the additive `eps_likelihood·L` the certificate carries.
748    ///
749    /// The two selections are unioned (a row certified for either half is kept),
750    /// the rows carry their deterministic BSS / sensitivity weights, and the
751    /// [`CoresetCertificate`] rides the result so a race consumer can gate the
752    /// transfer with [`CoresetCertificate::race_transfer_margin`] — the SAME
753    /// margin seam the enclosure path (#1011) declares. Below that margin the
754    /// consumer must grow the coreset, never silently decide.
755    ///
756    /// `row_factors` is the per-row factor list aligned with this measure's rows;
757    /// `leverage`, `kappa_hat`, `chart_radius` are the sensitivity inputs (the
758    /// #1007 SVD-anchor leverage and the #1008 curvature slack). `budget` caps
759    /// the likelihood-half greedy selection.
760    pub fn designed_subsample_certified<'a, I>(
761        &self,
762        row_factors: I,
763        target_eps: f64,
764        leverage: &[f64],
765        kappa_hat: f64,
766        chart_radius: f64,
767        budget: usize,
768    ) -> Result<CertifiedRowSample, String>
769    where
770        I: IntoIterator<Item = ArrayView2<'a, f64>>,
771    {
772        // Spectral half: deterministic BSS coreset + its spectral certificate.
773        let spectral = bss_spectral_coreset_certified(row_factors, target_eps)?;
774
775        // Likelihood half: sensitivity-bounded greedy coreset; the residual mass
776        // not covered by the budget becomes the additive eps_likelihood.
777        let sigma = sensitivity_upper_bounds(leverage, kappa_hat, chart_radius)?;
778        let sensitivity = greedy_sensitivity_coreset(&sigma, budget)?;
779        let total_sensitivity =
780            sensitivity.selected_sensitivity_mass + sensitivity.residual_sensitivity_mass;
781        let eps_likelihood = if total_sensitivity > 0.0 {
782            sensitivity.residual_sensitivity_mass / total_sensitivity
783        } else {
784            0.0
785        };
786
787        // Union the two selections; a row certified for either half is retained.
788        // Carry the BSS weight where present, else the HT scale-up `1/π` proxy
789        // (uniform `n/|S|`) so the likelihood-only rows still enter the criterion
790        // unbiasedly.
791        let n = self.weights.len();
792        let bss_weight: std::collections::BTreeMap<usize, f64> = spectral
793            .indices
794            .iter()
795            .zip(spectral.weights.iter())
796            .map(|(&i, &w)| (i, w))
797            .collect();
798        let mut selected: std::collections::BTreeSet<usize> =
799            spectral.indices.iter().copied().collect();
800        for &i in &sensitivity.indices {
801            selected.insert(i);
802        }
803        let selected_len = selected.len().max(1);
804        let ht_scale = if n > 0 {
805            n as f64 / selected_len as f64
806        } else {
807            1.0
808        };
809
810        let rows: Vec<usize> = selected.iter().copied().collect();
811        let weights: Vec<f64> = rows
812            .iter()
813            .map(|i| *bss_weight.get(i).unwrap_or(&ht_scale))
814            .collect();
815
816        let certificate = CoresetCertificate::new(
817            spectral.certificate.eps_spectral,
818            eps_likelihood,
819            spectral.certificate.dim_effective,
820            rows.len(),
821        )?;
822
823        Ok(CertifiedRowSample {
824            provenance: self.provenance,
825            rows,
826            weights,
827            certificate,
828        })
829    }
830}
831
832/// A designed importance subsample with honest Horvitz–Thompson likelihood
833/// weights — what a frontier fit sums over instead of the full corpus
834/// (#987 / #973). Produced by [`RowSamplingMeasure::designed_subsample`].
835#[derive(Clone, Debug)]
836pub struct DesignedRowSample {
837    /// Provenance of the measure that shaped the design (uniform fallback or
838    /// Fisher mass), echoed for consumer certification.
839    pub provenance: MeasureProvenance,
840    /// Selected row indices, ascending.
841    pub rows: Vec<usize>,
842    /// Per-selected-row likelihood weight `1 / π_i`, aligned with `rows`.
843    /// Multiplying row `i`'s loss term by this makes the subsampled criterion
844    /// unbiased for the full-corpus criterion.
845    pub likelihood_weights: Vec<f64>,
846    /// `Σ π_i` — the design's expected sample size (≈ the requested budget;
847    /// Madow selection realizes `⌊·⌋` or `⌈·⌉` of it).
848    pub expected_size: f64,
849}
850
851impl DesignedRowSample {
852    /// Number of rows actually selected.
853    pub fn len(&self) -> usize {
854        self.rows.len()
855    }
856
857    pub fn is_empty(&self) -> bool {
858        self.rows.is_empty()
859    }
860
861    /// `Σ 1/π_i` over the selected rows — the Horvitz–Thompson estimate of the
862    /// corpus row count. A consumer can sanity-gate the design by checking
863    /// this lands near `n` (it is exactly `n` in expectation).
864    pub fn estimated_corpus_rows(&self) -> f64 {
865        self.likelihood_weights.iter().sum()
866    }
867}
868
869/// A **certified** designed subsample (#1012): the rows that certify BOTH
870/// evidence halves within the target `eps`, their deterministic BSS /
871/// sensitivity weights, and the [`CoresetCertificate`] a race consumer gates
872/// the verdict transfer against. Produced by
873/// [`RowSamplingMeasure::designed_subsample_certified`].
874#[derive(Clone, Debug)]
875pub struct CertifiedRowSample {
876    /// Provenance of the measure that shaped the design.
877    pub provenance: MeasureProvenance,
878    /// Selected row indices, ascending (union of the spectral and sensitivity
879    /// coresets).
880    pub rows: Vec<usize>,
881    /// Per-selected-row weight aligned with `rows`: the BSS spectral weight
882    /// where the row was chosen for the log-determinant half, else the
883    /// Horvitz–Thompson scale-up for a likelihood-only row.
884    pub weights: Vec<f64>,
885    /// The certificate bounding the worst-case evidence transfer error. Feed
886    /// [`CoresetCertificate::race_transfer_margin`] to the race consumer's
887    /// margin gate.
888    pub certificate: CoresetCertificate,
889}
890
891impl CertifiedRowSample {
892    pub fn len(&self) -> usize {
893        self.rows.len()
894    }
895
896    pub fn is_empty(&self) -> bool {
897        self.rows.is_empty()
898    }
899
900    /// The race-transfer margin a consumer must clear before inheriting the
901    /// full-corpus verdict from this coreset — the shared #1011/#1012 seam.
902    pub fn race_transfer_margin(&self) -> f64 {
903        self.certificate.race_transfer_margin()
904    }
905}
906
907/// Defensive uniform mixture fraction for [`RowSamplingMeasure::designed_subsample`]:
908/// the design samples from `(1 − ε)·measure + ε·uniform`. Guarantees every
909/// row's inclusion probability is positive (unbiasedness needs `π_i > 0`
910/// wherever `ℓ_i ≠ 0`) and caps the worst-case `1/π` weight at
911/// `n / (ε · budget)`. The standard defensive-importance-sampling guard.
912const DESIGNED_SAMPLE_UNIFORM_MIX: f64 = 0.1;
913
914/// Salt for the designed-sample systematic offset, distinct from
915/// [`ENRICHMENT_SALT`] so the two draws never share a stream for one seed.
916const DESIGNED_SAMPLE_SALT: u64 = 0x73AD_0987_5EED_D51F;
917
918/// Salt mixed into the enrichment seed so the offset hash is distinct from any
919/// other `splitmix64_hash` use of the same numeric seed elsewhere in the crate.
920const ENRICHMENT_SALT: u64 = 0x980E_1C45_F00D_AC70;
921
922const BSS_SCORE_TOL: f64 = 1e-10;
923
924/// Per-row Fisher mass `tr(M_n)` from the metric's criterion-facing traces.
925///
926/// The traces are recorded at metric construction (un-floored), so the solver
927/// `δ` never enters the measure — consistent with the `RowMetric` #747
928/// discipline, and irrelevant anyway because the measure feeds no criterion.
929/// Pure read; touches nothing.
930pub fn per_row_fisher_mass(metric: &RowMetric) -> Vec<f64> {
931    metric.row_traces().to_vec()
932}
933
934fn collect_row_factors<'a, I>(rows: I) -> Result<Vec<Array2<f64>>, String>
935where
936    I: IntoIterator<Item = ArrayView2<'a, f64>>,
937{
938    let mut out = Vec::new();
939    let mut ambient_dim: Option<usize> = None;
940    for (row, factor) in rows.into_iter().enumerate() {
941        if factor.iter().any(|value| !value.is_finite()) {
942            return Err(format!("BSS row factor {row} contains a non-finite value"));
943        }
944        match ambient_dim {
945            None => ambient_dim = Some(factor.ncols()),
946            Some(expected) if expected != factor.ncols() => {
947                return Err(format!(
948                    "BSS row factor {row} has {} columns, expected {expected}",
949                    factor.ncols()
950                ));
951            }
952            Some(_) => {}
953        }
954        out.push(factor.to_owned());
955    }
956    Ok(out)
957}
958
959fn stacked_factor_whitener(
960    factors: &[Array2<f64>],
961    ambient_dim: usize,
962) -> Result<Array2<f64>, String> {
963    let total_factor_rows: usize = factors.iter().map(|factor| factor.nrows()).sum();
964    if total_factor_rows == 0 || ambient_dim == 0 {
965        return Ok(Array2::<f64>::zeros((ambient_dim, 0)));
966    }
967
968    let mut stacked = Array2::<f64>::zeros((total_factor_rows, ambient_dim));
969    let mut cursor = 0usize;
970    for factor in factors {
971        for row in 0..factor.nrows() {
972            for col in 0..ambient_dim {
973                stacked[[cursor + row, col]] = factor[[row, col]];
974            }
975        }
976        cursor += factor.nrows();
977    }
978
979    let (_, singular, vt) = stacked
980        .svd(false, true)
981        .map_err(|err| format!("BSS stacked row-factor SVD failed: {err}"))?;
982    let vt = vt.ok_or_else(|| "BSS stacked row-factor SVD did not return Vt".to_string())?;
983    let max_sigma = singular.iter().copied().fold(0.0_f64, f64::max);
984    if !(max_sigma.is_finite() && max_sigma >= 0.0) {
985        return Err("BSS stacked row sketch has invalid singular values".to_string());
986    }
987    let tol = (ambient_dim.max(1) as f64) * f64::EPSILON * max_sigma.max(1.0) * 100.0;
988    let kept: Vec<usize> = singular
989        .iter()
990        .enumerate()
991        .filter_map(|(idx, &sigma)| (sigma > tol).then_some(idx))
992        .collect();
993    let mut whitener = Array2::<f64>::zeros((ambient_dim, kept.len()));
994    for (out_col, &sv_col) in kept.iter().enumerate() {
995        let scale = 1.0 / singular[sv_col];
996        for ambient_col in 0..ambient_dim {
997            whitener[[ambient_col, out_col]] = vt[[sv_col, ambient_col]] * scale;
998        }
999    }
1000    Ok(whitener)
1001}
1002
1003fn whiten_row_factors(factors: &[Array2<f64>], whitener: &Array2<f64>) -> Vec<Array2<f64>> {
1004    factors.iter().map(|factor| factor.dot(whitener)).collect()
1005}
1006
1007fn inverse_shifted_lower(matrix: &Array2<f64>, lower: f64) -> Result<Array2<f64>, String> {
1008    let n = matrix.nrows();
1009    let mut shifted = matrix.clone();
1010    for i in 0..n {
1011        shifted[[i, i]] -= lower;
1012    }
1013    inverse_symmetric_positive(&shifted, "BSS lower barrier inverse")
1014}
1015
1016fn inverse_shifted_upper(matrix: &Array2<f64>, upper: f64) -> Result<Array2<f64>, String> {
1017    let n = matrix.nrows();
1018    let mut shifted = Array2::<f64>::zeros((n, n));
1019    for i in 0..n {
1020        shifted[[i, i]] = upper;
1021    }
1022    for i in 0..n {
1023        for j in 0..n {
1024            shifted[[i, j]] -= matrix[[i, j]];
1025        }
1026    }
1027    inverse_symmetric_positive(&shifted, "BSS upper barrier inverse")
1028}
1029
1030fn inverse_symmetric_positive(matrix: &Array2<f64>, context: &str) -> Result<Array2<f64>, String> {
1031    let (evals, evecs) = matrix
1032        .eigh(Side::Lower)
1033        .map_err(|err| format!("{context} eigendecomposition failed: {err}"))?;
1034    let n = matrix.nrows();
1035    let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
1036    let tol = (n.max(1) as f64) * f64::EPSILON * max_eval * 100.0;
1037    let mut inv = Array2::<f64>::zeros((n, n));
1038    for k in 0..n {
1039        let lambda = evals[k];
1040        if !(lambda.is_finite() && lambda > tol) {
1041            return Err(format!(
1042                "{context} expected a positive barrier matrix, eigenvalue {k} was {lambda}"
1043            ));
1044        }
1045        let inv_lambda = 1.0 / lambda;
1046        for i in 0..n {
1047            for j in 0..n {
1048                inv[[i, j]] += evecs[[i, k]] * inv_lambda * evecs[[j, k]];
1049            }
1050        }
1051    }
1052    Ok(inv)
1053}
1054
1055fn lower_potential(matrix: &Array2<f64>, lower: f64) -> Result<f64, String> {
1056    let inv = inverse_shifted_lower(matrix, lower)?;
1057    Ok((0..inv.nrows()).map(|i| inv[[i, i]]).sum())
1058}
1059
1060fn upper_potential(matrix: &Array2<f64>, upper: f64) -> Result<f64, String> {
1061    let inv = inverse_shifted_upper(matrix, upper)?;
1062    Ok((0..inv.nrows()).map(|i| inv[[i, i]]).sum())
1063}
1064
1065fn trace_factor_quadratic(factor: &Array2<f64>, matrix: &Array2<f64>) -> f64 {
1066    let mut trace = 0.0_f64;
1067    for row in 0..factor.nrows() {
1068        for i in 0..factor.ncols() {
1069            let xi = factor[[row, i]];
1070            if xi == 0.0 {
1071                continue;
1072            }
1073            for j in 0..factor.ncols() {
1074                trace += xi * matrix[[i, j]] * factor[[row, j]];
1075            }
1076        }
1077    }
1078    trace
1079}
1080
1081fn trace_factor_quadratic_square(factor: &Array2<f64>, matrix: &Array2<f64>) -> f64 {
1082    let mut trace = 0.0_f64;
1083    for row in 0..factor.nrows() {
1084        for i in 0..factor.ncols() {
1085            let mut v = 0.0_f64;
1086            for j in 0..factor.ncols() {
1087                v += matrix[[i, j]] * factor[[row, j]];
1088            }
1089            trace += v * v;
1090        }
1091    }
1092    trace
1093}
1094
1095fn add_factor_gram_scaled(target: &mut Array2<f64>, factor: &Array2<f64>, scale: f64) {
1096    let dim = factor.ncols();
1097    for row in 0..factor.nrows() {
1098        for i in 0..dim {
1099            let xi = factor[[row, i]];
1100            if xi == 0.0 {
1101                continue;
1102            }
1103            for j in 0..dim {
1104                target[[i, j]] += scale * xi * factor[[row, j]];
1105            }
1106        }
1107    }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113    use ndarray::Array2;
1114    use ndarray::array;
1115    use std::sync::Arc;
1116
1117    fn summed_factor_gram(factors: &[Array2<f64>], ambient_dim: usize) -> Array2<f64> {
1118        let mut total = Array2::<f64>::zeros((ambient_dim, ambient_dim));
1119        for factor in factors {
1120            add_factor_gram_scaled(&mut total, factor, 1.0);
1121        }
1122        total
1123    }
1124
1125    fn factors_from_rows(rows: &[Vec<f64>], p: usize, rank: usize) -> Arc<Array2<f64>> {
1126        let n = rows.len();
1127        let mut u = Array2::<f64>::zeros((n, p * rank));
1128        for (r, row) in rows.iter().enumerate() {
1129            for (c, &v) in row.iter().enumerate() {
1130                u[[r, c]] = v;
1131            }
1132        }
1133        Arc::new(u)
1134    }
1135
1136    #[test]
1137    fn euclidean_degrades_to_uniform() {
1138        let metric = RowMetric::euclidean(5, 3).expect("euclidean");
1139        let measure = RowSamplingMeasure::from_metric(&metric);
1140        assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
1141        assert!(!measure.is_enriched());
1142        for &w in measure.weights() {
1143            assert!((w - 0.2).abs() < 1e-12);
1144        }
1145    }
1146
1147    #[test]
1148    fn weights_normalize_to_one_and_track_mass() {
1149        // p = 1, rank = 1 ⇒ tr(M_n) = u_n². Row 2 is far louder.
1150        let rows = vec![vec![1.0], vec![1.0], vec![3.0], vec![1.0]];
1151        let u = factors_from_rows(&rows, 1, 1);
1152        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1153        let measure = RowSamplingMeasure::from_metric(&metric);
1154        assert!(measure.is_enriched());
1155        let w = measure.weights();
1156        let sum: f64 = w.iter().sum();
1157        assert!((sum - 1.0).abs() < 1e-12);
1158        // tr masses: 1, 1, 9, 1 ⇒ total 12.
1159        assert!((w[0] - 1.0 / 12.0).abs() < 1e-12);
1160        assert!((w[2] - 9.0 / 12.0).abs() < 1e-12);
1161        assert!(w[2] > w[0] * 8.0);
1162    }
1163
1164    #[test]
1165    fn all_zero_mass_degrades_to_uniform() {
1166        let rows = vec![vec![0.0], vec![0.0], vec![0.0]];
1167        let u = factors_from_rows(&rows, 1, 1);
1168        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1169        let measure = RowSamplingMeasure::from_metric(&metric);
1170        assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
1171        for &w in measure.weights() {
1172            assert!((w - 1.0 / 3.0).abs() < 1e-12);
1173        }
1174    }
1175
1176    #[test]
1177    fn enrichment_order_is_deterministic() {
1178        let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
1179        let u = factors_from_rows(&rows, 1, 1);
1180        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1181        let measure = RowSamplingMeasure::from_metric(&metric);
1182        let a = measure.enrichment_order(20, 7);
1183        let b = measure.enrichment_order(20, 7);
1184        assert_eq!(a, b, "same seed must give identical ordering");
1185        let c = measure.enrichment_order(20, 8);
1186        // Different seed ⇒ (generally) different ordering, but same length.
1187        assert_eq!(c.len(), 20);
1188    }
1189
1190    #[test]
1191    fn enrichment_oversamples_loud_row() {
1192        // Row 1 has 9x the mass of rows 0 and 2.
1193        let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
1194        let u = factors_from_rows(&rows, 1, 1);
1195        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1196        let measure = RowSamplingMeasure::from_metric(&metric);
1197        let count = 110;
1198        let order = measure.enrichment_order(count, 1);
1199        let loud = order.iter().filter(|&&r| r == 1).count();
1200        let quiet0 = order.iter().filter(|&&r| r == 0).count();
1201        // Expected: 9/11 of 110 = 90 for the loud row, 10 each for the quiet.
1202        assert!(
1203            loud > quiet0 * 5,
1204            "loud row must be oversampled: loud={loud} quiet0={quiet0}"
1205        );
1206    }
1207
1208    #[test]
1209    fn expected_representation_matches_count_times_weight() {
1210        let rows = vec![vec![1.0], vec![3.0]];
1211        let u = factors_from_rows(&rows, 1, 1);
1212        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1213        let measure = RowSamplingMeasure::from_metric(&metric);
1214        let rep = measure.expected_representation(10);
1215        // masses 1, 9 ⇒ weights 0.1, 0.9 ⇒ reps 1.0, 9.0.
1216        assert!((rep[0] - 1.0).abs() < 1e-12);
1217        assert!((rep[1] - 9.0).abs() < 1e-12);
1218    }
1219
1220    #[test]
1221    fn designed_subsample_is_deterministic_and_honest() {
1222        // 200 rows, one loud block. The design must (a) be reproducible for a
1223        // fixed seed, (b) carry weights 1/π whose HT total estimates n, and
1224        // (c) hit roughly the requested budget.
1225        let n = 200usize;
1226        let rows: Vec<Vec<f64>> = (0..n)
1227            .map(|i| vec![if i % 10 == 0 { 3.0 } else { 1.0 }])
1228            .collect();
1229        let u = factors_from_rows(&rows, 1, 1);
1230        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1231        let measure = RowSamplingMeasure::from_metric(&metric);
1232
1233        let budget = 40usize;
1234        let a = measure.designed_subsample(budget, 17);
1235        let b = measure.designed_subsample(budget, 17);
1236        assert_eq!(a.rows, b.rows, "same seed must give the identical design");
1237        assert_eq!(a.likelihood_weights, b.likelihood_weights);
1238
1239        // Madow realizes ⌊Σπ⌋ or ⌈Σπ⌉ rows; Σπ is the budget by construction.
1240        assert!((a.expected_size - budget as f64).abs() < 1e-9);
1241        assert!(a.len() == budget || a.len() == budget + 1 || a.len() + 1 == budget);
1242
1243        // Horvitz–Thompson corpus-size identity: Σ 1/π over a systematic
1244        // sample concentrates near n (exact in expectation; systematic
1245        // sampling keeps it within a small relative band here).
1246        let est = a.estimated_corpus_rows();
1247        assert!(
1248            (est - n as f64).abs() < 0.25 * n as f64,
1249            "HT corpus estimate {est} too far from n = {n}"
1250        );
1251
1252        // Rows ascend and weights are finite and ≥ 1 (π ≤ 1).
1253        assert!(a.rows.windows(2).all(|w| w[0] < w[1]));
1254        assert!(
1255            a.likelihood_weights
1256                .iter()
1257                .all(|&w| w.is_finite() && w >= 1.0 - 1e-12)
1258        );
1259    }
1260
1261    #[test]
1262    fn designed_subsample_full_budget_is_the_exact_pass() {
1263        let measure = RowSamplingMeasure::uniform(7);
1264        let s = measure.designed_subsample(7, 3);
1265        assert_eq!(s.rows, (0..7).collect::<Vec<_>>());
1266        assert!(s.likelihood_weights.iter().all(|&w| w == 1.0));
1267        let s = measure.designed_subsample(100, 3);
1268        assert_eq!(s.rows.len(), 7);
1269    }
1270
1271    #[test]
1272    fn designed_subsample_uniform_measure_gives_flat_weights() {
1273        // Under the uniform fallback every π is budget/n, so every selected
1274        // row carries the same n/budget weight — plain HT scale-up.
1275        let n = 120usize;
1276        let budget = 30usize;
1277        let measure = RowSamplingMeasure::uniform(n);
1278        let s = measure.designed_subsample(budget, 5);
1279        assert_eq!(s.provenance, MeasureProvenance::Uniform);
1280        let expect = n as f64 / budget as f64;
1281        for &w in &s.likelihood_weights {
1282            assert!(
1283                (w - expect).abs() < 1e-9,
1284                "uniform design weight {w} != {expect}"
1285            );
1286        }
1287        assert_eq!(s.len(), budget);
1288    }
1289
1290    #[test]
1291    fn designed_subsample_oversamples_loud_rows_with_downweighted_loss() {
1292        // A loud row should be (nearly) always included — but with a SMALLER
1293        // likelihood weight (its π is larger), so inclusion does not bias the
1294        // criterion toward loud rows.
1295        let rows: Vec<Vec<f64>> = (0..50)
1296            .map(|i| vec![if i == 7 { 30.0 } else { 1.0 }])
1297            .collect();
1298        let u = factors_from_rows(&rows, 1, 1);
1299        let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1300        let measure = RowSamplingMeasure::from_metric(&metric);
1301        let s = measure.designed_subsample(10, 99);
1302        let pos = s.rows.iter().position(|&r| r == 7);
1303        assert!(pos.is_some(), "the dominant-mass row must be in the design");
1304        let w7 = s.likelihood_weights[pos.unwrap()];
1305        let w_other = s
1306            .likelihood_weights
1307            .iter()
1308            .enumerate()
1309            .filter(|&(k, _)| s.rows[k] != 7)
1310            .map(|(_, &w)| w)
1311            .next()
1312            .expect("some quiet row selected");
1313        assert!(
1314            w7 < w_other,
1315            "loud row weight {w7} must be below quiet row weight {w_other}"
1316        );
1317    }
1318
1319    fn coreset_dense_oracle(rows: &[Array2<f64>], coreset: &SpectralCoreset) -> Array2<f64> {
1320        let dim = rows[0].ncols();
1321        let mut approx = Array2::<f64>::zeros((dim, dim));
1322        for (&row, &weight) in coreset.indices.iter().zip(coreset.weights.iter()) {
1323            add_factor_gram_scaled(&mut approx, &rows[row], weight);
1324        }
1325        approx
1326    }
1327
1328    fn generalized_effective_spectrum(full: &Array2<f64>, approx: &Array2<f64>) -> Vec<f64> {
1329        let (evals, evecs) = full.eigh(Side::Lower).expect("oracle eigh");
1330        let max_eval = evals.iter().copied().fold(0.0_f64, f64::max);
1331        let tol = (full.ncols().max(1) as f64) * f64::EPSILON * max_eval.max(1.0) * 100.0;
1332        let kept: Vec<usize> = evals
1333            .iter()
1334            .enumerate()
1335            .filter_map(|(idx, &lambda)| (lambda > tol).then_some(idx))
1336            .collect();
1337        let mut whitener = Array2::<f64>::zeros((full.ncols(), kept.len()));
1338        for (col, &eig_idx) in kept.iter().enumerate() {
1339            let scale = 1.0 / evals[eig_idx].sqrt();
1340            for row in 0..full.ncols() {
1341                whitener[[row, col]] = evecs[[row, eig_idx]] * scale;
1342            }
1343        }
1344        let reduced = whitener.t().dot(approx).dot(&whitener);
1345        let (spectrum, _) = reduced.eigh(Side::Lower).expect("reduced oracle eigh");
1346        spectrum.to_vec()
1347    }
1348
1349    #[test]
1350    fn bss_planted_low_rank_rows_match_dense_oracle_spectrum() {
1351        let rows = vec![
1352            array![[1.0, 0.0, 0.0, 0.0]],
1353            array![[0.0, 2.0, 0.0, 0.0]],
1354            array![[1.0, 1.0, 0.0, 0.0]],
1355            array![[2.0, -1.0, 0.0, 0.0]],
1356            array![[0.5, 1.5, 0.0, 0.0]],
1357            array![[1.25, -0.25, 0.0, 0.0]],
1358        ];
1359        let eps = 0.35;
1360        let coreset = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), eps)
1361            .expect("BSS coreset");
1362        let full = summed_factor_gram(&rows, rows[0].ncols());
1363        let approx = coreset_dense_oracle(&rows, &coreset);
1364        let spectrum = generalized_effective_spectrum(&full, &approx);
1365
1366        assert_eq!(coreset.certificate.dim_effective, 2);
1367        assert_eq!(spectrum.len(), 2);
1368        for lambda in spectrum {
1369            assert!(
1370                lambda >= 1.0 - eps - 1e-8 && lambda <= 1.0 + eps + 1e-8,
1371                "coreset generalized eigenvalue {lambda} outside [{}, {}]",
1372                1.0 - eps,
1373                1.0 + eps
1374            );
1375        }
1376    }
1377
1378    #[test]
1379    fn bss_selects_single_row_carrying_unique_direction() {
1380        let rows = vec![
1381            array![[3.0, 0.0]],
1382            array![[2.0, 0.0]],
1383            array![[1.0, 0.0]],
1384            array![[0.0, 4.0]],
1385        ];
1386        let coreset = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.4)
1387            .expect("BSS coreset");
1388        assert!(
1389            coreset.indices.contains(&3),
1390            "the only row carrying direction e2 must be selected: {:?}",
1391            coreset.indices
1392        );
1393    }
1394
1395    #[test]
1396    fn bss_selection_is_deterministic() {
1397        let rows = vec![
1398            array![[1.0, 0.0, 0.0]],
1399            array![[0.0, 1.0, 0.0]],
1400            array![[0.0, 0.0, 1.0]],
1401            array![[1.0, 1.0, 0.0]],
1402            array![[0.0, 1.0, 1.0]],
1403        ];
1404        let a = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.45)
1405            .expect("first BSS coreset");
1406        let b = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.45)
1407            .expect("second BSS coreset");
1408        assert_eq!(a.indices, b.indices);
1409        assert_eq!(a.weights, b.weights);
1410        assert_eq!(a.certificate, b.certificate);
1411    }
1412
1413    #[test]
1414    fn certificate_reports_insufficient_margin_explicitly() {
1415        let certificate = CoresetCertificate::new(0.1, 0.25, 3, 5).expect("certificate");
1416        let required = certificate.race_transfer_margin();
1417        assert!(matches!(
1418            certificate.certify_margin(required),
1419            CoresetMarginVerdict::InsufficientMargin { .. }
1420        ));
1421        assert!(matches!(
1422            certificate.certify_margin(required + 1.0),
1423            CoresetMarginVerdict::Certified { .. }
1424        ));
1425    }
1426
1427    #[test]
1428    fn sensitivity_bounds_and_greedy_budget_are_deterministic() {
1429        let leverage = vec![0.2, 0.5, 0.5, 0.1];
1430        let sigma = sensitivity_upper_bounds(&leverage, 2.0, 0.25).expect("sigma");
1431        let expected = [0.3, 0.75, 0.75, 0.15];
1432        for (got, want) in sigma.iter().zip(expected.iter()) {
1433            assert!((got - want).abs() < 1e-12);
1434        }
1435        let selected = greedy_sensitivity_coreset(&sigma, 2).expect("greedy");
1436        assert_eq!(selected.indices, vec![1, 2]);
1437        assert!((selected.selected_sensitivity_mass - 1.5).abs() < 1e-12);
1438        assert!((selected.residual_sensitivity_mass - 0.45).abs() < 1e-12);
1439    }
1440
1441    /// #1012 certified designed subsample: the result carries a certificate
1442    /// whose race-transfer margin equals the certificate's, and the adversarial
1443    /// heavy-tail row (one row carrying the curvature signal) is FORCED into the
1444    /// coreset by the sensitivity bound — the classic uniform-subsampling miss.
1445    #[test]
1446    fn certified_subsample_forces_the_heavy_tail_row_and_carries_a_certificate() {
1447        // Five rows: four ordinary low-leverage rows and one heavy-tail row
1448        // (index 4) with an order-of-magnitude larger leverage and a unique
1449        // spectral direction e2.
1450        let row_factors = vec![
1451            array![[1.0, 0.0]],
1452            array![[1.0, 0.0]],
1453            array![[1.0, 0.0]],
1454            array![[1.0, 0.0]],
1455            array![[0.0, 5.0]],
1456        ];
1457        let leverage = vec![0.05, 0.05, 0.05, 0.05, 0.9];
1458        let measure = RowSamplingMeasure::uniform(5);
1459        let certified = measure
1460            .designed_subsample_certified(
1461                row_factors.iter().map(|r| r.view()),
1462                0.4,
1463                &leverage,
1464                1.0,
1465                0.1,
1466                1, // budget admits a single sensitivity row
1467            )
1468            .expect("certified subsample");
1469
1470        assert!(
1471            certified.rows.contains(&4),
1472            "the heavy-tail row carrying the curvature signal must be forced in: {:?}",
1473            certified.rows
1474        );
1475        assert_eq!(certified.rows.len(), certified.weights.len());
1476        // The race-transfer margin is the certificate's — the shared #1011/#1012
1477        // seam a race consumer gates on.
1478        assert!(
1479            (certified.race_transfer_margin() - certified.certificate.race_transfer_margin()).abs()
1480                < 1e-12
1481        );
1482        assert!(certified.certificate.race_transfer_margin() > 0.0);
1483        // The certificate's selected count matches the realized coreset.
1484        assert_eq!(certified.certificate.n_selected, certified.rows.len());
1485    }
1486}