use crate::inference::row_metric::{MetricProvenance, RowMetric};
use crate::linalg::utils::splitmix64_hash;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum MeasureProvenance {
Uniform,
FisherMass(MetricProvenance),
}
#[derive(Clone, Debug)]
pub struct RowMeasure {
provenance: MeasureProvenance,
weights: Vec<f64>,
}
impl RowMeasure {
pub fn from_metric(metric: &RowMetric) -> Self {
let n = metric.n_rows();
if n == 0 {
return Self {
provenance: MeasureProvenance::Uniform,
weights: Vec::new(),
};
}
if matches!(metric.provenance(), MetricProvenance::Euclidean) {
return Self::uniform(n);
}
let mass = per_row_fisher_mass(metric);
Self::from_masses(metric.provenance(), mass)
}
pub fn uniform(n: usize) -> Self {
let w = if n == 0 { 0.0 } else { 1.0 / n as f64 };
Self {
provenance: MeasureProvenance::Uniform,
weights: vec![w; n],
}
}
pub(crate) fn from_masses(metric_provenance: MetricProvenance, masses: Vec<f64>) -> Self {
let n = masses.len();
if n == 0 {
return Self::uniform(0);
}
let mut total = 0.0_f64;
let mut clean = vec![0.0_f64; n];
let mut all_finite = true;
for (i, &m) in masses.iter().enumerate() {
if !m.is_finite() {
all_finite = false;
break;
}
let v = if m > 0.0 { m } else { 0.0 };
clean[i] = v;
total += v;
}
if !all_finite || !(total > 0.0) {
return Self::uniform(n);
}
let inv = 1.0 / total;
for w in clean.iter_mut() {
*w *= inv;
}
Self {
provenance: MeasureProvenance::FisherMass(metric_provenance),
weights: clean,
}
}
pub fn weights(&self) -> &[f64] {
&self.weights
}
pub fn provenance(&self) -> MeasureProvenance {
self.provenance
}
pub fn n_rows(&self) -> usize {
self.weights.len()
}
pub fn is_enriched(&self) -> bool {
matches!(self.provenance, MeasureProvenance::FisherMass(_))
}
pub fn enrichment_order(&self, count: usize, seed: u64) -> Vec<usize> {
let n = self.weights.len();
if n == 0 || count == 0 {
return Vec::new();
}
let u = {
let bits = splitmix64_hash(seed ^ ENRICHMENT_SALT);
let mantissa = (bits >> 11) as f64; mantissa / ((1_u64 << 53) as f64)
};
let mut cdf = vec![0.0_f64; n];
let mut acc = 0.0_f64;
for i in 0..n {
acc += self.weights[i];
cdf[i] = acc;
}
cdf[n - 1] = 1.0;
let mut out = Vec::with_capacity(count);
let step = 1.0 / count as f64;
let mut cursor = 0usize;
for j in 0..count {
let pointer = (j as f64 + u) * step;
while cursor < n - 1 && pointer > cdf[cursor] {
cursor += 1;
}
out.push(cursor);
}
out
}
pub fn expected_representation(&self, count: usize) -> Vec<f64> {
let c = count as f64;
self.weights.iter().map(|&w| c * w).collect()
}
pub fn designed_subsample(&self, budget: usize, seed: u64) -> DesignedRowSample {
let n = self.weights.len();
if n == 0 || budget == 0 {
return DesignedRowSample {
provenance: self.provenance,
rows: Vec::new(),
likelihood_weights: Vec::new(),
expected_size: 0.0,
};
}
if budget >= n {
return DesignedRowSample {
provenance: self.provenance,
rows: (0..n).collect(),
likelihood_weights: vec![1.0; n],
expected_size: n as f64,
};
}
let eps = DESIGNED_SAMPLE_UNIFORM_MIX;
let unif = 1.0 / n as f64;
let mixed: Vec<f64> = self
.weights
.iter()
.map(|&w| (1.0 - eps) * w + eps * unif)
.collect();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
mixed[b]
.partial_cmp(&mixed[a])
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.cmp(&b))
});
let total: f64 = mixed.iter().sum();
let mut capped = 0usize;
let mut tail_mass = total;
let mut tau = budget as f64 / tail_mass;
while capped < n {
let next = mixed[order[capped]];
if tau * next <= 1.0 {
break;
}
capped += 1;
tail_mass -= next;
let remaining_budget = budget as f64 - capped as f64;
if remaining_budget <= 0.0 || tail_mass <= 0.0 {
break;
}
tau = remaining_budget / tail_mass;
}
let mut pi = vec![0.0_f64; n];
for (rank, &i) in order.iter().enumerate() {
pi[i] = if rank < capped {
1.0
} else {
(tau * mixed[i]).min(1.0)
};
}
let u = {
let bits = splitmix64_hash(seed ^ DESIGNED_SAMPLE_SALT);
let mantissa = (bits >> 11) as f64;
mantissa / ((1_u64 << 53) as f64)
};
let mut rows = Vec::with_capacity(budget + 1);
let mut likelihood_weights = Vec::with_capacity(budget + 1);
let mut acc = 0.0_f64;
for (i, &p) in pi.iter().enumerate() {
let before = acc;
acc += p;
if (acc - u).floor() > (before - u).floor() {
rows.push(i);
likelihood_weights.push(1.0 / p);
}
}
DesignedRowSample {
provenance: self.provenance,
rows,
likelihood_weights,
expected_size: pi.iter().sum(),
}
}
}
#[derive(Clone, Debug)]
pub struct DesignedRowSample {
pub provenance: MeasureProvenance,
pub rows: Vec<usize>,
pub likelihood_weights: Vec<f64>,
pub expected_size: f64,
}
impl DesignedRowSample {
pub fn len(&self) -> usize {
self.rows.len()
}
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
pub fn estimated_corpus_rows(&self) -> f64 {
self.likelihood_weights.iter().sum()
}
}
const DESIGNED_SAMPLE_UNIFORM_MIX: f64 = 0.1;
const DESIGNED_SAMPLE_SALT: u64 = 0x73AD_0987_5EED_D51F;
const ENRICHMENT_SALT: u64 = 0x980E_1C45_F00D_AC70;
pub(crate) fn per_row_fisher_mass(metric: &RowMetric) -> Vec<f64> {
let blocks = metric.blocks();
let n = metric.n_rows();
let p = metric.p_out();
let mut mass = vec![0.0_f64; n];
for row in 0..n {
let mut tr = 0.0_f64;
for i in 0..p {
tr += blocks[[row, i, i]];
}
mass[row] = tr;
}
mass
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use std::sync::Arc;
fn factors_from_rows(rows: &[Vec<f64>], p: usize, rank: usize) -> Arc<Array2<f64>> {
let n = rows.len();
let mut u = Array2::<f64>::zeros((n, p * rank));
for (r, row) in rows.iter().enumerate() {
for (c, &v) in row.iter().enumerate() {
u[[r, c]] = v;
}
}
Arc::new(u)
}
#[test]
fn euclidean_degrades_to_uniform() {
let metric = RowMetric::euclidean(5, 3).expect("euclidean");
let measure = RowMeasure::from_metric(&metric);
assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
assert!(!measure.is_enriched());
for &w in measure.weights() {
assert!((w - 0.2).abs() < 1e-12);
}
}
#[test]
fn weights_normalize_to_one_and_track_mass() {
let rows = vec![vec![1.0], vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
assert!(measure.is_enriched());
let w = measure.weights();
let sum: f64 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-12);
assert!((w[0] - 1.0 / 12.0).abs() < 1e-12);
assert!((w[2] - 9.0 / 12.0).abs() < 1e-12);
assert!(w[2] > w[0] * 8.0);
}
#[test]
fn all_zero_mass_degrades_to_uniform() {
let rows = vec![vec![0.0], vec![0.0], vec![0.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
for &w in measure.weights() {
assert!((w - 1.0 / 3.0).abs() < 1e-12);
}
}
#[test]
fn enrichment_order_is_deterministic() {
let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let a = measure.enrichment_order(20, 7);
let b = measure.enrichment_order(20, 7);
assert_eq!(a, b, "same seed must give identical ordering");
let c = measure.enrichment_order(20, 8);
assert_eq!(c.len(), 20);
}
#[test]
fn enrichment_oversamples_loud_row() {
let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let count = 110;
let order = measure.enrichment_order(count, 1);
let loud = order.iter().filter(|&&r| r == 1).count();
let quiet0 = order.iter().filter(|&&r| r == 0).count();
assert!(
loud > quiet0 * 5,
"loud row must be oversampled: loud={loud} quiet0={quiet0}"
);
}
#[test]
fn expected_representation_matches_count_times_weight() {
let rows = vec![vec![1.0], vec![3.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let rep = measure.expected_representation(10);
assert!((rep[0] - 1.0).abs() < 1e-12);
assert!((rep[1] - 9.0).abs() < 1e-12);
}
#[test]
fn designed_subsample_is_deterministic_and_honest() {
let n = 200usize;
let rows: Vec<Vec<f64>> = (0..n)
.map(|i| vec![if i % 10 == 0 { 3.0 } else { 1.0 }])
.collect();
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let budget = 40usize;
let a = measure.designed_subsample(budget, 17);
let b = measure.designed_subsample(budget, 17);
assert_eq!(a.rows, b.rows, "same seed must give the identical design");
assert_eq!(a.likelihood_weights, b.likelihood_weights);
assert!((a.expected_size - budget as f64).abs() < 1e-9);
assert!(a.len() == budget || a.len() == budget + 1 || a.len() + 1 == budget);
let est = a.estimated_corpus_rows();
assert!(
(est - n as f64).abs() < 0.25 * n as f64,
"HT corpus estimate {est} too far from n = {n}"
);
assert!(a.rows.windows(2).all(|w| w[0] < w[1]));
assert!(
a.likelihood_weights
.iter()
.all(|&w| w.is_finite() && w >= 1.0 - 1e-12)
);
}
#[test]
fn designed_subsample_full_budget_is_the_exact_pass() {
let measure = RowMeasure::uniform(7);
let s = measure.designed_subsample(7, 3);
assert_eq!(s.rows, (0..7).collect::<Vec<_>>());
assert!(s.likelihood_weights.iter().all(|&w| w == 1.0));
let s = measure.designed_subsample(100, 3);
assert_eq!(s.rows.len(), 7);
}
#[test]
fn designed_subsample_uniform_measure_gives_flat_weights() {
let n = 120usize;
let budget = 30usize;
let measure = RowMeasure::uniform(n);
let s = measure.designed_subsample(budget, 5);
assert_eq!(s.provenance, MeasureProvenance::Uniform);
let expect = n as f64 / budget as f64;
for &w in &s.likelihood_weights {
assert!(
(w - expect).abs() < 1e-9,
"uniform design weight {w} != {expect}"
);
}
assert_eq!(s.len(), budget);
}
#[test]
fn designed_subsample_oversamples_loud_rows_with_downweighted_loss() {
let rows: Vec<Vec<f64>> = (0..50)
.map(|i| vec![if i == 7 { 30.0 } else { 1.0 }])
.collect();
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let s = measure.designed_subsample(10, 99);
let pos = s.rows.iter().position(|&r| r == 7);
assert!(pos.is_some(), "the dominant-mass row must be in the design");
let w7 = s.likelihood_weights[pos.unwrap()];
let w_other = s
.likelihood_weights
.iter()
.enumerate()
.filter(|&(k, _)| s.rows[k] != 7)
.map(|(_, &w)| w)
.next()
.expect("some quiet row selected");
assert!(
w7 < w_other,
"loud row weight {w7} must be below quiet row weight {w_other}"
);
}
}