1use gam_problem::{MetricProvenance, RowMetric};
60use gam_linalg::faer_ndarray::{FaerEigh, FaerSvd};
61use gam_linalg::utils::splitmix64_hash;
62use faer::Side;
63use ndarray::{Array2, ArrayView2};
64
65#[derive(Clone, Copy, PartialEq, Eq, Debug)]
68pub enum MeasureProvenance {
69 Uniform,
74 FisherMass(MetricProvenance),
79}
80
81#[derive(Clone, Debug)]
88pub struct RowSamplingMeasure {
89 provenance: MeasureProvenance,
90 weights: Vec<f64>,
93}
94
95#[derive(Clone, Copy, Debug, PartialEq)]
97pub struct CoresetCertificate {
98 pub eps_spectral: f64,
102 pub eps_likelihood: f64,
105 pub dim_effective: usize,
108 pub n_selected: usize,
110}
111
112impl CoresetCertificate {
113 pub fn new(
114 eps_spectral: f64,
115 eps_likelihood: f64,
116 dim_effective: usize,
117 n_selected: usize,
118 ) -> Result<Self, String> {
119 if !(eps_spectral.is_finite() && eps_spectral >= 0.0 && eps_spectral < 1.0) {
120 return Err(format!(
121 "coreset certificate requires 0 <= eps_spectral < 1, got {eps_spectral}"
122 ));
123 }
124 if !(eps_likelihood.is_finite() && eps_likelihood >= 0.0) {
125 return Err(format!(
126 "coreset certificate requires finite non-negative eps_likelihood, got {eps_likelihood}"
127 ));
128 }
129 Ok(Self {
130 eps_spectral,
131 eps_likelihood,
132 dim_effective,
133 n_selected,
134 })
135 }
136
137 pub fn logdet_error_bound(&self) -> f64 {
140 self.dim_effective as f64 * ((1.0 + self.eps_spectral) / (1.0 - self.eps_spectral)).ln()
141 }
142
143 pub fn race_transfer_margin(&self) -> f64 {
146 2.0 * (self.logdet_error_bound() + self.eps_likelihood)
147 }
148
149 pub fn certify_margin(&self, decision_margin: f64) -> CoresetMarginVerdict {
153 let required_margin = self.race_transfer_margin();
154 if decision_margin.is_finite() && decision_margin > required_margin {
155 CoresetMarginVerdict::Certified {
156 decision_margin,
157 required_margin,
158 }
159 } else {
160 CoresetMarginVerdict::InsufficientMargin {
161 decision_margin,
162 required_margin,
163 }
164 }
165 }
166}
167
168#[derive(Clone, Copy, Debug, PartialEq)]
170pub enum CoresetMarginVerdict {
171 Certified {
172 decision_margin: f64,
173 required_margin: f64,
174 },
175 InsufficientMargin {
176 decision_margin: f64,
177 required_margin: f64,
178 },
179}
180
181#[derive(Clone, Debug, PartialEq)]
183pub struct SpectralCoreset {
184 pub indices: Vec<usize>,
186 pub weights: Vec<f64>,
188 pub certificate: CoresetCertificate,
191}
192
193pub fn bss_spectral_coreset_certified<'a, I>(
204 rows: I,
205 target_eps: f64,
206) -> Result<SpectralCoreset, String>
207where
208 I: IntoIterator<Item = ArrayView2<'a, f64>>,
209{
210 if !(target_eps.is_finite() && target_eps > 0.0 && target_eps < 1.0) {
211 return Err(format!(
212 "BSS spectral coreset requires 0 < target_eps < 1, got {target_eps}"
213 ));
214 }
215
216 let factors = collect_row_factors(rows)?;
217 let n = factors.len();
218 if n == 0 {
219 let certificate = CoresetCertificate::new(target_eps, 0.0, 0, 0)?;
220 return Ok(SpectralCoreset {
221 indices: Vec::new(),
222 weights: Vec::new(),
223 certificate,
224 });
225 }
226
227 let ambient_dim = factors[0].ncols();
228 let effective = stacked_factor_whitener(&factors, ambient_dim)?;
229 let dim = effective.ncols();
230 if dim == 0 {
231 let certificate = CoresetCertificate::new(target_eps, 0.0, 0, 0)?;
232 return Ok(SpectralCoreset {
233 indices: Vec::new(),
234 weights: Vec::new(),
235 certificate,
236 });
237 }
238
239 let whitened = whiten_row_factors(&factors, &effective);
240 let eta = 0.5 * target_eps;
241 let steps = ((dim as f64) / (eta * eta)).ceil().max(dim as f64) as usize;
242 let delta_lower = 1.0_f64;
243 let delta_upper = (1.0 + eta) / (1.0 - eta);
244 let root = (steps as f64 * dim as f64).sqrt();
245 let mut barrier_matrix = Array2::<f64>::zeros((dim, dim));
246 let mut row_weights = vec![0.0_f64; n];
247
248 for step in 0..steps {
249 let lower = step as f64 - root;
250 let upper = delta_upper * (step as f64 + root);
251 let lower_next = lower + delta_lower;
252 let upper_next = upper + delta_upper;
253
254 let lower_inv = inverse_shifted_lower(&barrier_matrix, lower_next)?;
255 let upper_inv = inverse_shifted_upper(&barrier_matrix, upper_next)?;
256 let lower_denom = lower_potential(&barrier_matrix, lower_next)?
257 - lower_potential(&barrier_matrix, lower)?;
258 let upper_denom = upper_potential(&barrier_matrix, upper)?
259 - upper_potential(&barrier_matrix, upper_next)?;
260 if !(lower_denom.is_finite() && lower_denom > 0.0) {
261 return Err(format!(
262 "BSS lower potential denominator became invalid at step {step}: {lower_denom}"
263 ));
264 }
265 if !(upper_denom.is_finite() && upper_denom > 0.0) {
266 return Err(format!(
267 "BSS upper potential denominator became invalid at step {step}: {upper_denom}"
268 ));
269 }
270
271 let mut chosen: Option<(usize, f64, f64)> = None;
272 for (row, factor) in whitened.iter().enumerate() {
273 let lower_trace = trace_factor_quadratic(factor, &lower_inv);
274 let lower_trace_sq = trace_factor_quadratic_square(factor, &lower_inv);
275 let upper_trace = trace_factor_quadratic(factor, &upper_inv);
276 let upper_trace_sq = trace_factor_quadratic_square(factor, &upper_inv);
277 let lower_score = lower_trace_sq / lower_denom - lower_trace;
278 let upper_score = upper_trace_sq / upper_denom + upper_trace;
279 if lower_score.is_finite()
280 && upper_score.is_finite()
281 && lower_score > 0.0
282 && upper_score > 0.0
283 && lower_score + BSS_SCORE_TOL >= upper_score
284 {
285 match chosen {
286 None => chosen = Some((row, lower_score, upper_score)),
287 Some((best_row, best_lower, best_upper)) => {
288 let gap = lower_score - upper_score;
289 let best_gap = best_lower - best_upper;
290 if gap > best_gap + BSS_SCORE_TOL
291 || ((gap - best_gap).abs() <= BSS_SCORE_TOL && row < best_row)
292 {
293 chosen = Some((row, lower_score, upper_score));
294 }
295 }
296 }
297 }
298 }
299
300 let (row, lower_score, upper_score) = chosen
301 .ok_or_else(|| format!("BSS failed to find a barrier-admissible row at step {step}"))?;
302 let inv_step_weight = 0.5 * (lower_score + upper_score);
303 if !(inv_step_weight.is_finite() && inv_step_weight > 0.0) {
304 return Err(format!(
305 "BSS invalid inverse step weight at step {step}: {inv_step_weight}"
306 ));
307 }
308 let step_weight = 1.0 / inv_step_weight;
309 add_factor_gram_scaled(&mut barrier_matrix, &whitened[row], step_weight);
310 row_weights[row] += step_weight;
311 }
312
313 let lower_final = steps as f64 - root;
314 let upper_final = delta_upper * (steps as f64 + root);
315 let scale = 2.0 / (lower_final + upper_final);
316 let mut indexed: Vec<(usize, f64)> = row_weights
317 .into_iter()
318 .enumerate()
319 .filter_map(|(row, weight)| {
320 let scaled = weight * scale;
321 (scaled > 0.0).then_some((row, scaled))
322 })
323 .collect();
324 indexed.sort_by_key(|&(row, _)| row);
325 let indices: Vec<usize> = indexed.iter().map(|&(row, _)| row).collect();
326 let weights: Vec<f64> = indexed.iter().map(|&(_, weight)| weight).collect();
327 let certificate = CoresetCertificate::new(target_eps, 0.0, dim, indices.len())?;
328 Ok(SpectralCoreset {
329 indices,
330 weights,
331 certificate,
332 })
333}
334
335pub fn sensitivity_upper_bounds(
344 linear_anchor_leverage: &[f64],
345 kappa_hat: f64,
346 chart_radius: f64,
347) -> Result<Vec<f64>, String> {
348 if !(kappa_hat.is_finite() && kappa_hat >= 0.0) {
349 return Err(format!(
350 "sensitivity bounds require finite non-negative kappa_hat, got {kappa_hat}"
351 ));
352 }
353 if !(chart_radius.is_finite() && chart_radius >= 0.0) {
354 return Err(format!(
355 "sensitivity bounds require finite non-negative chart_radius, got {chart_radius}"
356 ));
357 }
358 let inflation = 1.0 + kappa_hat * chart_radius;
359 linear_anchor_leverage
360 .iter()
361 .enumerate()
362 .map(|(row, &lev)| {
363 if lev.is_finite() && lev >= 0.0 {
364 Ok(lev * inflation)
365 } else {
366 Err(format!(
367 "sensitivity leverage at row {row} must be finite and non-negative, got {lev}"
368 ))
369 }
370 })
371 .collect()
372}
373
374#[derive(Clone, Debug, PartialEq)]
376pub struct SensitivityCoreset {
377 pub indices: Vec<usize>,
379 pub selected_sensitivity_mass: f64,
381 pub residual_sensitivity_mass: f64,
384}
385
386pub fn greedy_sensitivity_coreset(
387 sigma_upper_bounds: &[f64],
388 budget: usize,
389) -> Result<SensitivityCoreset, String> {
390 let mut indexed = Vec::with_capacity(sigma_upper_bounds.len());
391 for (row, &sigma) in sigma_upper_bounds.iter().enumerate() {
392 if !(sigma.is_finite() && sigma >= 0.0) {
393 return Err(format!(
394 "sensitivity upper bound at row {row} must be finite and non-negative, got {sigma}"
395 ));
396 }
397 indexed.push((row, sigma));
398 }
399 indexed.sort_by(|&(row_a, sigma_a), &(row_b, sigma_b)| {
400 sigma_b
401 .partial_cmp(&sigma_a)
402 .unwrap_or(std::cmp::Ordering::Equal)
403 .then(row_a.cmp(&row_b))
404 });
405 let selected_len = budget.min(indexed.len());
406 let indices: Vec<usize> = indexed
407 .iter()
408 .take(selected_len)
409 .map(|&(row, _)| row)
410 .collect();
411 let selected_sensitivity_mass: f64 = indexed
412 .iter()
413 .take(selected_len)
414 .map(|&(_, sigma)| sigma)
415 .sum();
416 let residual_sensitivity_mass: f64 = indexed
417 .iter()
418 .skip(selected_len)
419 .map(|&(_, sigma)| sigma)
420 .sum();
421 Ok(SensitivityCoreset {
422 indices,
423 selected_sensitivity_mass,
424 residual_sensitivity_mass,
425 })
426}
427
428impl RowSamplingMeasure {
429 pub fn from_metric(metric: &RowMetric) -> Self {
441 let n = metric.n_rows();
442 if n == 0 {
443 return Self {
444 provenance: MeasureProvenance::Uniform,
445 weights: Vec::new(),
446 };
447 }
448
449 if matches!(metric.provenance(), MetricProvenance::Euclidean) {
454 return Self::uniform(n);
455 }
456
457 let mass = per_row_fisher_mass(metric);
458 Self::from_masses(metric.provenance(), mass)
459 }
460
461 pub fn uniform(n: usize) -> Self {
464 let w = if n == 0 { 0.0 } else { 1.0 / n as f64 };
465 Self {
466 provenance: MeasureProvenance::Uniform,
467 weights: vec![w; n],
468 }
469 }
470
471 pub fn from_masses(metric_provenance: MetricProvenance, masses: Vec<f64>) -> Self {
478 let n = masses.len();
479 if n == 0 {
480 return Self::uniform(0);
481 }
482 let mut total = 0.0_f64;
485 let mut clean = vec![0.0_f64; n];
486 let mut all_finite = true;
487 for (i, &m) in masses.iter().enumerate() {
488 if !m.is_finite() {
489 all_finite = false;
490 break;
491 }
492 let v = if m > 0.0 { m } else { 0.0 };
493 clean[i] = v;
494 total += v;
495 }
496
497 if !all_finite || !(total > 0.0) {
498 return Self::uniform(n);
500 }
501
502 let inv = 1.0 / total;
503 for w in clean.iter_mut() {
504 *w *= inv;
505 }
506 Self {
507 provenance: MeasureProvenance::FisherMass(metric_provenance),
508 weights: clean,
509 }
510 }
511
512 pub fn weights(&self) -> &[f64] {
515 &self.weights
516 }
517
518 pub fn provenance(&self) -> MeasureProvenance {
521 self.provenance
522 }
523
524 pub fn n_rows(&self) -> usize {
526 self.weights.len()
527 }
528
529 pub fn is_enriched(&self) -> bool {
532 matches!(self.provenance, MeasureProvenance::FisherMass(_))
533 }
534
535 pub fn enrichment_order(&self, count: usize, seed: u64) -> Vec<usize> {
558 let n = self.weights.len();
559 if n == 0 || count == 0 {
560 return Vec::new();
561 }
562
563 let u = {
566 let bits = splitmix64_hash(seed ^ ENRICHMENT_SALT);
567 let mantissa = (bits >> 11) as f64; mantissa / ((1_u64 << 53) as f64)
569 };
570
571 let mut cdf = vec![0.0_f64; n];
575 let mut acc = 0.0_f64;
576 for i in 0..n {
577 acc += self.weights[i];
578 cdf[i] = acc;
579 }
580 cdf[n - 1] = 1.0;
581
582 let mut out = Vec::with_capacity(count);
583 let step = 1.0 / count as f64;
584 let mut cursor = 0usize;
585 for j in 0..count {
586 let pointer = (j as f64 + u) * step;
587 while cursor < n - 1 && pointer > cdf[cursor] {
590 cursor += 1;
591 }
592 out.push(cursor);
593 }
594 out
595 }
596
597 pub fn expected_representation(&self, count: usize) -> Vec<f64> {
603 let c = count as f64;
604 self.weights.iter().map(|&w| c * w).collect()
605 }
606
607 pub fn designed_subsample(&self, budget: usize, seed: u64) -> DesignedRowSample {
643 let n = self.weights.len();
644 if n == 0 || budget == 0 {
645 return DesignedRowSample {
646 provenance: self.provenance,
647 rows: Vec::new(),
648 likelihood_weights: Vec::new(),
649 expected_size: 0.0,
650 };
651 }
652 if budget >= n {
653 return DesignedRowSample {
654 provenance: self.provenance,
655 rows: (0..n).collect(),
656 likelihood_weights: vec![1.0; n],
657 expected_size: n as f64,
658 };
659 }
660
661 let eps = DESIGNED_SAMPLE_UNIFORM_MIX;
663 let unif = 1.0 / n as f64;
664 let mixed: Vec<f64> = self
665 .weights
666 .iter()
667 .map(|&w| (1.0 - eps) * w + eps * unif)
668 .collect();
669
670 let mut order: Vec<usize> = (0..n).collect();
673 order.sort_by(|&a, &b| {
674 mixed[b]
675 .partial_cmp(&mixed[a])
676 .unwrap_or(std::cmp::Ordering::Equal)
677 .then(a.cmp(&b))
678 });
679 let total: f64 = mixed.iter().sum();
680 let mut capped = 0usize;
681 let mut tail_mass = total;
682 let mut tau = budget as f64 / tail_mass;
683 while capped < n {
684 let next = mixed[order[capped]];
685 if tau * next <= 1.0 {
686 break;
687 }
688 capped += 1;
690 tail_mass -= next;
691 let remaining_budget = budget as f64 - capped as f64;
692 if remaining_budget <= 0.0 || tail_mass <= 0.0 {
693 break;
694 }
695 tau = remaining_budget / tail_mass;
696 }
697 let mut pi = vec![0.0_f64; n];
698 for (rank, &i) in order.iter().enumerate() {
699 pi[i] = if rank < capped {
700 1.0
701 } else {
702 (tau * mixed[i]).min(1.0)
703 };
704 }
705
706 let u = {
710 let bits = splitmix64_hash(seed ^ DESIGNED_SAMPLE_SALT);
711 let mantissa = (bits >> 11) as f64;
712 mantissa / ((1_u64 << 53) as f64)
713 };
714 let mut rows = Vec::with_capacity(budget + 1);
715 let mut likelihood_weights = Vec::with_capacity(budget + 1);
716 let mut acc = 0.0_f64;
717 for (i, &p) in pi.iter().enumerate() {
718 let before = acc;
719 acc += p;
720 if (acc - u).floor() > (before - u).floor() {
722 rows.push(i);
723 likelihood_weights.push(1.0 / p);
724 }
725 }
726 DesignedRowSample {
727 provenance: self.provenance,
728 rows,
729 likelihood_weights,
730 expected_size: pi.iter().sum(),
731 }
732 }
733
734 pub fn designed_subsample_certified<'a, I>(
761 &self,
762 row_factors: I,
763 target_eps: f64,
764 leverage: &[f64],
765 kappa_hat: f64,
766 chart_radius: f64,
767 budget: usize,
768 ) -> Result<CertifiedRowSample, String>
769 where
770 I: IntoIterator<Item = ArrayView2<'a, f64>>,
771 {
772 let spectral = bss_spectral_coreset_certified(row_factors, target_eps)?;
774
775 let sigma = sensitivity_upper_bounds(leverage, kappa_hat, chart_radius)?;
778 let sensitivity = greedy_sensitivity_coreset(&sigma, budget)?;
779 let total_sensitivity =
780 sensitivity.selected_sensitivity_mass + sensitivity.residual_sensitivity_mass;
781 let eps_likelihood = if total_sensitivity > 0.0 {
782 sensitivity.residual_sensitivity_mass / total_sensitivity
783 } else {
784 0.0
785 };
786
787 let n = self.weights.len();
792 let bss_weight: std::collections::BTreeMap<usize, f64> = spectral
793 .indices
794 .iter()
795 .zip(spectral.weights.iter())
796 .map(|(&i, &w)| (i, w))
797 .collect();
798 let mut selected: std::collections::BTreeSet<usize> =
799 spectral.indices.iter().copied().collect();
800 for &i in &sensitivity.indices {
801 selected.insert(i);
802 }
803 let selected_len = selected.len().max(1);
804 let ht_scale = if n > 0 {
805 n as f64 / selected_len as f64
806 } else {
807 1.0
808 };
809
810 let rows: Vec<usize> = selected.iter().copied().collect();
811 let weights: Vec<f64> = rows
812 .iter()
813 .map(|i| *bss_weight.get(i).unwrap_or(&ht_scale))
814 .collect();
815
816 let certificate = CoresetCertificate::new(
817 spectral.certificate.eps_spectral,
818 eps_likelihood,
819 spectral.certificate.dim_effective,
820 rows.len(),
821 )?;
822
823 Ok(CertifiedRowSample {
824 provenance: self.provenance,
825 rows,
826 weights,
827 certificate,
828 })
829 }
830}
831
832#[derive(Clone, Debug)]
836pub struct DesignedRowSample {
837 pub provenance: MeasureProvenance,
840 pub rows: Vec<usize>,
842 pub likelihood_weights: Vec<f64>,
846 pub expected_size: f64,
849}
850
851impl DesignedRowSample {
852 pub fn len(&self) -> usize {
854 self.rows.len()
855 }
856
857 pub fn is_empty(&self) -> bool {
858 self.rows.is_empty()
859 }
860
861 pub fn estimated_corpus_rows(&self) -> f64 {
865 self.likelihood_weights.iter().sum()
866 }
867}
868
869#[derive(Clone, Debug)]
875pub struct CertifiedRowSample {
876 pub provenance: MeasureProvenance,
878 pub rows: Vec<usize>,
881 pub weights: Vec<f64>,
885 pub certificate: CoresetCertificate,
889}
890
891impl CertifiedRowSample {
892 pub fn len(&self) -> usize {
893 self.rows.len()
894 }
895
896 pub fn is_empty(&self) -> bool {
897 self.rows.is_empty()
898 }
899
900 pub fn race_transfer_margin(&self) -> f64 {
903 self.certificate.race_transfer_margin()
904 }
905}
906
907const DESIGNED_SAMPLE_UNIFORM_MIX: f64 = 0.1;
913
914const DESIGNED_SAMPLE_SALT: u64 = 0x73AD_0987_5EED_D51F;
917
918const ENRICHMENT_SALT: u64 = 0x980E_1C45_F00D_AC70;
921
922const BSS_SCORE_TOL: f64 = 1e-10;
923
924pub fn per_row_fisher_mass(metric: &RowMetric) -> Vec<f64> {
931 metric.row_traces().to_vec()
932}
933
934fn collect_row_factors<'a, I>(rows: I) -> Result<Vec<Array2<f64>>, String>
935where
936 I: IntoIterator<Item = ArrayView2<'a, f64>>,
937{
938 let mut out = Vec::new();
939 let mut ambient_dim: Option<usize> = None;
940 for (row, factor) in rows.into_iter().enumerate() {
941 if factor.iter().any(|value| !value.is_finite()) {
942 return Err(format!("BSS row factor {row} contains a non-finite value"));
943 }
944 match ambient_dim {
945 None => ambient_dim = Some(factor.ncols()),
946 Some(expected) if expected != factor.ncols() => {
947 return Err(format!(
948 "BSS row factor {row} has {} columns, expected {expected}",
949 factor.ncols()
950 ));
951 }
952 Some(_) => {}
953 }
954 out.push(factor.to_owned());
955 }
956 Ok(out)
957}
958
959fn stacked_factor_whitener(
960 factors: &[Array2<f64>],
961 ambient_dim: usize,
962) -> Result<Array2<f64>, String> {
963 let total_factor_rows: usize = factors.iter().map(|factor| factor.nrows()).sum();
964 if total_factor_rows == 0 || ambient_dim == 0 {
965 return Ok(Array2::<f64>::zeros((ambient_dim, 0)));
966 }
967
968 let mut stacked = Array2::<f64>::zeros((total_factor_rows, ambient_dim));
969 let mut cursor = 0usize;
970 for factor in factors {
971 for row in 0..factor.nrows() {
972 for col in 0..ambient_dim {
973 stacked[[cursor + row, col]] = factor[[row, col]];
974 }
975 }
976 cursor += factor.nrows();
977 }
978
979 let (_, singular, vt) = stacked
980 .svd(false, true)
981 .map_err(|err| format!("BSS stacked row-factor SVD failed: {err}"))?;
982 let vt = vt.ok_or_else(|| "BSS stacked row-factor SVD did not return Vt".to_string())?;
983 let max_sigma = singular.iter().copied().fold(0.0_f64, f64::max);
984 if !(max_sigma.is_finite() && max_sigma >= 0.0) {
985 return Err("BSS stacked row sketch has invalid singular values".to_string());
986 }
987 let tol = (ambient_dim.max(1) as f64) * f64::EPSILON * max_sigma.max(1.0) * 100.0;
988 let kept: Vec<usize> = singular
989 .iter()
990 .enumerate()
991 .filter_map(|(idx, &sigma)| (sigma > tol).then_some(idx))
992 .collect();
993 let mut whitener = Array2::<f64>::zeros((ambient_dim, kept.len()));
994 for (out_col, &sv_col) in kept.iter().enumerate() {
995 let scale = 1.0 / singular[sv_col];
996 for ambient_col in 0..ambient_dim {
997 whitener[[ambient_col, out_col]] = vt[[sv_col, ambient_col]] * scale;
998 }
999 }
1000 Ok(whitener)
1001}
1002
1003fn whiten_row_factors(factors: &[Array2<f64>], whitener: &Array2<f64>) -> Vec<Array2<f64>> {
1004 factors.iter().map(|factor| factor.dot(whitener)).collect()
1005}
1006
1007fn inverse_shifted_lower(matrix: &Array2<f64>, lower: f64) -> Result<Array2<f64>, String> {
1008 let n = matrix.nrows();
1009 let mut shifted = matrix.clone();
1010 for i in 0..n {
1011 shifted[[i, i]] -= lower;
1012 }
1013 inverse_symmetric_positive(&shifted, "BSS lower barrier inverse")
1014}
1015
1016fn inverse_shifted_upper(matrix: &Array2<f64>, upper: f64) -> Result<Array2<f64>, String> {
1017 let n = matrix.nrows();
1018 let mut shifted = Array2::<f64>::zeros((n, n));
1019 for i in 0..n {
1020 shifted[[i, i]] = upper;
1021 }
1022 for i in 0..n {
1023 for j in 0..n {
1024 shifted[[i, j]] -= matrix[[i, j]];
1025 }
1026 }
1027 inverse_symmetric_positive(&shifted, "BSS upper barrier inverse")
1028}
1029
1030fn inverse_symmetric_positive(matrix: &Array2<f64>, context: &str) -> Result<Array2<f64>, String> {
1031 let (evals, evecs) = matrix
1032 .eigh(Side::Lower)
1033 .map_err(|err| format!("{context} eigendecomposition failed: {err}"))?;
1034 let n = matrix.nrows();
1035 let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
1036 let tol = (n.max(1) as f64) * f64::EPSILON * max_eval * 100.0;
1037 let mut inv = Array2::<f64>::zeros((n, n));
1038 for k in 0..n {
1039 let lambda = evals[k];
1040 if !(lambda.is_finite() && lambda > tol) {
1041 return Err(format!(
1042 "{context} expected a positive barrier matrix, eigenvalue {k} was {lambda}"
1043 ));
1044 }
1045 let inv_lambda = 1.0 / lambda;
1046 for i in 0..n {
1047 for j in 0..n {
1048 inv[[i, j]] += evecs[[i, k]] * inv_lambda * evecs[[j, k]];
1049 }
1050 }
1051 }
1052 Ok(inv)
1053}
1054
1055fn lower_potential(matrix: &Array2<f64>, lower: f64) -> Result<f64, String> {
1056 let inv = inverse_shifted_lower(matrix, lower)?;
1057 Ok((0..inv.nrows()).map(|i| inv[[i, i]]).sum())
1058}
1059
1060fn upper_potential(matrix: &Array2<f64>, upper: f64) -> Result<f64, String> {
1061 let inv = inverse_shifted_upper(matrix, upper)?;
1062 Ok((0..inv.nrows()).map(|i| inv[[i, i]]).sum())
1063}
1064
1065fn trace_factor_quadratic(factor: &Array2<f64>, matrix: &Array2<f64>) -> f64 {
1066 let mut trace = 0.0_f64;
1067 for row in 0..factor.nrows() {
1068 for i in 0..factor.ncols() {
1069 let xi = factor[[row, i]];
1070 if xi == 0.0 {
1071 continue;
1072 }
1073 for j in 0..factor.ncols() {
1074 trace += xi * matrix[[i, j]] * factor[[row, j]];
1075 }
1076 }
1077 }
1078 trace
1079}
1080
1081fn trace_factor_quadratic_square(factor: &Array2<f64>, matrix: &Array2<f64>) -> f64 {
1082 let mut trace = 0.0_f64;
1083 for row in 0..factor.nrows() {
1084 for i in 0..factor.ncols() {
1085 let mut v = 0.0_f64;
1086 for j in 0..factor.ncols() {
1087 v += matrix[[i, j]] * factor[[row, j]];
1088 }
1089 trace += v * v;
1090 }
1091 }
1092 trace
1093}
1094
1095fn add_factor_gram_scaled(target: &mut Array2<f64>, factor: &Array2<f64>, scale: f64) {
1096 let dim = factor.ncols();
1097 for row in 0..factor.nrows() {
1098 for i in 0..dim {
1099 let xi = factor[[row, i]];
1100 if xi == 0.0 {
1101 continue;
1102 }
1103 for j in 0..dim {
1104 target[[i, j]] += scale * xi * factor[[row, j]];
1105 }
1106 }
1107 }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use super::*;
1113 use ndarray::Array2;
1114 use ndarray::array;
1115 use std::sync::Arc;
1116
1117 fn summed_factor_gram(factors: &[Array2<f64>], ambient_dim: usize) -> Array2<f64> {
1118 let mut total = Array2::<f64>::zeros((ambient_dim, ambient_dim));
1119 for factor in factors {
1120 add_factor_gram_scaled(&mut total, factor, 1.0);
1121 }
1122 total
1123 }
1124
1125 fn factors_from_rows(rows: &[Vec<f64>], p: usize, rank: usize) -> Arc<Array2<f64>> {
1126 let n = rows.len();
1127 let mut u = Array2::<f64>::zeros((n, p * rank));
1128 for (r, row) in rows.iter().enumerate() {
1129 for (c, &v) in row.iter().enumerate() {
1130 u[[r, c]] = v;
1131 }
1132 }
1133 Arc::new(u)
1134 }
1135
1136 #[test]
1137 fn euclidean_degrades_to_uniform() {
1138 let metric = RowMetric::euclidean(5, 3).expect("euclidean");
1139 let measure = RowSamplingMeasure::from_metric(&metric);
1140 assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
1141 assert!(!measure.is_enriched());
1142 for &w in measure.weights() {
1143 assert!((w - 0.2).abs() < 1e-12);
1144 }
1145 }
1146
1147 #[test]
1148 fn weights_normalize_to_one_and_track_mass() {
1149 let rows = vec![vec![1.0], vec![1.0], vec![3.0], vec![1.0]];
1151 let u = factors_from_rows(&rows, 1, 1);
1152 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1153 let measure = RowSamplingMeasure::from_metric(&metric);
1154 assert!(measure.is_enriched());
1155 let w = measure.weights();
1156 let sum: f64 = w.iter().sum();
1157 assert!((sum - 1.0).abs() < 1e-12);
1158 assert!((w[0] - 1.0 / 12.0).abs() < 1e-12);
1160 assert!((w[2] - 9.0 / 12.0).abs() < 1e-12);
1161 assert!(w[2] > w[0] * 8.0);
1162 }
1163
1164 #[test]
1165 fn all_zero_mass_degrades_to_uniform() {
1166 let rows = vec![vec![0.0], vec![0.0], vec![0.0]];
1167 let u = factors_from_rows(&rows, 1, 1);
1168 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1169 let measure = RowSamplingMeasure::from_metric(&metric);
1170 assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
1171 for &w in measure.weights() {
1172 assert!((w - 1.0 / 3.0).abs() < 1e-12);
1173 }
1174 }
1175
1176 #[test]
1177 fn enrichment_order_is_deterministic() {
1178 let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
1179 let u = factors_from_rows(&rows, 1, 1);
1180 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1181 let measure = RowSamplingMeasure::from_metric(&metric);
1182 let a = measure.enrichment_order(20, 7);
1183 let b = measure.enrichment_order(20, 7);
1184 assert_eq!(a, b, "same seed must give identical ordering");
1185 let c = measure.enrichment_order(20, 8);
1186 assert_eq!(c.len(), 20);
1188 }
1189
1190 #[test]
1191 fn enrichment_oversamples_loud_row() {
1192 let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
1194 let u = factors_from_rows(&rows, 1, 1);
1195 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1196 let measure = RowSamplingMeasure::from_metric(&metric);
1197 let count = 110;
1198 let order = measure.enrichment_order(count, 1);
1199 let loud = order.iter().filter(|&&r| r == 1).count();
1200 let quiet0 = order.iter().filter(|&&r| r == 0).count();
1201 assert!(
1203 loud > quiet0 * 5,
1204 "loud row must be oversampled: loud={loud} quiet0={quiet0}"
1205 );
1206 }
1207
1208 #[test]
1209 fn expected_representation_matches_count_times_weight() {
1210 let rows = vec![vec![1.0], vec![3.0]];
1211 let u = factors_from_rows(&rows, 1, 1);
1212 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1213 let measure = RowSamplingMeasure::from_metric(&metric);
1214 let rep = measure.expected_representation(10);
1215 assert!((rep[0] - 1.0).abs() < 1e-12);
1217 assert!((rep[1] - 9.0).abs() < 1e-12);
1218 }
1219
1220 #[test]
1221 fn designed_subsample_is_deterministic_and_honest() {
1222 let n = 200usize;
1226 let rows: Vec<Vec<f64>> = (0..n)
1227 .map(|i| vec![if i % 10 == 0 { 3.0 } else { 1.0 }])
1228 .collect();
1229 let u = factors_from_rows(&rows, 1, 1);
1230 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1231 let measure = RowSamplingMeasure::from_metric(&metric);
1232
1233 let budget = 40usize;
1234 let a = measure.designed_subsample(budget, 17);
1235 let b = measure.designed_subsample(budget, 17);
1236 assert_eq!(a.rows, b.rows, "same seed must give the identical design");
1237 assert_eq!(a.likelihood_weights, b.likelihood_weights);
1238
1239 assert!((a.expected_size - budget as f64).abs() < 1e-9);
1241 assert!(a.len() == budget || a.len() == budget + 1 || a.len() + 1 == budget);
1242
1243 let est = a.estimated_corpus_rows();
1247 assert!(
1248 (est - n as f64).abs() < 0.25 * n as f64,
1249 "HT corpus estimate {est} too far from n = {n}"
1250 );
1251
1252 assert!(a.rows.windows(2).all(|w| w[0] < w[1]));
1254 assert!(
1255 a.likelihood_weights
1256 .iter()
1257 .all(|&w| w.is_finite() && w >= 1.0 - 1e-12)
1258 );
1259 }
1260
1261 #[test]
1262 fn designed_subsample_full_budget_is_the_exact_pass() {
1263 let measure = RowSamplingMeasure::uniform(7);
1264 let s = measure.designed_subsample(7, 3);
1265 assert_eq!(s.rows, (0..7).collect::<Vec<_>>());
1266 assert!(s.likelihood_weights.iter().all(|&w| w == 1.0));
1267 let s = measure.designed_subsample(100, 3);
1268 assert_eq!(s.rows.len(), 7);
1269 }
1270
1271 #[test]
1272 fn designed_subsample_uniform_measure_gives_flat_weights() {
1273 let n = 120usize;
1276 let budget = 30usize;
1277 let measure = RowSamplingMeasure::uniform(n);
1278 let s = measure.designed_subsample(budget, 5);
1279 assert_eq!(s.provenance, MeasureProvenance::Uniform);
1280 let expect = n as f64 / budget as f64;
1281 for &w in &s.likelihood_weights {
1282 assert!(
1283 (w - expect).abs() < 1e-9,
1284 "uniform design weight {w} != {expect}"
1285 );
1286 }
1287 assert_eq!(s.len(), budget);
1288 }
1289
1290 #[test]
1291 fn designed_subsample_oversamples_loud_rows_with_downweighted_loss() {
1292 let rows: Vec<Vec<f64>> = (0..50)
1296 .map(|i| vec![if i == 7 { 30.0 } else { 1.0 }])
1297 .collect();
1298 let u = factors_from_rows(&rows, 1, 1);
1299 let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
1300 let measure = RowSamplingMeasure::from_metric(&metric);
1301 let s = measure.designed_subsample(10, 99);
1302 let pos = s.rows.iter().position(|&r| r == 7);
1303 assert!(pos.is_some(), "the dominant-mass row must be in the design");
1304 let w7 = s.likelihood_weights[pos.unwrap()];
1305 let w_other = s
1306 .likelihood_weights
1307 .iter()
1308 .enumerate()
1309 .filter(|&(k, _)| s.rows[k] != 7)
1310 .map(|(_, &w)| w)
1311 .next()
1312 .expect("some quiet row selected");
1313 assert!(
1314 w7 < w_other,
1315 "loud row weight {w7} must be below quiet row weight {w_other}"
1316 );
1317 }
1318
1319 fn coreset_dense_oracle(rows: &[Array2<f64>], coreset: &SpectralCoreset) -> Array2<f64> {
1320 let dim = rows[0].ncols();
1321 let mut approx = Array2::<f64>::zeros((dim, dim));
1322 for (&row, &weight) in coreset.indices.iter().zip(coreset.weights.iter()) {
1323 add_factor_gram_scaled(&mut approx, &rows[row], weight);
1324 }
1325 approx
1326 }
1327
1328 fn generalized_effective_spectrum(full: &Array2<f64>, approx: &Array2<f64>) -> Vec<f64> {
1329 let (evals, evecs) = full.eigh(Side::Lower).expect("oracle eigh");
1330 let max_eval = evals.iter().copied().fold(0.0_f64, f64::max);
1331 let tol = (full.ncols().max(1) as f64) * f64::EPSILON * max_eval.max(1.0) * 100.0;
1332 let kept: Vec<usize> = evals
1333 .iter()
1334 .enumerate()
1335 .filter_map(|(idx, &lambda)| (lambda > tol).then_some(idx))
1336 .collect();
1337 let mut whitener = Array2::<f64>::zeros((full.ncols(), kept.len()));
1338 for (col, &eig_idx) in kept.iter().enumerate() {
1339 let scale = 1.0 / evals[eig_idx].sqrt();
1340 for row in 0..full.ncols() {
1341 whitener[[row, col]] = evecs[[row, eig_idx]] * scale;
1342 }
1343 }
1344 let reduced = whitener.t().dot(approx).dot(&whitener);
1345 let (spectrum, _) = reduced.eigh(Side::Lower).expect("reduced oracle eigh");
1346 spectrum.to_vec()
1347 }
1348
1349 #[test]
1350 fn bss_planted_low_rank_rows_match_dense_oracle_spectrum() {
1351 let rows = vec![
1352 array![[1.0, 0.0, 0.0, 0.0]],
1353 array![[0.0, 2.0, 0.0, 0.0]],
1354 array![[1.0, 1.0, 0.0, 0.0]],
1355 array![[2.0, -1.0, 0.0, 0.0]],
1356 array![[0.5, 1.5, 0.0, 0.0]],
1357 array![[1.25, -0.25, 0.0, 0.0]],
1358 ];
1359 let eps = 0.35;
1360 let coreset = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), eps)
1361 .expect("BSS coreset");
1362 let full = summed_factor_gram(&rows, rows[0].ncols());
1363 let approx = coreset_dense_oracle(&rows, &coreset);
1364 let spectrum = generalized_effective_spectrum(&full, &approx);
1365
1366 assert_eq!(coreset.certificate.dim_effective, 2);
1367 assert_eq!(spectrum.len(), 2);
1368 for lambda in spectrum {
1369 assert!(
1370 lambda >= 1.0 - eps - 1e-8 && lambda <= 1.0 + eps + 1e-8,
1371 "coreset generalized eigenvalue {lambda} outside [{}, {}]",
1372 1.0 - eps,
1373 1.0 + eps
1374 );
1375 }
1376 }
1377
1378 #[test]
1379 fn bss_selects_single_row_carrying_unique_direction() {
1380 let rows = vec![
1381 array![[3.0, 0.0]],
1382 array![[2.0, 0.0]],
1383 array![[1.0, 0.0]],
1384 array![[0.0, 4.0]],
1385 ];
1386 let coreset = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.4)
1387 .expect("BSS coreset");
1388 assert!(
1389 coreset.indices.contains(&3),
1390 "the only row carrying direction e2 must be selected: {:?}",
1391 coreset.indices
1392 );
1393 }
1394
1395 #[test]
1396 fn bss_selection_is_deterministic() {
1397 let rows = vec![
1398 array![[1.0, 0.0, 0.0]],
1399 array![[0.0, 1.0, 0.0]],
1400 array![[0.0, 0.0, 1.0]],
1401 array![[1.0, 1.0, 0.0]],
1402 array![[0.0, 1.0, 1.0]],
1403 ];
1404 let a = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.45)
1405 .expect("first BSS coreset");
1406 let b = bss_spectral_coreset_certified(rows.iter().map(|row| row.view()), 0.45)
1407 .expect("second BSS coreset");
1408 assert_eq!(a.indices, b.indices);
1409 assert_eq!(a.weights, b.weights);
1410 assert_eq!(a.certificate, b.certificate);
1411 }
1412
1413 #[test]
1414 fn certificate_reports_insufficient_margin_explicitly() {
1415 let certificate = CoresetCertificate::new(0.1, 0.25, 3, 5).expect("certificate");
1416 let required = certificate.race_transfer_margin();
1417 assert!(matches!(
1418 certificate.certify_margin(required),
1419 CoresetMarginVerdict::InsufficientMargin { .. }
1420 ));
1421 assert!(matches!(
1422 certificate.certify_margin(required + 1.0),
1423 CoresetMarginVerdict::Certified { .. }
1424 ));
1425 }
1426
1427 #[test]
1428 fn sensitivity_bounds_and_greedy_budget_are_deterministic() {
1429 let leverage = vec![0.2, 0.5, 0.5, 0.1];
1430 let sigma = sensitivity_upper_bounds(&leverage, 2.0, 0.25).expect("sigma");
1431 let expected = [0.3, 0.75, 0.75, 0.15];
1432 for (got, want) in sigma.iter().zip(expected.iter()) {
1433 assert!((got - want).abs() < 1e-12);
1434 }
1435 let selected = greedy_sensitivity_coreset(&sigma, 2).expect("greedy");
1436 assert_eq!(selected.indices, vec![1, 2]);
1437 assert!((selected.selected_sensitivity_mass - 1.5).abs() < 1e-12);
1438 assert!((selected.residual_sensitivity_mass - 0.45).abs() < 1e-12);
1439 }
1440
1441 #[test]
1446 fn certified_subsample_forces_the_heavy_tail_row_and_carries_a_certificate() {
1447 let row_factors = vec![
1451 array![[1.0, 0.0]],
1452 array![[1.0, 0.0]],
1453 array![[1.0, 0.0]],
1454 array![[1.0, 0.0]],
1455 array![[0.0, 5.0]],
1456 ];
1457 let leverage = vec![0.05, 0.05, 0.05, 0.05, 0.9];
1458 let measure = RowSamplingMeasure::uniform(5);
1459 let certified = measure
1460 .designed_subsample_certified(
1461 row_factors.iter().map(|r| r.view()),
1462 0.4,
1463 &leverage,
1464 1.0,
1465 0.1,
1466 1, )
1468 .expect("certified subsample");
1469
1470 assert!(
1471 certified.rows.contains(&4),
1472 "the heavy-tail row carrying the curvature signal must be forced in: {:?}",
1473 certified.rows
1474 );
1475 assert_eq!(certified.rows.len(), certified.weights.len());
1476 assert!(
1479 (certified.race_transfer_margin() - certified.certificate.race_transfer_margin()).abs()
1480 < 1e-12
1481 );
1482 assert!(certified.certificate.race_transfer_margin() > 0.0);
1483 assert_eq!(certified.certificate.n_selected, certified.rows.len());
1485 }
1486}