use super::*;
#[derive(Debug, Clone)]
pub(crate) struct ArrowRowFactorResult {
pub(crate) factor: Array2<f64>,
pub(crate) gauge_deflated_directions: usize,
}
pub(crate) fn try_factor_blocks_batched(
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Option<ArrowFactorSlab> {
if d == 0 || rows.is_empty() {
return None;
}
if rows
.iter()
.any(|row| row.htt.dim() != (d, d) || row.gt.len() != d)
{
return None;
}
if !crate::gpu::device_runtime::GpuRuntime::is_available() {
return None;
}
let mut blocks: Vec<Array2<f64>> = Vec::with_capacity(rows.len());
for row in rows {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_t;
}
blocks.push(block);
}
crate::gpu::try_cholesky_batched_lower_inplace(&mut blocks)?;
if !tolerate_ill_conditioning {
for (row, factor) in rows.iter().zip(blocks.iter()) {
let diag_scale = row_block_diag_scale(row, d);
let kappa_est = cholesky_factor_kappa_estimate(factor);
if !cholesky_factor_passes_safe_inversion(factor, d, diag_scale, kappa_est) {
return None;
}
}
}
Some(ArrowFactorSlab::from_blocks(blocks))
}
pub(crate) fn row_block_diag_scale(row: &ArrowRowBlock, d: usize) -> f64 {
(0..d)
.map(|a| row.htt[[a, a]].abs())
.fold(0.0_f64, f64::max)
.max(1.0)
}
pub(crate) fn cholesky_factor_kappa_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for a in 0..d {
let v = factor[[a, a]];
if v < min_diag {
min_diag = v;
}
if v > max_diag {
max_diag = v;
}
}
if min_diag > 0.0 && max_diag.is_finite() {
let ratio = max_diag / min_diag;
ratio * ratio
} else {
f64::INFINITY
}
}
pub(crate) fn cholesky_factor_min_pivot_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
if d == 0 {
return 0.0;
}
let mut min_pivot = f64::INFINITY;
for a in 0..d {
let v = factor[[a, a]];
if !(v > 0.0 && v.is_finite()) {
return 0.0;
}
let pivot = v * v;
if pivot < min_pivot {
min_pivot = pivot;
}
}
min_pivot
}
pub(crate) fn safe_spd_pivot_min(diag_scale: f64) -> f64 {
f64::EPSILON.sqrt() * diag_scale.max(1.0)
}
pub(crate) fn cholesky_factor_passes_safe_inversion(
factor: &Array2<f64>,
dim: usize,
diag_scale: f64,
kappa_est: f64,
) -> bool {
kappa_est.is_finite()
&& kappa_est <= safe_spd_kappa_max(dim)
&& cholesky_factor_min_pivot_estimate(factor) >= safe_spd_pivot_min(diag_scale)
}
pub(crate) fn safe_spd_kappa_max(dim: usize) -> f64 {
let d_scale = (dim as f64).max(1.0);
1.0 / (f64::EPSILON.sqrt() * d_scale)
}
pub(crate) fn factor_row_block_cholesky(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
match d {
1 => factor_row_block_cholesky_fixed::<1>(row, ridge_eff),
2 => factor_row_block_cholesky_fixed::<2>(row, ridge_eff),
3 => factor_row_block_cholesky_fixed::<3>(row, ridge_eff),
4 => factor_row_block_cholesky_fixed::<4>(row, ridge_eff),
_ => factor_row_block_cholesky_dynamic(row, ridge_eff, d),
}
}
pub(crate) fn factor_row_block_cholesky_dynamic(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_eff;
}
cholesky_lower(&block)
}
pub(crate) fn factor_row_block_cholesky_fixed<const D: usize>(
row: &ArrowRowBlock,
ridge_eff: f64,
) -> Result<Array2<f64>, String> {
for i in 0..D {
for j in 0..D {
let value = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
if !value.is_finite() {
let idx = i * D + j;
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
}
}
let mut l = [[0.0_f64; D]; D];
for i in 0..D {
for j in 0..=i {
let mut sum = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
for kk in 0..j {
sum -= l[i][kk] * l[j][kk];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let mut out = Array2::<f64>::zeros((D, D));
for i in 0..D {
for j in 0..=i {
out[[i, j]] = l[i][j];
}
}
Ok(out)
}
pub(crate) fn row_gauge_curvature(
row: &ArrowRowBlock,
d: usize,
gauge: &Array1<f64>,
) -> Option<f64> {
if gauge.len() != d {
return None;
}
let mut acc = 0.0_f64;
for i in 0..d {
let gi = gauge[i];
for j in 0..d {
acc += gi * row.htt[[i, j]] * gauge[j];
}
}
if acc.is_finite() { Some(acc) } else { None }
}
pub(crate) fn factor_gauge_deflated_evidence_row(
row: &ArrowRowBlock,
d: usize,
gauges: &[Array1<f64>],
) -> Option<ArrowRowFactorResult> {
const GAUGE_RAYLEIGH_EPS: f64 = 1.0e-8;
if gauges.is_empty() {
return None;
}
let max_diag = row_block_diag_scale(row, d);
if !(max_diag.is_finite() && max_diag > 0.0) {
return None;
}
let mut basis: Vec<Array1<f64>> = Vec::new();
for gauge in gauges {
if gauge.len() != d {
continue;
}
let norm_sq = gauge.iter().map(|&v| v * v).sum::<f64>();
if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
continue;
}
let curvature = row_gauge_curvature(row, d, gauge)?;
if curvature.abs() > GAUGE_RAYLEIGH_EPS * max_diag * norm_sq {
continue;
}
let mut direction = gauge.clone();
for existing in &basis {
let coeff = direction.dot(existing);
for idx in 0..d {
direction[idx] -= coeff * existing[idx];
}
}
let residual_norm_sq = direction.iter().map(|&v| v * v).sum::<f64>();
if !(residual_norm_sq.is_finite() && residual_norm_sq > 1.0e-24) {
continue;
}
let inv_norm = residual_norm_sq.sqrt().recip();
for value in direction.iter_mut() {
*value *= inv_norm;
}
basis.push(direction);
}
if basis.is_empty() {
return None;
}
let mut deflated = row.htt.clone();
for direction in &basis {
for i in 0..d {
for j in 0..d {
deflated[[i, j]] += direction[i] * direction[j];
}
}
}
let factor = cholesky_lower(&deflated).ok()?;
Some(ArrowRowFactorResult {
factor,
gauge_deflated_directions: basis.len(),
})
}
pub(crate) const SPECTRAL_DEFLATION_REL_FLOOR: f64 = 1.0e-8;
pub(crate) const SPECTRAL_DEFLATION_HYSTERESIS_FRACTION: f64 = 1.0e-2;
pub(crate) fn factor_spectral_deflated_evidence_row(
row: &ArrowRowBlock,
d: usize,
) -> Option<ArrowRowFactorResult> {
if d == 0 || row.htt.dim() != (d, d) {
return None;
}
let mut sym = Array2::<f64>::zeros((d, d));
for i in 0..d {
for j in 0..d {
let v = 0.5 * (row.htt[[i, j]] + row.htt[[j, i]]);
if !v.is_finite() {
return None;
}
sym[[i, j]] = v;
}
}
let (evals, evecs) = sym.eigh(Side::Lower).ok()?;
let max_abs = evals.iter().fold(
0.0_f64,
|acc, &v| if v.is_finite() { acc.max(v.abs()) } else { acc },
);
if !(max_abs.is_finite() && max_abs > 0.0) {
return None;
}
let floor = SPECTRAL_DEFLATION_REL_FLOOR * max_abs;
let deflate_floor = floor * (1.0 - SPECTRAL_DEFLATION_HYSTERESIS_FRACTION);
let mut conditioned = Array2::<f64>::zeros((d, d));
let mut deflated_count = 0usize;
for eig_idx in 0..evals.len() {
let lambda = evals[eig_idx];
let lambda_tilde = if lambda.is_finite() && lambda > deflate_floor {
lambda.max(floor)
} else {
deflated_count += 1;
1.0
};
for i in 0..d {
let vi = evecs[[i, eig_idx]];
for j in 0..d {
conditioned[[i, j]] += lambda_tilde * vi * evecs[[j, eig_idx]];
}
}
}
if deflated_count == 0 {
let mut min_idx = 0usize;
let mut min_lambda = f64::INFINITY;
for eig_idx in 0..evals.len() {
let lambda = evals[eig_idx];
if lambda < min_lambda {
min_lambda = lambda;
min_idx = eig_idx;
}
}
let kept = min_lambda.max(floor);
let delta = 1.0 - kept;
for i in 0..d {
let vi = evecs[[i, min_idx]];
for j in 0..d {
conditioned[[i, j]] += delta * vi * evecs[[j, min_idx]];
}
}
deflated_count = 1;
}
let factor = cholesky_lower(&conditioned).ok()?;
Some(ArrowRowFactorResult {
factor,
gauge_deflated_directions: deflated_count,
})
}
pub(crate) fn cholesky_solve_vector_fixed<const D: usize>(
l: ArrayView2<'_, f64>,
b: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert!(
(0..D).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f64::MIN_POSITIVE),
"cholesky_solve_vector_fixed: factor diagonal must be finite and non-subnormal"
);
let mut y = [0.0_f64; D];
for i in 0..D {
let mut sum = b[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = [0.0_f64; D];
for i in (0..D).rev() {
let mut sum = y[i];
for k in (i + 1)..D {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
let mut out = Array1::<f64>::zeros(D);
for i in 0..D {
out[i] = x[i];
}
out
}
pub(crate) fn factor_one_row(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
) -> Result<Array2<f64>, ArrowSchurError> {
factor_one_row_result(
row,
ridge_t,
d,
row_idx,
tolerate_ill_conditioning,
&[],
false,
)
.map(|result| result.factor)
}
pub(crate) fn factor_one_row_result(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
row_gauges: &[Array1<f64>],
allow_spectral_deflation: bool,
) -> Result<ArrowRowFactorResult, ArrowSchurError> {
if row.htt.dim() != (d, d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt shape {:?} does not match per_point_hessian_block dimension ({d}, {d})",
row.htt.dim()
),
});
}
if row.gt.len() != d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} g_t length {} does not match latent dimension {d}",
row.gt.len()
),
});
}
const RIDGE_GROWTH_FACTOR: f64 = 10.0;
const RIDGE_SEED_DIAG_FRACTION: f64 = 1.0e-10;
const RIDGE_CAP_DIAG_FRACTION: f64 = 1.0e-12;
const RIDGE_CAP_SCALE: f64 = 1.0e12;
let diag_scale = row_block_diag_scale(row, d);
let ridge_cap = ridge_t.max(RIDGE_CAP_DIAG_FRACTION * diag_scale) * RIDGE_CAP_SCALE;
let mut ridge_eff = ridge_t;
let factor = loop {
match factor_row_block_cholesky(row, ridge_eff, d) {
Ok(factor) => {
if tolerate_ill_conditioning {
if ridge_t == 0.0
&& !row_gauges.is_empty()
&& let Some(deflated) =
factor_gauge_deflated_evidence_row(row, d, row_gauges)
{
return Ok(deflated);
}
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let kappa_est = cholesky_factor_kappa_estimate(&factor);
if cholesky_factor_passes_safe_inversion(&factor, d, diag_scale, kappa_est) {
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: kappa_est,
});
}
ridge_eff = next;
}
Err(e) => {
if tolerate_ill_conditioning {
if ridge_t == 0.0 {
if let Some(deflated) =
factor_gauge_deflated_evidence_row(row, d, row_gauges)
{
return Ok(deflated);
}
if allow_spectral_deflation
&& let Some(deflated) = factor_spectral_deflated_evidence_row(row, d)
{
return Ok(deflated);
}
}
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt is non-PD at base ridge {ridge_t:e}; \
evidence mode preserves the genuine Cholesky of \
H_tt and does not condition non-PD blocks: {e}"
),
});
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt remained non-PD up to ridge {ridge_eff:e} \
(base ridge_t={ridge_t}); last cholesky error: {e}"
),
});
}
ridge_eff = next;
}
}
};
Ok(factor)
}
pub(crate) fn manifold_mode_fingerprint(latent: &LatentCoordValues) -> u64 {
let manifold = latent.manifold();
if manifold.is_euclidean() {
return EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-manifold-mode-v1");
hasher.write_usize(latent.n_obs());
hasher.write_usize(latent.latent_dim());
write_latent_manifold(&mut hasher, manifold);
let mut metric_weights = Vec::new();
append_latent_metric_weights(&mut metric_weights, manifold);
hasher.write_usize(metric_weights.len());
for weight in metric_weights {
hasher.write_f64(weight);
}
hasher.finish_u64()
}
pub(crate) fn row_hessian_fingerprint_for_system(sys: &ArrowSchurSystem) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-v2");
hasher.write_usize(sys.rows.len());
hasher.write_usize(sys.d);
hasher.write_usize(sys.k);
let htbeta_op_addr: Option<usize> = sys
.htbeta_matvec
.as_ref()
.map(|op| Arc::as_ptr(op) as *const () as usize);
for row in sys.rows.iter() {
hasher.write_f64_array2(&row.htt);
match htbeta_op_addr {
Some(addr) => {
hasher.write_usize(addr);
if sys.htbeta_dense_supplement {
hasher.write_f64_array2(&row.htbeta);
}
}
None => hasher.write_f64_array2(&row.htbeta),
}
}
match sys.penalty_op.as_ref() {
Some(op) => {
hasher.write_bool(true);
op.fingerprint(&mut hasher);
}
None => {
hasher.write_bool(false);
hasher.write_f64_array2(&sys.hbb);
}
}
match sys.hbb_diag.as_ref() {
Some(diag) => {
hasher.write_bool(true);
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
}
None => hasher.write_bool(false),
}
hasher.finish_u64()
}
pub(crate) fn combine_row_and_registry_fingerprints(row: u64, registry: u64) -> u64 {
if registry == 0 {
return row;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-with-penalties-v1");
hasher.write_u64(row);
hasher.write_u64(registry);
hasher.finish_u64()
}
pub(crate) fn analytic_penalty_row_hessian_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> Option<u64> {
if penalty.tier() != PenaltyTier::Psi || !analytic_penalty_is_row_block_diagonal(penalty) {
return None;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
match penalty {
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
let (n, rows, cols) = p.lambda_per_row.dim();
hasher.write_str("row-precision-fixed");
hasher.write_usize(n);
hasher.write_usize(rows);
hasher.write_usize(cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
if p.learnable_weight {
hasher.write_usize(p.rho_index);
hasher.write_f64(p.weight * rho_local[p.rho_index].exp());
}
for &value in p.lambda_per_row.iter() {
hasher.write_f64(value);
}
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
let (aux_n, aux_dim) = p.aux.dim();
let (mu_rows, mu_cols) = p.mu.dim();
let weight_offset = p.log_alpha.len() + p.raw_beta.len() + p.mu.len();
hasher.write_str("row-precision-parametric");
hasher.write_usize(aux_n);
hasher.write_usize(aux_dim);
hasher.write_usize(mu_rows);
hasher.write_usize(mu_cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
for &value in p.aux.iter() {
hasher.write_f64(value);
}
for k in 0..p.log_alpha.len() {
let active_log_alpha = p.log_alpha[k] + rho_local[k];
hasher.write_f64(p.log_alpha[k]);
hasher.write_f64(active_log_alpha);
hasher.write_f64(active_log_alpha.exp());
}
let raw_beta_offset = p.log_alpha.len();
for k in 0..p.raw_beta.len() {
let active_raw_beta = p.raw_beta[k] + rho_local[raw_beta_offset + k];
hasher.write_f64(p.raw_beta[k]);
hasher.write_f64(active_raw_beta);
hasher.write_f64(crate::linalg::utils::stable_softplus(active_raw_beta));
}
let mu_offset = p.log_alpha.len() + p.raw_beta.len();
for k in 0..p.mu.nrows() {
for a in 0..p.mu.ncols() {
let idx = mu_offset + k * p.aux.ncols() + a;
hasher.write_f64(p.mu[[k, a]]);
hasher.write_f64(p.mu[[k, a]] + rho_local[idx]);
}
}
if p.learnable_weight {
hasher.write_usize(weight_offset);
hasher.write_f64(p.weight * rho_local[weight_offset].exp());
}
}
_ => {
hasher.write_str("row-block-diagonal");
if let Some(diag) = penalty.hessian_diag(target_t, rho_local) {
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
} else {
hasher.write_usize(0);
}
}
}
Some(hasher.finish_u64())
}
pub(crate) fn cross_row_penalty_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-cross-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
let probe = penalty.psd_majorizer_hvp(target_t, rho_local, target_t);
hasher.write_usize(probe.len());
for &value in probe.iter() {
hasher.write_f64(value);
}
hasher.finish_u64()
}
pub(crate) fn write_latent_manifold(hasher: &mut Fingerprinter, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => {
hasher.write_str("euclidean");
}
LatentManifold::Circle { period } => {
hasher.write_str("circle");
hasher.write_f64(*period);
}
LatentManifold::Sphere { dim } => {
hasher.write_str("sphere");
hasher.write_usize(*dim);
}
LatentManifold::Interval { lo, hi } => {
hasher.write_str("interval");
hasher.write_f64(*lo);
hasher.write_f64(*hi);
}
LatentManifold::Product(parts) => {
hasher.write_str("product");
hasher.write_usize(parts.len());
for part in parts {
write_latent_manifold(hasher, part);
}
}
LatentManifold::ProductWithMetric { manifolds, weights } => {
hasher.write_str("product-with-metric");
hasher.write_usize(manifolds.len());
for part in manifolds {
write_latent_manifold(hasher, part);
}
hasher.write_usize(weights.len());
for weight in weights {
hasher.write_f64(*weight);
}
}
}
}
pub(crate) fn append_latent_metric_weights(out: &mut Vec<f64>, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => out.push(1.0),
LatentManifold::Circle { period } => {
out.push(1.0 / (period * period));
}
LatentManifold::Sphere { dim } => {
let scale = std::f64::consts::PI;
for _ in 0..*dim {
out.push(1.0 / (scale * scale));
}
}
LatentManifold::Interval { lo, hi } => {
let scale = hi - lo;
out.push(1.0 / (scale * scale));
}
LatentManifold::Product(parts) => {
for part in parts {
append_latent_metric_weights(out, part);
}
}
LatentManifold::ProductWithMetric {
manifolds: _,
weights,
} => {
out.extend(weights.iter().copied());
}
}
}