use crate::inference::row_metric::{MetricProvenance, RowMetric};
use crate::linalg::utils::splitmix64_hash;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum MeasureProvenance {
Uniform,
FisherMass(MetricProvenance),
}
#[derive(Clone, Debug)]
pub struct RowMeasure {
provenance: MeasureProvenance,
weights: Vec<f64>,
}
impl RowMeasure {
pub fn from_metric(metric: &RowMetric) -> Self {
let n = metric.n_rows();
if n == 0 {
return Self {
provenance: MeasureProvenance::Uniform,
weights: Vec::new(),
};
}
if matches!(metric.provenance(), MetricProvenance::Euclidean) {
return Self::uniform(n);
}
let mass = per_row_fisher_mass(metric);
Self::from_masses(metric.provenance(), mass)
}
pub fn uniform(n: usize) -> Self {
let w = if n == 0 { 0.0 } else { 1.0 / n as f64 };
Self {
provenance: MeasureProvenance::Uniform,
weights: vec![w; n],
}
}
fn from_masses(metric_provenance: MetricProvenance, masses: Vec<f64>) -> Self {
let n = masses.len();
if n == 0 {
return Self::uniform(0);
}
let mut total = 0.0_f64;
let mut clean = vec![0.0_f64; n];
let mut all_finite = true;
for (i, &m) in masses.iter().enumerate() {
if !m.is_finite() {
all_finite = false;
break;
}
let v = if m > 0.0 { m } else { 0.0 };
clean[i] = v;
total += v;
}
if !all_finite || !(total > 0.0) {
return Self::uniform(n);
}
let inv = 1.0 / total;
for w in clean.iter_mut() {
*w *= inv;
}
Self {
provenance: MeasureProvenance::FisherMass(metric_provenance),
weights: clean,
}
}
pub fn weights(&self) -> &[f64] {
&self.weights
}
pub fn provenance(&self) -> MeasureProvenance {
self.provenance
}
pub fn n_rows(&self) -> usize {
self.weights.len()
}
pub fn is_enriched(&self) -> bool {
matches!(self.provenance, MeasureProvenance::FisherMass(_))
}
pub fn enrichment_order(&self, count: usize, seed: u64) -> Vec<usize> {
let n = self.weights.len();
if n == 0 || count == 0 {
return Vec::new();
}
let u = {
let bits = splitmix64_hash(seed ^ ENRICHMENT_SALT);
let mantissa = (bits >> 11) as f64; mantissa / ((1_u64 << 53) as f64)
};
let mut cdf = vec![0.0_f64; n];
let mut acc = 0.0_f64;
for i in 0..n {
acc += self.weights[i];
cdf[i] = acc;
}
cdf[n - 1] = 1.0;
let mut out = Vec::with_capacity(count);
let step = 1.0 / count as f64;
let mut cursor = 0usize;
for j in 0..count {
let pointer = (j as f64 + u) * step;
while cursor < n - 1 && pointer > cdf[cursor] {
cursor += 1;
}
out.push(cursor);
}
out
}
pub fn expected_representation(&self, count: usize) -> Vec<f64> {
let c = count as f64;
self.weights.iter().map(|&w| c * w).collect()
}
}
const ENRICHMENT_SALT: u64 = 0x980E_1C45_F00D_AC70;
fn per_row_fisher_mass(metric: &RowMetric) -> Vec<f64> {
let blocks = metric.blocks();
let n = metric.n_rows();
let p = metric.p_out();
let mut mass = vec![0.0_f64; n];
for row in 0..n {
let mut tr = 0.0_f64;
for i in 0..p {
tr += blocks[[row, i, i]];
}
mass[row] = tr;
}
mass
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use std::sync::Arc;
fn factors_from_rows(rows: &[Vec<f64>], p: usize, rank: usize) -> Arc<Array2<f64>> {
let n = rows.len();
let mut u = Array2::<f64>::zeros((n, p * rank));
for (r, row) in rows.iter().enumerate() {
for (c, &v) in row.iter().enumerate() {
u[[r, c]] = v;
}
}
Arc::new(u)
}
#[test]
fn euclidean_degrades_to_uniform() {
let metric = RowMetric::euclidean(5, 3).expect("euclidean");
let measure = RowMeasure::from_metric(&metric);
assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
assert!(!measure.is_enriched());
for &w in measure.weights() {
assert!((w - 0.2).abs() < 1e-12);
}
}
#[test]
fn weights_normalize_to_one_and_track_mass() {
let rows = vec![vec![1.0], vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
assert!(measure.is_enriched());
let w = measure.weights();
let sum: f64 = w.iter().sum();
assert!((sum - 1.0).abs() < 1e-12);
assert!((w[0] - 1.0 / 12.0).abs() < 1e-12);
assert!((w[2] - 9.0 / 12.0).abs() < 1e-12);
assert!(w[2] > w[0] * 8.0);
}
#[test]
fn all_zero_mass_degrades_to_uniform() {
let rows = vec![vec![0.0], vec![0.0], vec![0.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
assert_eq!(measure.provenance(), MeasureProvenance::Uniform);
for &w in measure.weights() {
assert!((w - 1.0 / 3.0).abs() < 1e-12);
}
}
#[test]
fn enrichment_order_is_deterministic() {
let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let a = measure.enrichment_order(20, 7);
let b = measure.enrichment_order(20, 7);
assert_eq!(a, b, "same seed must give identical ordering");
let c = measure.enrichment_order(20, 8);
assert_eq!(c.len(), 20);
}
#[test]
fn enrichment_oversamples_loud_row() {
let rows = vec![vec![1.0], vec![3.0], vec![1.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let count = 110;
let order = measure.enrichment_order(count, 1);
let loud = order.iter().filter(|&&r| r == 1).count();
let quiet0 = order.iter().filter(|&&r| r == 0).count();
assert!(
loud > quiet0 * 5,
"loud row must be oversampled: loud={loud} quiet0={quiet0}"
);
}
#[test]
fn expected_representation_matches_count_times_weight() {
let rows = vec![vec![1.0], vec![3.0]];
let u = factors_from_rows(&rows, 1, 1);
let metric = RowMetric::output_fisher(u, 1, 1).expect("of");
let measure = RowMeasure::from_metric(&metric);
let rep = measure.expected_representation(10);
assert!((rep[0] - 1.0).abs() < 1e-12);
assert!((rep[1] - 9.0).abs() < 1e-12);
}
}