use ndarray::{Array2, Array3, ArrayView1};
use std::sync::Arc;
use crate::normalize_fisher_rao_blocks;
#[derive(Clone)]
pub enum WeightField {
Identity,
Factored {
u: Arc<Array2<f64>>,
rank: usize,
p_out: usize,
},
}
impl std::fmt::Debug for WeightField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WeightField::Identity => f.write_str("Identity"),
WeightField::Factored { u, rank, p_out } => f
.debug_struct("Factored")
.field("shape", &format_args!("{}×{}", u.nrows(), u.ncols()))
.field("rank", rank)
.field("p_out", p_out)
.finish(),
}
}
}
impl WeightField {
pub fn project_jac_row_with_u(
u_row: &[f64],
jac_row: &[f64],
p: usize,
rank: usize,
d: usize,
) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((rank, d));
for k in 0..rank {
for a in 0..d {
let mut s = 0.0;
for i in 0..p {
s += u_row[i * rank + k] * jac_row[i * d + a];
}
m[[k, a]] = s;
}
}
m
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum MetricProvenance {
Euclidean,
OutputFisher { rank: usize },
OutputFisherDownstream { rank: usize },
WhitenedStructured { factor_rank: usize },
}
#[derive(Clone, Debug)]
pub struct RowMetric {
provenance: MetricProvenance,
n_rows: usize,
p: usize,
rank: usize,
factors: Option<Arc<Array2<f64>>>,
solver_delta: f64,
traces: ndarray::Array1<f64>,
}
impl RowMetric {
pub fn euclidean(n_rows: usize, p: usize) -> Result<Self, String> {
Ok(Self {
provenance: MetricProvenance::Euclidean,
n_rows,
p,
rank: p,
factors: None,
solver_delta: 0.0,
traces: ndarray::Array1::<f64>::from_elem(n_rows, p as f64),
})
}
pub fn output_fisher(u: Arc<Array2<f64>>, p: usize, rank: usize) -> Result<Self, String> {
Self::from_factors(MetricProvenance::OutputFisher { rank }, u, p, rank, 0.0)
}
pub fn output_fisher_downstream(
u: Arc<Array2<f64>>,
p: usize,
rank: usize,
) -> Result<Self, String> {
Self::from_factors(
MetricProvenance::OutputFisherDownstream { rank },
u,
p,
rank,
0.0,
)
}
pub fn output_fisher_with_solver_floor(
u: Arc<Array2<f64>>,
p: usize,
rank: usize,
solver_delta: f64,
) -> Result<Self, String> {
if !(solver_delta.is_finite() && solver_delta >= 0.0) {
return Err(format!(
"RowMetric::output_fisher_with_solver_floor: solver_delta must be finite and \
non-negative; got {solver_delta}"
));
}
Self::from_factors(
MetricProvenance::OutputFisher { rank },
u,
p,
rank,
solver_delta,
)
}
pub fn whitened_structured(u: Arc<Array2<f64>>, p: usize, rank: usize) -> Result<Self, String> {
Self::from_factors(
MetricProvenance::WhitenedStructured { factor_rank: rank },
u,
p,
rank,
0.0,
)
}
fn from_factors(
provenance: MetricProvenance,
u: Arc<Array2<f64>>,
p: usize,
rank: usize,
solver_delta: f64,
) -> Result<Self, String> {
let n_rows = u.nrows();
if u.ncols() != p * rank {
return Err(format!(
"RowMetric::from_factors: factor matrix has {} cols; expected p*rank = {}*{} = {}",
u.ncols(),
p,
rank,
p * rank
));
}
if !u.iter().all(|v| v.is_finite()) {
return Err("RowMetric::from_factors: factors must be finite".to_string());
}
let mut traces = ndarray::Array1::<f64>::zeros(n_rows);
let mut full = Array3::<f64>::zeros((1, p, p));
for row in 0..n_rows {
for i in 0..p {
for j in 0..p {
let mut acc = 0.0;
for k in 0..rank {
acc += u[[row, i * rank + k]] * u[[row, j * rank + k]];
}
full[[0, i, j]] = acc;
}
}
normalize_fisher_rao_blocks(full.view().into_dyn(), 1, p)
.map_err(|e| format!("RowMetric::from_factors: row {row}: {e}"))?;
let mut tr = 0.0_f64;
for i in 0..p {
tr += full[[0, i, i]];
}
traces[row] = tr;
}
Ok(Self {
provenance,
n_rows,
p,
rank,
factors: Some(u),
solver_delta,
traces,
})
}
pub fn provenance(&self) -> MetricProvenance {
self.provenance
}
pub fn whitens_likelihood(&self) -> bool {
matches!(self.provenance, MetricProvenance::WhitenedStructured { .. })
}
pub fn drives_gauge(&self) -> bool {
!matches!(self.provenance, MetricProvenance::Euclidean)
}
pub fn is_output_fisher_like(&self) -> bool {
matches!(
self.provenance,
MetricProvenance::OutputFisher { .. } | MetricProvenance::OutputFisherDownstream { .. }
)
}
pub fn n_rows(&self) -> usize {
self.n_rows
}
pub fn p_out(&self) -> usize {
self.p
}
pub fn metric_rank(&self) -> usize {
self.rank
}
pub fn row_traces(&self) -> ndarray::ArrayView1<'_, f64> {
self.traces.view()
}
pub fn whiten_residual_row(&self, row: usize, r: ArrayView1<'_, f64>) -> Vec<f64> {
match &self.factors {
None => r.iter().copied().collect(),
Some(u) => {
let mut out = vec![0.0_f64; self.rank];
for k in 0..self.rank {
let mut acc = 0.0;
for i in 0..self.p {
acc += u[[row, i * self.rank + k]] * r[i];
}
out[k] = acc;
}
out
}
}
}
#[inline]
pub fn factor_entry(&self, row: usize, i: usize, k: usize) -> f64 {
match &self.factors {
None => {
if i == k {
1.0
} else {
0.0
}
}
Some(u) => u[[row, i * self.rank + k]],
}
}
pub fn apply_metric_row(&self, row: usize, x: ArrayView1<'_, f64>) -> Vec<f64> {
match &self.factors {
None => x.iter().copied().collect(),
Some(u) => {
let mut w = vec![0.0_f64; self.rank];
for k in 0..self.rank {
let mut acc = 0.0;
for i in 0..self.p {
acc += u[[row, i * self.rank + k]] * x[i];
}
w[k] = acc;
}
let mut out = vec![0.0_f64; self.p];
for i in 0..self.p {
let mut acc = 0.0;
for k in 0..self.rank {
acc += u[[row, i * self.rank + k]] * w[k];
}
out[i] = acc;
}
out
}
}
}
pub fn pullback(&self, row: usize, j_row: &[f64], d: usize) -> Array2<f64> {
match &self.factors {
None => {
let mut g = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in a..d {
let mut acc = 0.0;
for i in 0..self.p {
acc += j_row[i * d + a] * j_row[i * d + b];
}
g[[a, b]] = acc;
g[[b, a]] = acc;
}
}
g
}
Some(u) => {
let mut m = Array2::<f64>::zeros((self.rank, d));
for k in 0..self.rank {
for a in 0..d {
let mut acc = 0.0;
for i in 0..self.p {
acc += u[[row, i * self.rank + k]] * j_row[i * d + a];
}
m[[k, a]] = acc;
}
}
let mut g = Array2::<f64>::zeros((d, d));
for a in 0..d {
for b in a..d {
let mut acc = 0.0;
for k in 0..self.rank {
acc += m[[k, a]] * m[[k, b]];
}
g[[a, b]] = acc;
g[[b, a]] = acc;
}
}
g
}
}
}
#[inline]
pub fn quad_form(&self, row: usize, r: ArrayView1<'_, f64>) -> f64 {
match &self.factors {
None => r.iter().map(|&v| v * v).sum(),
Some(_) => self
.whiten_residual_row(row, r)
.iter()
.map(|&w| w * w)
.sum(),
}
}
pub fn whiten_jacobian(&self, row: usize, j_row: &[f64], d: usize) -> Array2<f64> {
match &self.factors {
None => {
let mut out = Array2::<f64>::zeros((self.p, d));
for i in 0..self.p {
for a in 0..d {
out[[i, a]] = j_row[i * d + a];
}
}
out
}
Some(u) => {
let mut m = Array2::<f64>::zeros((self.rank, d));
for k in 0..self.rank {
for a in 0..d {
let mut acc = 0.0;
for i in 0..self.p {
acc += u[[row, i * self.rank + k]] * j_row[i * d + a];
}
m[[k, a]] = acc;
}
}
m
}
}
}
#[inline]
pub fn fisher_mass(&self, row: usize, x: ArrayView1<'_, f64>) -> f64 {
self.quad_form(row, x)
}
pub fn solver_floor(&self) -> f64 {
self.solver_delta
}
pub fn to_weight_field(&self) -> crate::WeightField {
use crate::WeightField;
match &self.factors {
None => WeightField::Identity,
Some(u) => WeightField::Factored {
u: Arc::clone(u),
rank: self.rank,
p_out: self.p,
},
}
}
}