use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use std::ops::Range;
use std::sync::Arc;
use crate::cache::Fingerprinter;
use crate::linalg::faer_ndarray::{FaerArrayView, FaerLlt};
use crate::linalg::triangular::{
cholesky_solve_matrix, cholesky_solve_vector, forward_substitution_lower_matrix,
};
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, AnalyticPenaltyRegistry, PenaltyTier};
use crate::terms::latent_coord::{LatentCoordValues, LatentManifold};
const DIRECT_SOLVE_MAX_K: usize = 2_000;
const DEFAULT_PCG_MAX_ITERATIONS: usize = 200;
const DEFAULT_PCG_RELATIVE_TOLERANCE: f64 = 1e-4;
const PCG_ABSOLUTE_TOLERANCE_FLOOR: f64 = 1e-14;
const DEFAULT_TRUST_REGION_RADIUS: f64 = f64::INFINITY;
pub const DEFAULT_PROXIMAL_INITIAL_RIDGE: f64 = 1e-8;
const F32_UNIT_ROUNDOFF: f64 = (f32::EPSILON as f64) * 0.5;
const DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS: usize = 6;
const DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE: f64 = 1e-11;
const DEFAULT_MIXED_PRECISION_KAPPA_MARGIN: f64 = 0.5;
const MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER: f64 = 64.0;
const MIXED_PRECISION_KAPPA_MARGIN_CEILING: f64 = 1.0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct BetaEdge {
a: usize,
b: usize,
}
#[derive(Debug, Clone)]
struct BetaCouplingGraph {
num_blocks: usize,
edges: Vec<BetaEdge>,
adj_start: Vec<usize>,
adj_targets: Vec<usize>,
}
impl BetaCouplingGraph {
fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
let num_blocks = block_offsets.len();
if num_blocks == 0 {
return Self {
num_blocks: 0,
edges: Vec::new(),
adj_start: vec![0],
adj_targets: Vec::new(),
};
}
let mut edge_set = Vec::<(usize, usize)>::new();
for row in htbeta_rows {
let mut active = Vec::<usize>::new();
for (block, range) in block_offsets.iter().enumerate() {
if range
.clone()
.any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
{
active.push(block);
}
}
for i in 0..active.len() {
for j in (i + 1)..active.len() {
edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
}
}
}
edge_set.sort_unstable();
edge_set.dedup();
let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
let mut degree = vec![0usize; num_blocks];
for &BetaEdge { a, b } in &edges {
degree[a] += 1;
degree[b] += 1;
}
let mut adj_start = vec![0usize; num_blocks + 1];
for block in 0..num_blocks {
adj_start[block + 1] = adj_start[block] + degree[block];
}
let mut adj_targets = vec![0usize; adj_start[num_blocks]];
let mut cursor = adj_start[..num_blocks].to_vec();
for &BetaEdge { a, b } in &edges {
adj_targets[cursor[a]] = b;
cursor[a] += 1;
adj_targets[cursor[b]] = a;
cursor[b] += 1;
}
Self {
num_blocks,
edges,
adj_start,
adj_targets,
}
}
fn neighbours(&self, node: usize) -> &[usize] {
&self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
}
fn component_partition(&self) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..self.num_blocks).collect();
let mut rank = vec![0u8; self.num_blocks];
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
for &BetaEdge { a, b } in &self.edges {
let lhs = find(&mut parent, a);
let rhs = find(&mut parent, b);
if lhs != rhs {
if rank[lhs] < rank[rhs] {
parent[lhs] = rhs;
} else if rank[lhs] > rank[rhs] {
parent[rhs] = lhs;
} else {
parent[rhs] = lhs;
rank[lhs] += 1;
}
}
}
let mut label_map = vec![usize::MAX; self.num_blocks];
let mut parts = Vec::<Vec<usize>>::new();
for block in 0..self.num_blocks {
let root = find(&mut parent, block);
let label = if label_map[root] == usize::MAX {
label_map[root] = parts.len();
parts.push(Vec::new());
label_map[root]
} else {
label_map[root]
};
parts[label].push(block);
}
parts
}
fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
let mut expanded = seed.to_vec();
for &block in seed {
expanded.extend_from_slice(self.neighbours(block));
}
expanded.sort_unstable();
expanded.dedup();
expanded
}
}
pub const DEFAULT_PROXIMAL_RIDGE_GROWTH: f64 = 10.0;
pub const DEFAULT_PROXIMAL_MAX_ATTEMPTS: usize = 22;
const DEFAULT_ARMIJO_C1: f64 = 1e-4;
const DEFAULT_GRADIENT_TOLERANCE: f64 = 1e-10;
const DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL: f64 = 8e-12;
const EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT: u64 = 0;
const ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES: usize = 256 * 1024 * 1024;
pub type SharedBetaMatvec =
Arc<dyn for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaTransposeMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type StreamingArrowRowBuilder =
Arc<dyn Fn(usize) -> Result<ArrowRowBlock, ArrowSchurError> + Send + Sync>;
pub type GpuSchurMatvec = Arc<dyn Fn(&Array1<f64>, &mut Array1<f64>) + Send + Sync>;
type MetricWeights = [f64];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BetaBlockId(pub usize);
pub trait BetaPenaltyOp: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, x: &[f64], y: &mut [f64]);
fn gradient(&self, beta: &[f64], out: &mut [f64]);
fn diagonal(&self, diag: &mut [f64]);
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
fn to_dense(&self) -> Array2<f64>;
fn fingerprint(&self, hasher: &mut Fingerprinter);
}
pub struct DensePenaltyOp(pub Array2<f64>);
impl BetaPenaltyOp for DensePenaltyOp {
fn dim(&self) -> usize {
self.0.nrows()
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * x[b];
}
y[a] += acc;
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * beta[b];
}
out[a] += acc;
}
}
fn diagonal(&self, diag: &mut [f64]) {
let k = self.0.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.0[[j, j]];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
self.0.clone()
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("dense-penalty-op-v1");
hasher.write_f64_array2(&self.0);
}
}
pub struct BlockPenaltyOp {
pub k: usize,
pub blocks: Vec<(usize, Array2<f64>)>,
}
impl BetaPenaltyOp for BlockPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * x[off + j];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * beta[off + j];
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for j in 0..b {
diag[off + j] += local[[j, j]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_out = range.end - range.start;
for (off, local) in &self.blocks {
let b = local.nrows();
let block_end = off + b;
if block_end <= range.start || *off >= range.end {
continue;
}
for bi in 0..b_out {
let gi = range.start + bi;
if gi < *off || gi >= block_end {
continue;
}
let li = gi - off;
for bj in 0..b_out {
let gj = range.start + bj;
if gj < *off || gj >= block_end {
continue;
}
let lj = gj - off;
out[[bi, bj]] += local[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
for j in 0..b {
out[[off + i, off + j]] += local[[i, j]];
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("block-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for (off, local) in &self.blocks {
hasher.write_usize(*off);
hasher.write_f64_array2(local);
}
}
}
pub struct KroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub factor_b: Array2<f64>,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for KroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
}
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
}
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
diag[off + i_a * p_b + i_b] +=
self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let block_end = off + p_a * p_b;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p_b;
let i_b = li % p_b;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p_b;
let j_b = lj % p_b;
out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
let gj = off + j_a * p_b + j_b;
out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_f64_array2(&self.factor_a);
hasher.write_f64_array2(&self.factor_b);
}
}
pub struct IdentityRightKroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub p: usize,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
acc += a_ij * x[off + j_a * p + i_b];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
let a_ii = self.factor_a[[i_a, i_a]];
for i_b in 0..p {
diag[off + i_a * p + i_b] += a_ii;
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let block_end = off + p_a * p;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p;
let i_b = li % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p;
let j_b = lj % p;
if i_b == j_b {
out[[bi, bj]] += self.factor_a[[i_a, j_a]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let gj = off + j_a * p + i_b;
out[[gi, gj]] += a_ij;
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("identity-right-kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_usize(self.p);
hasher.write_f64_array2(&self.factor_a);
}
}
#[derive(Debug, Clone)]
pub struct SparseGBlock {
pub row_off: usize,
pub col_off: usize,
pub data: Array2<f64>,
}
pub struct SparseBlockKroneckerPenaltyOp {
pub p: usize,
pub dim_a: usize,
pub k: usize,
pub blocks: Vec<SparseGBlock>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaeSmoothBlock {
pub global_offset: usize,
pub factor_a: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaePcgData {
pub p: usize,
pub beta_dim: usize,
pub a_phi: Vec<Vec<(usize, f64)>>,
pub local_jac: Vec<Vec<f64>>,
pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
pub sparse_g_blocks: Vec<SparseGBlock>,
}
impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
y[gi_base + oc] += a_ij * x[gj_base + oc];
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
if blk.row_off != blk.col_off {
continue;
}
let (m_i, m_j) = blk.data.dim();
let m = m_i.min(m_j);
for li in 0..m {
let a_ii = blk.data[[li, li]];
let gi_base = (blk.row_off + li) * p;
for oc in 0..p {
diag[gi_base + oc] += a_ii;
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
let row_start = blk.row_off * p;
let row_end = (blk.row_off + m_i) * p;
let col_start = blk.col_off * p;
let col_end = (blk.col_off + m_j) * p;
if row_end <= range.start
|| row_start >= range.end
|| col_end <= range.start
|| col_start >= range.end
{
continue;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < row_start || gi >= row_end {
continue;
}
let li = (gi - row_start) / p;
let oc_i = (gi - row_start) % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < col_start || gj >= col_end {
continue;
}
let oc_j = (gj - col_start) % p;
if oc_i != oc_j {
continue;
}
let lj = (gj - col_start) / p;
out[[bi, bj]] += blk.data[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p = self.p;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
out[[gi_base + oc, gj_base + oc]] += a_ij;
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("sparse-block-kronecker-penalty-op-v1");
hasher.write_usize(self.p);
hasher.write_usize(self.dim_a);
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.row_off);
hasher.write_usize(blk.col_off);
hasher.write_f64_array2(&blk.data);
}
}
}
#[derive(Debug, Clone)]
pub struct FactoredFrameGBlock {
pub atom_i: usize,
pub atom_j: usize,
pub g: Array2<f64>,
pub w: Array2<f64>,
}
pub struct FactoredFrameKroneckerOp {
pub ranks: Vec<usize>,
pub basis_sizes: Vec<usize>,
pub offsets: Vec<usize>,
pub dim: usize,
pub blocks: Vec<FactoredFrameGBlock>,
}
pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
let (p_i, r_i) = u_i.dim();
let (p_j, r_j) = u_j.dim();
assert_eq!(
p_i, p_j,
"frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
);
let mut w = Array2::<f64>::zeros((r_i, r_j));
for a in 0..r_i {
for b in 0..r_j {
let mut acc = 0.0;
for c in 0..p_i {
acc += u_i[[c, a]] * u_j[[c, b]];
}
w[[a, b]] = acc;
}
}
w
}
impl FactoredFrameKroneckerOp {
pub fn new(
ranks: Vec<usize>,
basis_sizes: Vec<usize>,
blocks: Vec<FactoredFrameGBlock>,
) -> Result<Self, String> {
if ranks.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
ranks.len(),
basis_sizes.len()
));
}
let n_atoms = ranks.len();
let mut offsets = Vec::with_capacity(n_atoms + 1);
let mut acc = 0usize;
for k in 0..n_atoms {
offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
offsets.push(acc);
let dim = acc;
for blk in &blocks {
if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
blk.atom_i, blk.atom_j
));
}
if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.g.dim(),
basis_sizes[blk.atom_i],
basis_sizes[blk.atom_j]
));
}
if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.w.dim(),
ranks[blk.atom_i],
ranks[blk.atom_j]
));
}
}
Ok(Self {
ranks,
basis_sizes,
offsets,
dim,
blocks,
})
}
pub fn from_frames_and_blocks(
frames: &[Option<Array2<f64>>],
basis_sizes: &[usize],
p: usize,
g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
) -> Result<Self, String> {
if frames.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
frames.len(),
basis_sizes.len()
));
}
let n_atoms = frames.len();
let mut ranks = Vec::with_capacity(n_atoms);
for (k, frame) in frames.iter().enumerate() {
match frame {
Some(u) => {
let (pr, r) = u.dim();
if pr != p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
));
}
if r > p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
));
}
ranks.push(r);
}
None => ranks.push(p),
}
}
let identity = Array2::<f64>::eye(p);
let frame_or_ident = |k: usize| -> ArrayView2<f64> {
match &frames[k] {
Some(u) => u.view(),
None => identity.view(),
}
};
let mut blocks = Vec::with_capacity(g_blocks.len());
for (&(atom_i, atom_j), g) in g_blocks {
if atom_i >= n_atoms || atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
));
}
let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
blocks.push(FactoredFrameGBlock {
atom_i,
atom_j,
g: g.clone(),
w,
});
}
Self::new(ranks, basis_sizes.to_vec(), blocks)
}
}
impl BetaPenaltyOp for FactoredFrameKroneckerOp {
fn dim(&self) -> usize {
self.dim
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
let yi_base = off_i + li * r_i;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
let xj_base = off_j + lj * r_j;
for a in 0..r_i {
let mut acc = 0.0;
for b in 0..r_j {
acc += blk.w[[a, b]] * x[xj_base + b];
}
y[yi_base + a] += g * acc;
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
for blk in &self.blocks {
if blk.atom_i != blk.atom_j {
continue;
}
let r = self.ranks[blk.atom_i];
let off = self.offsets[blk.atom_i];
let (m_i, m_j) = blk.g.dim();
let m = m_i.min(m_j);
for li in 0..m {
let gii = blk.g[[li, li]];
let base = off + li * r;
for a in 0..r {
diag[base + a] += gii * blk.w[[a, a]];
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_dim = range.end - range.start;
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for a in 0..r_i {
let gi = off_i + li * r_i + a;
if gi < range.start || gi >= range.end {
continue;
}
let bi = gi - range.start;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
if gj < range.start || gj >= range.end {
continue;
}
let bj = gj - range.start;
if bi < b_dim && bj < b_dim {
out[[bi, bj]] += g * blk.w[[a, b]];
}
}
}
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.dim, self.dim));
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for a in 0..r_i {
let gi = off_i + li * r_i + a;
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
out[[gi, gj]] += g * blk.w[[a, b]];
}
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("factored-frame-kronecker-op-v1");
hasher.write_usize(self.dim);
for &r in &self.ranks {
hasher.write_usize(r);
}
for &m in &self.basis_sizes {
hasher.write_usize(m);
}
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.atom_i);
hasher.write_usize(blk.atom_j);
hasher.write_f64_array2(&blk.g);
hasher.write_f64_array2(&blk.w);
}
}
}
pub struct CompositePenaltyOp {
pub k: usize,
pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
}
impl BetaPenaltyOp for CompositePenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for op in &self.ops {
op.matvec(x, y);
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for op in &self.ops {
op.gradient(beta, out);
}
}
fn diagonal(&self, diag: &mut [f64]) {
for op in &self.ops {
op.diagonal(diag);
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
for op in &self.ops {
op.block(id, offsets, out);
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for op in &self.ops {
let dense = op.to_dense();
out += &dense;
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("composite-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.ops.len());
for op in &self.ops {
op.fingerprint(hasher);
}
}
}
pub struct MatvecDiagPenaltyOp {
k: usize,
matvec: SharedBetaMatvec,
diagonal_vec: Array1<f64>,
}
impl MatvecDiagPenaltyOp {
pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
assert_eq!(diagonal_vec.len(), k);
Self {
k,
matvec,
diagonal_vec,
}
}
}
impl BetaPenaltyOp for MatvecDiagPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let x_arr = Array1::from_iter(x.iter().copied());
let mut out = Array1::<f64>::zeros(self.k);
(self.matvec)(x_arr.view(), &mut out);
for a in 0..self.k {
y[a] += out[a];
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let beta_arr = Array1::from_iter(beta.iter().copied());
let mut hb = Array1::<f64>::zeros(self.k);
(self.matvec)(beta_arr.view(), &mut hb);
for a in 0..self.k {
out[a] += hb[a];
}
}
fn diagonal(&self, diag: &mut [f64]) {
for j in 0..self.k.min(diag.len()) {
diag[j] += self.diagonal_vec[j];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let mut probe = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[range.start + bj] = 1.0;
let mut col = Array1::<f64>::zeros(self.k);
(self.matvec)(probe.view(), &mut col);
for bi in 0..b {
out[[bi, bj]] += col[range.start + bi];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let k = self.k;
let mut out = Array2::<f64>::zeros((k, k));
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let mut col = Array1::<f64>::zeros(k);
(self.matvec)(probe.view(), &mut col);
for i in 0..k {
out[[i, j]] = col[i];
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("matvec-diag-penalty-op-v1");
hasher.write_usize(self.k);
for &value in self.diagonal_vec.iter() {
hasher.write_f64(value);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowSolverMode {
Direct,
SqrtBA,
InexactPCG,
}
impl ArrowSolverMode {
pub const fn automatic(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::Direct
} else {
Self::InexactPCG
}
}
pub const fn automatic_for_single_precision(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::SqrtBA
} else {
Self::InexactPCG
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum PcgStopReason {
#[default]
Converged,
MaxIter,
TrustRegion,
Indefinite,
Stagnation,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PcgDiagnostics {
pub iterations: usize,
pub matvec_calls: usize,
pub precond_apply_calls: usize,
pub ridge_escalations: usize,
pub final_relative_residual: f64,
pub stopping_reason: PcgStopReason,
pub mixed_precision_status: MixedPrecisionStatus,
pub used_device_arrow: bool,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MixedPrecisionStatus {
#[default]
Off,
Certified { refinement_steps: usize },
F64Fallback,
}
#[derive(Debug, Clone)]
pub struct ArrowPcgOptions {
pub max_iterations: usize,
pub relative_tolerance: f64,
}
impl Default for ArrowPcgOptions {
fn default() -> Self {
Self {
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTrustRegionOptions {
pub radius: f64,
pub steihaug_relative_tolerance: f64,
pub max_iterations: usize,
}
impl Default for ArrowTrustRegionOptions {
fn default() -> Self {
Self {
radius: DEFAULT_TRUST_REGION_RADIUS,
steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MixedPrecisionPolicy {
Off,
Certified {
max_refinement_steps: usize,
residual_relative_tolerance: f64,
kappa_unit_roundoff_margin: f64,
},
}
impl Default for MixedPrecisionPolicy {
fn default() -> Self {
Self::Off
}
}
impl MixedPrecisionPolicy {
pub fn certified() -> Self {
Self::Certified {
max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
}
}
fn is_enabled(self) -> bool {
matches!(self, MixedPrecisionPolicy::Certified { .. })
}
}
#[derive(Clone)]
pub struct ArrowSolveOptions {
pub mode: ArrowSolverMode,
pub pcg: ArrowPcgOptions,
pub trust_region: ArrowTrustRegionOptions,
pub streaming_chunk_size: Option<usize>,
pub riemannian_trust_region: bool,
pub gpu_matvec: Option<GpuSchurMatvec>,
pub tolerate_ill_conditioning: bool,
pub mixed_precision: MixedPrecisionPolicy,
}
impl std::fmt::Debug for ArrowSolveOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowSolveOptions")
.field("mode", &self.mode)
.field("pcg", &self.pcg)
.field("trust_region", &self.trust_region)
.field("streaming_chunk_size", &self.streaming_chunk_size)
.field("riemannian_trust_region", &self.riemannian_trust_region)
.field("gpu_matvec", &self.gpu_matvec.is_some())
.field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
.field("mixed_precision", &self.mixed_precision)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ArrowProximalCorrectionOptions {
pub initial_ridge: f64,
pub ridge_growth: f64,
pub max_attempts: usize,
pub armijo_c1: f64,
pub gradient_tolerance: f64,
pub convergence_objective_rel_tol: f64,
}
impl Default for ArrowProximalCorrectionOptions {
fn default() -> Self {
Self {
initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
armijo_c1: DEFAULT_ARMIJO_C1,
gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowAcceptedProximalStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub ridge_t: f64,
pub ridge_beta: f64,
pub proximal_ridge: f64,
pub objective_value: f64,
pub trial_objective_value: f64,
pub gradient_dot_step: f64,
pub attempts: usize,
}
impl ArrowSolveOptions {
pub fn automatic(k: usize) -> Self {
Self {
mode: ArrowSolverMode::automatic(k),
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn direct() -> Self {
Self {
mode: ArrowSolverMode::Direct,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn sqrt_ba() -> Self {
Self {
mode: ArrowSolverMode::SqrtBA,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn inexact_pcg() -> Self {
Self {
mode: ArrowSolverMode::InexactPCG,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
self
}
pub fn with_ill_conditioning_tolerated(mut self) -> Self {
self.tolerate_ill_conditioning = true;
self
}
pub fn with_mixed_precision_policy(mut self, policy: MixedPrecisionPolicy) -> Self {
self.mixed_precision = policy;
self
}
#[must_use]
pub fn with_streaming_mixed_precision_default(&self) -> Self {
let mut out = self.clone();
if matches!(out.mixed_precision, MixedPrecisionPolicy::Off) {
out.mixed_precision = MixedPrecisionPolicy::certified();
}
out
}
}
pub trait BatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError>;
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64>;
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
}
#[derive(Debug, Clone)]
pub struct ArrowRowGaugeDeflation {
pub directions: Arc<[Vec<Array1<f64>>]>,
}
impl ArrowRowGaugeDeflation {
pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
Self {
directions: Arc::from(directions.into_boxed_slice()),
}
}
fn row(&self, row: usize) -> &[Array1<f64>] {
self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBatchedBlockSolver;
impl BatchedBlockSolver for CpuBatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError> {
if let Some(batched) =
try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
{
return Ok(batched);
}
let mut out = Vec::with_capacity(rows.len());
for (row_idx, row) in rows.iter().enumerate() {
out.push(factor_one_row(
row,
ridge_t,
d,
row_idx,
tolerate_ill_conditioning,
)?);
}
Ok(ArrowFactorSlab::from_blocks(out))
}
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64> {
match (factor.nrows(), factor.ncols(), rhs.len()) {
(1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
(2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
(3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
(4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
_ => cholesky_solve_vector(factor, rhs),
}
}
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
cholesky_solve_matrix(factor, rhs)
}
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
forward_substitution_lower_matrix(factor, rhs)
}
fn block_gemm_subtract(
&self,
schur: &mut Array2<f64>,
left: &Array2<f64>,
right: &Array2<f64>,
) {
let k = schur.nrows();
let d = left.nrows();
assert_eq!(left.ncols(), k);
assert_eq!(right.ncols(), k);
assert_eq!(schur.ncols(), k);
for c in 0..d {
for a in 0..k {
let lca = left[[c, a]];
if lca == 0.0 {
continue;
}
for b in 0..k {
schur[[a, b]] -= lca * right[[c, b]];
}
}
}
}
}
#[derive(Debug, Clone)]
struct ArrowRowFactorResult {
factor: Array2<f64>,
gauge_deflated_directions: usize,
}
fn try_factor_blocks_batched(
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Option<ArrowFactorSlab> {
if d == 0 || rows.is_empty() {
return None;
}
if rows
.iter()
.any(|row| row.htt.dim() != (d, d) || row.gt.len() != d)
{
return None;
}
if !crate::gpu::runtime::GpuRuntime::is_available() {
return None;
}
let mut blocks: Vec<Array2<f64>> = Vec::with_capacity(rows.len());
for row in rows {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_t;
}
blocks.push(block);
}
crate::gpu::try_cholesky_batched_lower_inplace(&mut blocks)?;
if !tolerate_ill_conditioning {
for (row, factor) in rows.iter().zip(blocks.iter()) {
let diag_scale = row_block_diag_scale(row, d);
let kappa_est = cholesky_factor_kappa_estimate(factor);
if !cholesky_factor_passes_safe_inversion(factor, d, diag_scale, kappa_est) {
return None;
}
}
}
Some(ArrowFactorSlab::from_blocks(blocks))
}
fn row_block_diag_scale(row: &ArrowRowBlock, d: usize) -> f64 {
(0..d)
.map(|a| row.htt[[a, a]].abs())
.fold(0.0_f64, f64::max)
.max(1.0)
}
fn cholesky_factor_kappa_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for a in 0..d {
let v = factor[[a, a]];
if v < min_diag {
min_diag = v;
}
if v > max_diag {
max_diag = v;
}
}
if min_diag > 0.0 && max_diag.is_finite() {
let ratio = max_diag / min_diag;
ratio * ratio
} else {
f64::INFINITY
}
}
fn cholesky_factor_min_pivot_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
if d == 0 {
return 0.0;
}
let mut min_pivot = f64::INFINITY;
for a in 0..d {
let v = factor[[a, a]];
if !(v > 0.0 && v.is_finite()) {
return 0.0;
}
let pivot = v * v;
if pivot < min_pivot {
min_pivot = pivot;
}
}
min_pivot
}
fn safe_spd_pivot_min(diag_scale: f64) -> f64 {
f64::EPSILON.sqrt() * diag_scale.max(1.0)
}
fn cholesky_factor_passes_safe_inversion(
factor: &Array2<f64>,
dim: usize,
diag_scale: f64,
kappa_est: f64,
) -> bool {
kappa_est.is_finite()
&& kappa_est <= safe_spd_kappa_max(dim)
&& cholesky_factor_min_pivot_estimate(factor) >= safe_spd_pivot_min(diag_scale)
}
fn safe_spd_kappa_max(dim: usize) -> f64 {
let d_scale = (dim as f64).max(1.0);
1.0 / (f64::EPSILON.sqrt() * d_scale)
}
fn factor_row_block_cholesky(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
match d {
1 => factor_row_block_cholesky_fixed::<1>(row, ridge_eff),
2 => factor_row_block_cholesky_fixed::<2>(row, ridge_eff),
3 => factor_row_block_cholesky_fixed::<3>(row, ridge_eff),
4 => factor_row_block_cholesky_fixed::<4>(row, ridge_eff),
_ => factor_row_block_cholesky_dynamic(row, ridge_eff, d),
}
}
fn factor_row_block_cholesky_dynamic(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_eff;
}
cholesky_lower(&block)
}
fn factor_row_block_cholesky_fixed<const D: usize>(
row: &ArrowRowBlock,
ridge_eff: f64,
) -> Result<Array2<f64>, String> {
for i in 0..D {
for j in 0..D {
let value = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
if !value.is_finite() {
let idx = i * D + j;
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
}
}
let mut l = [[0.0_f64; D]; D];
for i in 0..D {
for j in 0..=i {
let mut sum = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
for kk in 0..j {
sum -= l[i][kk] * l[j][kk];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let mut out = Array2::<f64>::zeros((D, D));
for i in 0..D {
for j in 0..=i {
out[[i, j]] = l[i][j];
}
}
Ok(out)
}
fn row_gauge_curvature(row: &ArrowRowBlock, d: usize, gauge: &Array1<f64>) -> Option<f64> {
if gauge.len() != d {
return None;
}
let mut acc = 0.0_f64;
for i in 0..d {
let gi = gauge[i];
for j in 0..d {
acc += gi * row.htt[[i, j]] * gauge[j];
}
}
if acc.is_finite() { Some(acc) } else { None }
}
fn factor_gauge_deflated_evidence_row(
row: &ArrowRowBlock,
d: usize,
gauges: &[Array1<f64>],
) -> Option<ArrowRowFactorResult> {
const GAUGE_RAYLEIGH_EPS: f64 = 1.0e-8;
if gauges.is_empty() {
return None;
}
let max_diag = row_block_diag_scale(row, d);
if !(max_diag.is_finite() && max_diag > 0.0) {
return None;
}
let mut basis: Vec<Array1<f64>> = Vec::new();
for gauge in gauges {
if gauge.len() != d {
continue;
}
let norm_sq = gauge.iter().map(|&v| v * v).sum::<f64>();
if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
continue;
}
let curvature = row_gauge_curvature(row, d, gauge)?;
if curvature.abs() > GAUGE_RAYLEIGH_EPS * max_diag * norm_sq {
continue;
}
let mut direction = gauge.clone();
for existing in &basis {
let coeff = direction.dot(existing);
for idx in 0..d {
direction[idx] -= coeff * existing[idx];
}
}
let residual_norm_sq = direction.iter().map(|&v| v * v).sum::<f64>();
if !(residual_norm_sq.is_finite() && residual_norm_sq > 1.0e-24) {
continue;
}
let inv_norm = residual_norm_sq.sqrt().recip();
for value in direction.iter_mut() {
*value *= inv_norm;
}
basis.push(direction);
}
if basis.is_empty() {
return None;
}
let mut deflated = row.htt.clone();
for direction in &basis {
for i in 0..d {
for j in 0..d {
deflated[[i, j]] += direction[i] * direction[j];
}
}
}
let factor = cholesky_lower(&deflated).ok()?;
Some(ArrowRowFactorResult {
factor,
gauge_deflated_directions: basis.len(),
})
}
fn cholesky_solve_vector_fixed<const D: usize>(
l: ArrayView2<'_, f64>,
b: ArrayView1<'_, f64>,
) -> Array1<f64> {
let mut y = [0.0_f64; D];
for i in 0..D {
let mut sum = b[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = [0.0_f64; D];
for i in (0..D).rev() {
let mut sum = y[i];
for k in (i + 1)..D {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
let mut out = Array1::<f64>::zeros(D);
for i in 0..D {
out[i] = x[i];
}
out
}
fn factor_one_row(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
) -> Result<Array2<f64>, ArrowSchurError> {
factor_one_row_result(row, ridge_t, d, row_idx, tolerate_ill_conditioning, &[])
.map(|result| result.factor)
}
fn factor_one_row_result(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
row_gauges: &[Array1<f64>],
) -> Result<ArrowRowFactorResult, ArrowSchurError> {
if row.htt.dim() != (d, d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt shape {:?} does not match per_point_hessian_block dimension ({d}, {d})",
row.htt.dim()
),
});
}
if row.gt.len() != d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} g_t length {} does not match latent dimension {d}",
row.gt.len()
),
});
}
const RIDGE_GROWTH_FACTOR: f64 = 10.0;
const RIDGE_SEED_DIAG_FRACTION: f64 = 1.0e-10;
const RIDGE_CAP_DIAG_FRACTION: f64 = 1.0e-12;
const RIDGE_CAP_SCALE: f64 = 1.0e12;
let diag_scale = row_block_diag_scale(row, d);
let ridge_cap = ridge_t.max(RIDGE_CAP_DIAG_FRACTION * diag_scale) * RIDGE_CAP_SCALE;
let mut ridge_eff = ridge_t;
let factor = loop {
match factor_row_block_cholesky(row, ridge_eff, d) {
Ok(factor) => {
if tolerate_ill_conditioning {
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let kappa_est = cholesky_factor_kappa_estimate(&factor);
if cholesky_factor_passes_safe_inversion(&factor, d, diag_scale, kappa_est) {
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: kappa_est,
});
}
ridge_eff = next;
}
Err(e) => {
if tolerate_ill_conditioning {
if ridge_t == 0.0 {
if let Some(deflated) =
factor_gauge_deflated_evidence_row(row, d, row_gauges)
{
return Ok(deflated);
}
}
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt is non-PD at base ridge {ridge_t:e}; \
evidence mode preserves the genuine Cholesky of \
H_tt and does not condition non-PD blocks: {e}"
),
});
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt remained non-PD up to ridge {ridge_eff:e} \
(base ridge_t={ridge_t}); last cholesky error: {e}"
),
});
}
ridge_eff = next;
}
}
};
Ok(factor)
}
fn manifold_mode_fingerprint(latent: &LatentCoordValues) -> u64 {
let manifold = latent.manifold();
if manifold.is_euclidean() {
return EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-manifold-mode-v1");
hasher.write_usize(latent.n_obs());
hasher.write_usize(latent.latent_dim());
write_latent_manifold(&mut hasher, manifold);
let mut metric_weights = Vec::new();
append_latent_metric_weights(&mut metric_weights, manifold);
hasher.write_usize(metric_weights.len());
for weight in metric_weights {
hasher.write_f64(weight);
}
hasher.finish_u64()
}
fn row_hessian_fingerprint_for_system(sys: &ArrowSchurSystem) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-v2");
hasher.write_usize(sys.rows.len());
hasher.write_usize(sys.d);
hasher.write_usize(sys.k);
let htbeta_op_addr: Option<usize> = sys
.htbeta_matvec
.as_ref()
.map(|op| Arc::as_ptr(op) as *const () as usize);
for row in sys.rows.iter() {
hasher.write_f64_array2(&row.htt);
match htbeta_op_addr {
Some(addr) => {
hasher.write_usize(addr);
if sys.htbeta_dense_supplement {
hasher.write_f64_array2(&row.htbeta);
}
}
None => hasher.write_f64_array2(&row.htbeta),
}
}
match sys.penalty_op.as_ref() {
Some(op) => {
hasher.write_bool(true);
op.fingerprint(&mut hasher);
}
None => {
hasher.write_bool(false);
hasher.write_f64_array2(&sys.hbb);
}
}
match sys.hbb_diag.as_ref() {
Some(diag) => {
hasher.write_bool(true);
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
}
None => hasher.write_bool(false),
}
hasher.finish_u64()
}
fn combine_row_and_registry_fingerprints(row: u64, registry: u64) -> u64 {
if registry == 0 {
return row;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-with-penalties-v1");
hasher.write_u64(row);
hasher.write_u64(registry);
hasher.finish_u64()
}
fn analytic_penalty_row_hessian_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> Option<u64> {
if penalty.tier() != PenaltyTier::Psi || !analytic_penalty_is_row_block_diagonal(penalty) {
return None;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
match penalty {
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
let (n, rows, cols) = p.lambda_per_row.dim();
hasher.write_str("row-precision-fixed");
hasher.write_usize(n);
hasher.write_usize(rows);
hasher.write_usize(cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
if p.learnable_weight {
hasher.write_usize(p.rho_index);
hasher.write_f64(p.weight * rho_local[p.rho_index].exp());
}
for &value in p.lambda_per_row.iter() {
hasher.write_f64(value);
}
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
let (aux_n, aux_dim) = p.aux.dim();
let (mu_rows, mu_cols) = p.mu.dim();
let weight_offset = p.log_alpha.len() + p.raw_beta.len() + p.mu.len();
hasher.write_str("row-precision-parametric");
hasher.write_usize(aux_n);
hasher.write_usize(aux_dim);
hasher.write_usize(mu_rows);
hasher.write_usize(mu_cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
for &value in p.aux.iter() {
hasher.write_f64(value);
}
for k in 0..p.log_alpha.len() {
let active_log_alpha = p.log_alpha[k] + rho_local[k];
hasher.write_f64(p.log_alpha[k]);
hasher.write_f64(active_log_alpha);
hasher.write_f64(active_log_alpha.exp());
}
let raw_beta_offset = p.log_alpha.len();
for k in 0..p.raw_beta.len() {
let active_raw_beta = p.raw_beta[k] + rho_local[raw_beta_offset + k];
hasher.write_f64(p.raw_beta[k]);
hasher.write_f64(active_raw_beta);
hasher.write_f64(crate::linalg::utils::stable_softplus(active_raw_beta));
}
let mu_offset = p.log_alpha.len() + p.raw_beta.len();
for k in 0..p.mu.nrows() {
for a in 0..p.mu.ncols() {
let idx = mu_offset + k * p.aux.ncols() + a;
hasher.write_f64(p.mu[[k, a]]);
hasher.write_f64(p.mu[[k, a]] + rho_local[idx]);
}
}
if p.learnable_weight {
hasher.write_usize(weight_offset);
hasher.write_f64(p.weight * rho_local[weight_offset].exp());
}
}
_ => {
hasher.write_str("row-block-diagonal");
if let Some(diag) = penalty.hessian_diag(target_t, rho_local) {
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
} else {
hasher.write_usize(0);
}
}
}
Some(hasher.finish_u64())
}
fn cross_row_penalty_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-cross-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
let probe = penalty.psd_majorizer_hvp(target_t, rho_local, target_t);
hasher.write_usize(probe.len());
for &value in probe.iter() {
hasher.write_f64(value);
}
hasher.finish_u64()
}
fn write_latent_manifold(hasher: &mut Fingerprinter, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => {
hasher.write_str("euclidean");
}
LatentManifold::Circle { period } => {
hasher.write_str("circle");
hasher.write_f64(*period);
}
LatentManifold::Sphere { dim } => {
hasher.write_str("sphere");
hasher.write_usize(*dim);
}
LatentManifold::Interval { lo, hi } => {
hasher.write_str("interval");
hasher.write_f64(*lo);
hasher.write_f64(*hi);
}
LatentManifold::Product(parts) => {
hasher.write_str("product");
hasher.write_usize(parts.len());
for part in parts {
write_latent_manifold(hasher, part);
}
}
LatentManifold::ProductWithMetric { manifolds, weights } => {
hasher.write_str("product-with-metric");
hasher.write_usize(manifolds.len());
for part in manifolds {
write_latent_manifold(hasher, part);
}
hasher.write_usize(weights.len());
for weight in weights {
hasher.write_f64(*weight);
}
}
}
}
fn append_latent_metric_weights(out: &mut Vec<f64>, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => out.push(1.0),
LatentManifold::Circle { period } => {
out.push(1.0 / (period * period));
}
LatentManifold::Sphere { dim } => {
let scale = std::f64::consts::PI;
for _ in 0..*dim {
out.push(1.0 / (scale * scale));
}
}
LatentManifold::Interval { lo, hi } => {
let scale = hi - lo;
out.push(1.0 / (scale * scale));
}
LatentManifold::Product(parts) => {
for part in parts {
append_latent_metric_weights(out, part);
}
}
LatentManifold::ProductWithMetric {
manifolds: _,
weights,
} => {
out.extend(weights.iter().copied());
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowRowBlock {
pub htt: Array2<f64>,
pub htbeta: Array2<f64>,
pub gt: Array1<f64>,
}
impl ArrowRowBlock {
pub fn new(d: usize, k: usize) -> Self {
Self::new_with_htbeta_cols(d, k)
}
pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
Self {
htt: Array2::<f64>::zeros((d, d)),
htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
gt: Array1::<f64>::zeros(d),
}
}
}
pub struct ArrowSchurSystem {
pub rows: Vec<ArrowRowBlock>,
pub hbb: Array2<f64>,
pub hbb_matvec: Option<SharedBetaMatvec>,
pub htbeta_matvec: Option<RowHtbetaMatvec>,
pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
pub htbeta_dense_supplement: bool,
pub hbb_diag: Option<Array1<f64>>,
pub gb: Array1<f64>,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub analytic_row_hessian_fingerprint: u64,
pub block_offsets: Arc<[Range<usize>]>,
pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
}
impl Clone for ArrowSchurSystem {
fn clone(&self) -> Self {
Self {
rows: self.rows.clone(),
hbb: self.hbb.clone(),
hbb_matvec: self.hbb_matvec.clone(),
htbeta_matvec: self.htbeta_matvec.clone(),
htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
htbeta_dense_supplement: self.htbeta_dense_supplement,
hbb_diag: self.hbb_diag.clone(),
gb: self.gb.clone(),
d: self.d,
row_dims: Arc::clone(&self.row_dims),
row_offsets: Arc::clone(&self.row_offsets),
k: self.k,
manifold_mode_fingerprint: self.manifold_mode_fingerprint,
row_hessian_fingerprint: self.row_hessian_fingerprint,
analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
block_offsets: Arc::clone(&self.block_offsets),
penalty_op: self.penalty_op.clone(),
device_sae_pcg: self.device_sae_pcg.clone(),
cross_row_penalties: self.cross_row_penalties.clone(),
row_gauge_deflation: self.row_gauge_deflation.clone(),
}
}
}
#[derive(Clone)]
pub struct CrossRowLatentPenalty {
pub penalty: AnalyticPenaltyKind,
pub rho_local: Array1<f64>,
pub target_t: Array1<f64>,
}
impl ArrowSchurSystem {
pub fn new(n: usize, d: usize, k: usize) -> Self {
Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
}
pub fn new_with_empty_hbb(n: usize, d: usize, k: usize) -> Self {
Self::new_with_empty_hbb_and_htbeta_cols(n, d, k, k)
}
pub fn new_with_empty_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
htbeta_cols: usize,
) -> Self {
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
}
pub fn new_with_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let mut sys = Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_matrix_free_shared<F>(
n: usize,
d: usize,
k: usize,
matvec: F,
diag: Array1<f64>,
) -> Self
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), k);
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: Some(matvec_arc),
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: Some(diag),
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_with_per_row_dims(per_row_dims: Vec<usize>, k: usize) -> Self {
Self::new_with_per_row_dims_and_hbb(per_row_dims, k, Array2::<f64>::zeros((k, k)))
}
pub fn new_with_per_row_dims_empty_hbb(per_row_dims: Vec<usize>, k: usize) -> Self {
Self::new_with_per_row_dims_empty_hbb_and_htbeta_cols(per_row_dims, k, k)
}
pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
htbeta_cols: usize,
) -> Self {
let n = per_row_dims.len();
let d = per_row_dims.iter().copied().max().unwrap_or(0);
let mut offsets = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
offsets.push(cursor);
for &dim in &per_row_dims {
cursor += dim;
offsets.push(cursor);
}
let rows = per_row_dims
.iter()
.map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
.collect();
let mut sys = Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims: Arc::from(per_row_dims.into_boxed_slice()),
row_offsets: Arc::from(offsets.into_boxed_slice()),
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn new_with_per_row_dims_and_hbb(
per_row_dims: Vec<usize>,
k: usize,
hbb: Array2<f64>,
) -> Self {
Self::new_with_per_row_dims_and_hbb_and_htbeta_cols(per_row_dims, k, hbb, k)
}
pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let n = per_row_dims.len();
let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
let mut off_vec = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
for &di in &per_row_dims {
off_vec.push(cursor);
cursor += di;
}
off_vec.push(cursor);
let row_offsets: Arc<[usize]> = off_vec.into();
let rows = per_row_dims
.iter()
.map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
.collect();
let mut sys = Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d: max_d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
};
sys.refresh_row_hessian_fingerprint();
sys
}
pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
self.row_gauge_deflation = Some(deflation);
}
pub fn n(&self) -> usize {
self.rows.len()
}
pub fn compute_row_hessian_fingerprint(&self) -> u64 {
row_hessian_fingerprint_for_system(self)
}
pub fn current_row_hessian_fingerprint(&self) -> u64 {
combine_row_and_registry_fingerprints(
self.compute_row_hessian_fingerprint(),
self.analytic_row_hessian_fingerprint,
)
}
pub fn refresh_row_hessian_fingerprint(&mut self) {
self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
}
pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), self.k);
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
self.k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
self.hbb_matvec = Some(matvec_arc);
self.hbb_diag = Some(diag);
self.refresh_row_hessian_fingerprint();
}
pub fn activate_dense_htbeta_supplement(&mut self) {
self.htbeta_dense_supplement = true;
self.refresh_row_hessian_fingerprint();
}
pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
where
F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
self.htbeta_matvec = Some(Arc::new(forward));
self.htbeta_transpose_matvec = Some(Arc::new(transpose));
self.refresh_row_hessian_fingerprint();
}
pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
self.block_offsets = offsets;
}
pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
self.penalty_op = Some(op);
self.refresh_row_hessian_fingerprint();
}
pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
assert_eq!(data.beta_dim, self.k);
assert_eq!(data.a_phi.len(), self.rows.len());
assert_eq!(data.local_jac.len(), self.rows.len());
self.device_sae_pcg = Some(Arc::new(data));
}
pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
match self.penalty_op.as_ref() {
Some(op) => Arc::clone(op),
None => Arc::new(DensePenaltyOp(self.hbb.clone())),
}
}
#[inline]
fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.matvec(x, y);
} else {
let k = self.hbb.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.hbb[[a, b]] * x[b];
}
y[a] += acc;
}
}
}
#[inline]
fn penalty_diagonal_add(&self, diag: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.diagonal(diag);
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
let k = hbb_diag.len().min(diag.len());
for j in 0..k {
diag[j] += hbb_diag[j];
}
} else {
let k = self.hbb.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.hbb[[j, j]];
}
}
}
#[inline]
fn penalty_block_add(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
if let Some(op) = self.penalty_op.as_ref() {
op.block(id, offsets, out);
} else {
let range = &offsets[id.0];
let b = range.end - range.start;
if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[range.start + bi];
}
}
}
}
#[inline]
fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
let b = cols.len();
if let Some(op) = self.penalty_op.as_ref() {
let mut probe = Array1::<f64>::zeros(self.k);
let mut result = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[cols[bj]] = 1.0;
result.fill(0.0);
{
let p_slice = probe.as_slice().expect("probe contiguous");
let r_slice = result.as_slice_mut().expect("result contiguous");
op.matvec(p_slice, r_slice);
}
for bi in 0..b {
out[[bi, bj]] += result[cols[bi]];
}
}
} else if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[cols[bi]];
}
}
}
pub fn add_analytic_penalty_contributions(
&mut self,
registry: &AnalyticPenaltyRegistry,
target_t: ArrayView1<'_, f64>,
target_beta: ArrayView1<'_, f64>,
rho_global: ArrayView1<'_, f64>,
) -> Result<(), ArrowSchurError> {
let layout = registry.rho_layout();
let mut penalty_fingerprints = Vec::new();
self.cross_row_penalties.clear();
for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if analytic_penalty_is_row_block_diagonal(penalty) {
self.add_ext_coord_penalty(penalty, target_t, rho_local);
if let Some(fingerprint) =
analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
{
penalty_fingerprints.push(fingerprint);
}
} else {
self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
self.cross_row_penalties.push(CrossRowLatentPenalty {
penalty: penalty.clone(),
rho_local: rho_local.to_owned(),
target_t: target_t.to_owned(),
});
}
}
PenaltyTier::Beta => {
self.add_beta_penalty(penalty, target_beta, rho_local);
}
PenaltyTier::Rho => {
}
}
}
for cross in &self.cross_row_penalties {
penalty_fingerprints.push(cross_row_penalty_fingerprint(
&cross.penalty,
target_t,
cross.rho_local.view(),
));
}
self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
0
} else {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-registry-v1");
hasher.write_usize(penalty_fingerprints.len());
for fingerprint in penalty_fingerprints {
hasher.write_u64(fingerprint);
}
hasher.finish_u64()
};
self.refresh_row_hessian_fingerprint();
Ok(())
}
pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
let manifold = latent.manifold();
self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
if manifold.is_euclidean() {
self.refresh_row_hessian_fingerprint();
return;
}
assert_eq!(latent.n_obs(), self.rows.len());
assert_eq!(latent.latent_dim(), self.d);
for (i, row) in self.rows.iter_mut().enumerate() {
let t_i = ArrayView1::from(latent.row(i));
let gt_e = row.gt.clone();
let htt_e = row.htt.clone();
let htbeta_e = row.htbeta.clone();
row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
t_i,
gt_e.view(),
htbeta_e.view(),
);
}
self.refresh_row_hessian_fingerprint();
}
fn add_ext_coord_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
apply_analytic_penalty(
penalty,
target_t,
rho_local,
n * d,
d,
self,
|sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
|sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
|a, probe| {
for i in 0..n {
probe[i * d + a] = 1.0;
}
},
|sys, a, hv| {
for i in 0..n {
for b in 0..d {
sys.rows[i].htt[[b, a]] += hv[i * d + b];
}
}
},
);
}
fn add_ext_coord_penalty_gradient_only(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
assert_eq!(target_t.len(), n * d);
let grad = penalty.grad_target(target_t, rho_local);
for flat in 0..n * d {
self.rows[flat / d].gt[flat % d] += grad[flat];
}
}
fn apply_cross_row_penalty_hessian(&self, v: ArrayView1<'_, f64>, out: &mut Array1<f64>) {
for cross in &self.cross_row_penalties {
assert_eq!(cross.target_t.len(), v.len());
let hv =
cross
.penalty
.psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
assert_eq!(hv.len(), out.len());
for i in 0..out.len() {
out[i] += hv[i];
}
}
}
fn add_beta_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let k = self.k;
let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
apply_analytic_penalty(
penalty,
target_beta,
rho_local,
k,
hvp_columns,
self,
|sys, j, value| sys.gb[j] += value,
|sys, j, value| {
if sys.hbb.dim() == (k, k) {
sys.hbb[[j, j]] += value;
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += value;
}
},
|j, probe| probe[j] = 1.0,
|sys, j, hv| {
for i in 0..k {
sys.hbb[[i, j]] += hv[i];
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += hv[j];
}
},
);
}
pub fn solve(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_lm_escalation(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_options(
&self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
}
}
pub struct StreamingArrowSchur {
pub n_rows: usize,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub chunk_size: usize,
pub s_acc: Array2<f64>,
rhs_acc: Array1<f64>,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
htbeta_matvec: Option<RowHtbetaMatvec>,
htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
tolerate_ill_conditioning: bool,
}
impl std::fmt::Debug for StreamingArrowSchur {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingArrowSchur")
.field("n_rows", &self.n_rows)
.field("d", &self.d)
.field("k", &self.k)
.field("chunk_size", &self.chunk_size)
.finish_non_exhaustive()
}
}
impl StreamingArrowSchur {
#[must_use]
pub fn new(
n_rows: usize,
d: usize,
row_dims: Arc<[usize]>,
row_offsets: Arc<[usize]>,
k: usize,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
chunk_size: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
assert_eq!(gb.len(), k);
Self {
n_rows,
d,
row_dims,
row_offsets,
k,
chunk_size: chunk_size.max(1),
s_acc: Array2::<f64>::zeros((k, k)),
rhs_acc: Array1::<f64>::zeros(k),
hbb,
gb,
row_builder,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
tolerate_ill_conditioning: false,
}
}
#[must_use]
pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
let htbeta_matvec = sys.htbeta_matvec.clone();
let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
sys.rows
.iter()
.map(|row| ArrowRowBlock {
htt: row.htt.clone(),
htbeta: Array2::<f64>::zeros((0, 0)),
gt: row.gt.clone(),
})
.collect()
} else {
sys.rows.clone()
};
let rows = Arc::new(rows);
let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
rows.get(row)
.cloned()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: format!("streaming row {row} out of bounds"),
})
});
let hbb_dense = sys.effective_penalty_op().to_dense();
let mut streaming = Self::new(
sys.rows.len(),
sys.d,
Arc::clone(&sys.row_dims),
Arc::clone(&sys.row_offsets),
sys.k,
hbb_dense,
sys.gb.clone(),
row_builder,
chunk_size,
);
streaming.htbeta_matvec = htbeta_matvec;
streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
streaming
}
fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_c = Array1::<f64>::zeros(di);
let mut beta_row = Array1::<f64>::zeros(self.k);
for c in 0..di {
e_c.fill(0.0);
e_c[c] = 1.0;
beta_row.fill(0.0);
op_t(row_idx, e_c.view(), &mut beta_row);
for a in 0..self.k {
mat[[c, a]] = beta_row[a];
}
}
return mat;
}
match self.htbeta_matvec.as_ref() {
Some(op) => {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_a = Array1::<f64>::zeros(self.k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..self.k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] = col[c];
}
}
mat
}
None => row.htbeta.clone(),
}
}
#[must_use]
pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
(s, rhs)
}
pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
if self.hbb.dim() != (self.k, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
});
}
self.s_acc.assign(&self.hbb);
for j in 0..self.k {
self.s_acc[[j, j]] += ridge_beta;
self.rhs_acc[j] = 0.0;
}
Ok(())
}
pub fn accumulate_chunk(
&mut self,
start: usize,
end: usize,
ridge_t: f64,
mode: ArrowSolverMode,
) -> Result<(), ArrowSchurError> {
if start > end || end > self.n_rows {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
self.n_rows
),
});
}
let backend = CpuBatchedBlockSolver;
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
let v = backend.solve_block_vector(factor.view(), row.gt.view());
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..self.k {
self.rhs_acc[a] += htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened = backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
Ok(())
}
pub fn reduced_schur_and_log_det_tt(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(f64, Array2<f64>), ArrowSchurError> {
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
let backend = CpuBatchedBlockSolver;
let mut log_det_tt = 0.0_f64;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
for axis in 0..di {
log_det_tt += 2.0 * factor[[axis, axis]].ln();
}
match options.mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
}
symmetrize_upper_from_lower(&mut self.s_acc);
let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
Ok((log_det_tt, schur))
}
pub fn reduced_schur_log_det(
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let rhs = Array1::<f64>::zeros(schur.nrows());
let trust_metric_weights = None;
let (delta, schur_factor, diag) =
solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
if delta.len() != schur.nrows() || diag.iterations != 0 {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det reduced solve returned incoherent diagnostics"
.to_string(),
});
}
let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
})?;
let mut log_det_schur = 0.0_f64;
for axis in 0..schur_factor.nrows() {
log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
}
Ok(log_det_schur)
}
pub fn exact_arrow_log_det(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let (log_det_tt, schur) =
self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
}
pub fn solve(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
self.accumulate_chunk(start, end, ridge_t, options.mode)?;
}
for j in 0..self.k {
self.rhs_acc[j] -= self.gb[j];
}
symmetrize_upper_from_lower(&mut self.s_acc);
let trust_metric_weights = None;
let (delta_beta, schur_factor, _diag) =
solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
Ok((delta_t, delta_beta, schur_factor))
}
fn back_substitute(
&self,
ridge_t: f64,
delta_beta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let total_len = self.row_offsets[self.n_rows];
let mut delta_t = Array1::<f64>::zeros(total_len);
let mut rhs = Array1::<f64>::zeros(self.d);
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
let mut htbeta_delta = Array1::<f64>::zeros(di);
if let Some(op) = self.htbeta_matvec.as_ref() {
op(row_idx, delta_beta, &mut htbeta_delta);
} else {
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
htbeta_delta[c] = acc;
}
}
for c in 0..di {
rhs[c] = row.gt[c] + htbeta_delta[c];
}
let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
let row_base = self.row_offsets[row_idx];
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
}
Ok(delta_t)
}
fn validate_row(&self, row_idx: usize, row: &ArrowRowBlock) -> Result<(), ArrowSchurError> {
let expected_di = if row_idx < self.row_dims.len() {
self.row_dims[row_idx]
} else {
self.d
};
let actual_di = row.htt.nrows();
if actual_di != expected_di || row.htt.ncols() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
row.htt.dim(),
),
});
}
if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming row H_tβ shape {:?} != ({expected_di}, {})",
row.htbeta.dim(),
self.k
),
});
}
if row.gt.len() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
});
}
Ok::<(), _>(())
}
}
fn apply_analytic_penalty<S, G, D, P, H>(
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
expected_target_len: usize,
hvp_columns: usize,
scatter_target: &mut S,
mut grad_scatter: G,
mut diag_scatter: D,
seed_hvp_probe: P,
mut hvp_column_scatter: H,
) where
G: FnMut(&mut S, usize, f64),
D: FnMut(&mut S, usize, f64),
P: Fn(usize, &mut Array1<f64>),
H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
{
assert_eq!(target.len(), expected_target_len);
let grad = penalty.grad_target(target, rho_local);
for index in 0..expected_target_len {
grad_scatter(scatter_target, index, grad[index]);
}
if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
assert_eq!(diag.len(), expected_target_len);
for index in 0..expected_target_len {
diag_scatter(scatter_target, index, diag[index]);
}
return;
}
let mut probe = Array1::<f64>::zeros(expected_target_len);
for column in 0..hvp_columns {
probe.fill(0.0);
seed_hvp_probe(column, &mut probe);
let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
hvp_column_scatter(scatter_target, column, hv.view());
}
}
fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
penalty.is_row_block_diagonal()
}
#[derive(Clone)]
pub struct ArrowFactorSlab {
data: Arc<[f64]>,
offsets: Arc<[usize]>,
dims: Arc<[usize]>,
}
impl ArrowFactorSlab {
pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
let mut data = Vec::new();
let mut offsets = Vec::with_capacity(blocks.len() + 1);
let mut dims = Vec::with_capacity(blocks.len());
offsets.push(0);
for block in blocks {
let (rows, cols) = block.dim();
assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
dims.push(rows);
data.extend(block.iter().copied());
offsets.push(data.len());
}
Self {
data: data.into(),
offsets: offsets.into(),
dims: dims.into(),
}
}
pub fn len(&self) -> usize {
self.dims.len()
}
pub fn is_empty(&self) -> bool {
self.dims.is_empty()
}
pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
let dim = self.dims[row];
let range = self.offsets[row]..self.offsets[row + 1];
ArrayView2::from_shape((dim, dim), &self.data[range])
.expect("ArrowFactorSlab row offset/dim invariant violated")
}
pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.len()).map(|row| self.factor(row))
}
}
impl std::fmt::Debug for ArrowFactorSlab {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowFactorSlab")
.field("rows", &self.len())
.field("values", &self.data.len())
.finish()
}
}
#[derive(Clone)]
pub enum ArrowUndampedFactors {
SameAsDamped,
Owned(ArrowFactorSlab),
}
impl std::fmt::Debug for ArrowUndampedFactors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SameAsDamped => f.write_str("SameAsDamped"),
Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
}
}
}
fn sys_htbeta_apply_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
x: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
out.fill(0.0);
if let Some(op) = sys.htbeta_matvec.as_ref() {
op(row_idx, x, out);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (out.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..sys.k {
acc += row.htbeta[[c, a]] * x[a];
}
out[c] += acc;
}
}
}
fn sys_htbeta_accumulate_transpose(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
if let Some(op) = sys.htbeta_matvec.as_ref() {
htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (v.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..sys.k {
out[a] += row.htbeta[[c, a]] * vc;
}
}
}
}
fn sys_htbeta_materialize_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
) -> Array2<f64> {
let di = sys.row_dims[row_idx];
let k = sys.k;
let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
row.htbeta.clone()
} else {
Array2::<f64>::zeros((di, k))
};
if let Some(op) = sys.htbeta_matvec.as_ref() {
let mut e_a = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] += col[c];
}
}
} else if use_dense && row.htbeta.dim() != (di, k) {
panic!(
"row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
row.htbeta.dim()
);
}
mat
}
fn htbeta_probe_transpose(
row: usize,
op: &RowHtbetaMatvec,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
) {
let mut e_a = Array1::<f64>::zeros(k);
let mut col_a = Array1::<f64>::zeros(d);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col_a.fill(0.0);
op(row, e_a.view(), &mut col_a);
let mut acc = 0.0_f64;
for c in 0..d {
acc += col_a[c] * v[c];
}
out[a] += acc;
}
}
#[derive(Clone)]
pub enum ArrowHtbetaCache {
Dense {
blocks: Arc<[Array2<f64>]>,
estimated_bytes: usize,
},
Matvec {
op: RowHtbetaMatvec,
estimated_bytes: usize,
},
Disabled {
estimated_bytes: usize,
},
}
impl std::fmt::Debug for ArrowHtbetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense {
blocks,
estimated_bytes,
} => f
.debug_struct("Dense")
.field("blocks", &blocks.len())
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Matvec {
estimated_bytes, ..
} => f
.debug_struct("Matvec")
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Disabled { estimated_bytes } => f
.debug_struct("Disabled")
.field("estimated_bytes", estimated_bytes)
.finish(),
}
}
}
impl ArrowHtbetaCache {
fn is_available(&self) -> bool {
!matches!(self, Self::Disabled { .. })
}
fn apply_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
return false;
}
for c in 0..block.nrows() {
let mut acc = 0.0_f64;
for a in 0..block.ncols() {
acc += block[[c, a]] * delta_beta[a];
}
out[c] = acc;
}
true
}
Self::Matvec { op, .. } => {
op(row, delta_beta, out);
true
}
Self::Disabled { .. } => false,
}
}
fn apply_row_transpose_accumulate(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.nrows() != v.len() || block.ncols() != out.len() {
return false;
}
for c in 0..block.nrows() {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..block.ncols() {
out[a] += block[[c, a]] * vc;
}
}
true
}
Self::Matvec { op, .. } => {
htbeta_probe_transpose(row, op, v, out, d, k);
true
}
Self::Disabled { .. } => {
if let Some(op) = fallback_op {
htbeta_probe_transpose(row, op, v, out, d, k);
true
} else {
false
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowFactorCache {
pub htt_factors: ArrowFactorSlab,
pub htt_factors_undamped: ArrowUndampedFactors,
pub schur_factor: Option<Array2<f64>>,
pub solver_mode: ArrowSolverMode,
pub ridge_t: f64,
pub ridge_beta: f64,
pub htbeta: ArrowHtbetaCache,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub pcg_diagnostics: PcgDiagnostics,
pub gauge_deflated_directions: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ArrowFactorMinPivot {
pub min_row_pivot: Option<f64>,
pub min_schur_pivot: Option<f64>,
pub min_pivot: Option<f64>,
}
impl ArrowFactorMinPivot {
fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
let min_pivot = match (row, schur) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
Self {
min_row_pivot: row,
min_schur_pivot: schur,
min_pivot,
}
}
}
fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
out
}
fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
out
}
pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
let mut min_row_pivot = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_min_pivot(factor) {
min_row_pivot = Some(match min_row_pivot {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
}
let min_schur_pivot = cache
.schur_factor
.as_ref()
.and_then(|factor| lower_cholesky_min_pivot(factor.view()));
ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
}
pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
let mut max_pivot: Option<f64> = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_max_pivot(factor) {
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
}
if let Some(factor) = cache.schur_factor.as_ref()
&& let Some(pivot) = lower_cholesky_max_pivot(factor.view())
{
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
max_pivot
}
impl ArrowFactorCache {
pub fn n_rows(&self) -> usize {
self.htt_factors.len()
}
pub fn htbeta_available(&self) -> bool {
self.htbeta.is_available()
}
pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
ArrowUndampedFactors::Owned(factors) => factors.factor(row),
}
}
pub fn undamped_factor_count(&self) -> usize {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
ArrowUndampedFactors::Owned(factors) => factors.len(),
}
}
pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
}
pub fn delta_t_len(&self) -> usize {
self.row_offsets[self.n_rows()]
}
pub fn apply_htbeta_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if out.len() != di || delta_beta.len() != self.k {
return false;
}
self.htbeta.apply_row(row, delta_beta, out)
}
pub fn apply_htbeta_row_transpose(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if v.len() != di || out.len() != self.k {
return false;
}
self.htbeta
.apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
}
pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
let mut log_det_tt = 0.0_f64;
for l in self.htt_factors.iter() {
for i in 0..l.nrows() {
log_det_tt += l[[i, i]].ln();
}
}
log_det_tt *= 2.0;
let log_det_schur = self.schur_factor.as_ref().map(|l| {
let mut s = 0.0_f64;
for i in 0..l.nrows() {
s += l[[i, i]].ln();
}
2.0 * s
});
(log_det_tt, log_det_schur)
}
pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if !self.htbeta_available() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
but this cache's htbeta is Disabled"
.to_string(),
});
}
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
let mut out = Array1::<f64>::zeros(total_len);
let mut e_j = Array1::<f64>::zeros(self.d);
let mut w = Array1::<f64>::zeros(self.k);
for i in 0..n {
let di = self.row_dims[i];
let row_base = self.row_offsets[i];
let factor = self.undamped_factor(i);
for j in 0..di {
for c in 0..di {
e_j[c] = 0.0;
}
e_j[j] = 1.0;
let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
let a = cholesky_solve_vector(factor, &e_j_slice);
w.fill(0.0);
if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"latent_block_inverse_diagonal: H_βt^({i}) apply failed \
(htbeta cache could not supply row {i})"
),
});
}
let z = cholesky_solve_vector(schur_factor, &w);
let mut corr = 0.0_f64;
for c in 0..self.k {
corr += w[c] * z[c];
}
out[row_base + j] = a[j] + corr;
}
}
Ok(out)
}
pub fn full_inverse_apply(
&self,
w_t: ArrayView1<'_, f64>,
w_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let total_len = self.delta_t_len();
if w_t.len() != total_len || w_beta.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
w_t.len(),
w_beta.len(),
total_len,
self.k
),
});
}
let n = self.undamped_factor_count();
let mut y = Array1::<f64>::zeros(total_len);
let mut r_beta = w_beta.to_owned();
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let factor = self.undamped_factor(i);
let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
let y_row = cholesky_solve_vector(factor, &w_row);
if self.k > 0 {
let mut acc = Array1::<f64>::zeros(self.k);
if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
could not supply row {i})"
),
});
}
for c in 0..self.k {
r_beta[c] -= acc[c];
}
}
for j in 0..di {
y[base + j] = y_row[j];
}
}
let u_beta = if self.k > 0 {
self.schur_inverse_apply(r_beta.view())?
} else {
Array1::<f64>::zeros(0)
};
let mut u_t = y;
if self.k > 0 {
let mut cross = Array1::<f64>::zeros(self.d);
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let mut cross_row = cross.slice_mut(ndarray::s![..di]);
cross_row.fill(0.0);
let mut cross_owned = cross_row.to_owned();
if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
could not supply row {i})"
),
});
}
let factor = self.undamped_factor(i);
let corr = cholesky_solve_vector(factor, &cross_owned);
for j in 0..di {
u_t[base + j] -= corr[j];
}
}
}
Ok((u_t, u_beta))
}
pub fn schur_inverse_apply(
&self,
rhs: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_apply requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if rhs.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_apply: rhs length {} != K {}",
rhs.len(),
self.k
),
});
}
let rhs_owned = rhs.to_owned();
Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
}
pub fn schur_inverse_block(
&self,
block: std::ops::Range<usize>,
) -> Result<Array2<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_block requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if block.end > self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_block: block end {} exceeds K {}",
block.end, self.k
),
});
}
let w = block.len();
let mut out = Array2::<f64>::zeros((w, w));
let mut e_j = Array1::<f64>::zeros(self.k);
for (jc, j) in block.clone().enumerate() {
e_j.fill(0.0);
e_j[j] = 1.0;
let col = cholesky_solve_vector(schur_factor, &e_j);
for (ic, i) in block.clone().enumerate() {
out[[ic, jc]] = col[i];
}
}
for ic in 0..w {
for jc in (ic + 1)..w {
let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
out[[ic, jc]] = avg;
out[[jc, ic]] = avg;
}
}
Ok(out)
}
}
pub fn solve_arrow_newton_step_with_options(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, ArrowFactorCache), ArrowSchurError> {
if options.streaming_chunk_size.is_some() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur solve does not materialize the factor cache required by this entry point".to_string(),
});
}
let step = solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)?;
let backend = CpuBatchedBlockSolver;
let htbeta_estimated_bytes =
estimated_htbeta_bytes(sys.rows.len(), sys.d, sys.k).unwrap_or(usize::MAX);
let htbeta = if let Some(op) = sys.htbeta_matvec.as_ref() {
ArrowHtbetaCache::Matvec {
op: Arc::clone(op),
estimated_bytes: htbeta_estimated_bytes,
}
} else if htbeta_estimated_bytes <= ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES {
ArrowHtbetaCache::Dense {
blocks: sys
.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>()
.into(),
estimated_bytes: htbeta_estimated_bytes,
}
} else {
ArrowHtbetaCache::Disabled {
estimated_bytes: htbeta_estimated_bytes,
}
};
let htt_factors = step.htt_factors;
let (htt_factors_undamped, gauge_deflated_directions) = if ridge_t == 0.0 {
(
ArrowUndampedFactors::SameAsDamped,
step.gauge_deflated_directions,
)
} else {
let undamped = factor_blocks_for_system(sys, 0.0, options, &backend)?;
(
ArrowUndampedFactors::Owned(undamped.factors),
undamped.gauge_deflated_directions,
)
};
let cache = ArrowFactorCache {
htt_factors,
htt_factors_undamped,
schur_factor: step.schur_factor,
solver_mode: options.mode,
ridge_t,
ridge_beta,
htbeta,
d: sys.d,
row_dims: Arc::clone(&sys.row_dims),
row_offsets: Arc::clone(&sys.row_offsets),
k: sys.k,
manifold_mode_fingerprint: sys.manifold_mode_fingerprint,
row_hessian_fingerprint: sys.current_row_hessian_fingerprint(),
pcg_diagnostics: step.pcg_diagnostics,
gauge_deflated_directions,
};
Ok((step.delta_t, step.delta_beta, cache))
}
fn estimated_htbeta_bytes(n: usize, d: usize, k: usize) -> Option<usize> {
n.checked_mul(d)?
.checked_mul(k)?
.checked_mul(std::mem::size_of::<f64>())
}
pub fn solve_arrow_newton_step_core(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let streaming_options = options.with_streaming_mixed_precision_default();
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
return streaming
.solve(ridge_t, ridge_beta, &streaming_options)
.map(|(delta_t, delta_beta, _)| (delta_t, delta_beta, PcgDiagnostics::default()));
}
if let Some(device_step) = try_device_arrow_direct(sys, ridge_t, ridge_beta, options) {
return device_step;
}
if let Some(device_options) = maybe_inject_gpu_schur_matvec(sys, ridge_t, ridge_beta, options) {
return solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, &device_options).map(
|step| {
let mut diagnostics = step.pcg_diagnostics;
diagnostics.used_device_arrow = true;
(step.delta_t, step.delta_beta, diagnostics)
},
);
}
solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)
.map(|step| (step.delta_t, step.delta_beta, step.pcg_diagnostics))
}
fn maybe_inject_gpu_schur_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<ArrowSolveOptions> {
if options.mode != ArrowSolverMode::InexactPCG || options.gpu_matvec.is_some() {
return None;
}
if !sys.cross_row_penalties.is_empty() || options.streaming_chunk_size.is_some() {
return None;
}
let runtime = crate::gpu::runtime::GpuRuntime::global()?;
if !runtime
.policy()
.dense_hessian_work_target_is_gpu(sys.rows.len(), sys.k)
{
return None;
}
let matvec =
crate::gpu::arrow_schur::gpu_schur_matvec_backend(sys, ridge_t, ridge_beta).ok()?;
let mut device_options = options.clone();
device_options.gpu_matvec = Some(matvec);
Some(device_options)
}
fn try_device_arrow_direct(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError>> {
if options.mode != ArrowSolverMode::Direct {
return None;
}
if !sys.cross_row_penalties.is_empty()
|| options.streaming_chunk_size.is_some()
|| sys.hbb_matvec.is_some()
|| sys.htbeta_matvec.is_some()
{
return None;
}
let runtime = crate::gpu::runtime::GpuRuntime::global()?;
let admitted = runtime
.policy()
.dense_hessian_work_target_is_gpu(sys.rows.len(), sys.k);
if !admitted {
return None;
}
match crate::gpu::arrow_schur::solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
Ok(solution) => {
let diagnostics = PcgDiagnostics {
used_device_arrow: true,
..PcgDiagnostics::default()
};
Some(Ok((solution.delta_t, solution.delta_beta, diagnostics)))
}
Err(crate::gpu::arrow_schur::ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
Some(Err(ArrowSchurError::PerRowFactorFailed {
row,
reason: format!("device per-row block non-PD; suggested ridge bump {bump:e}"),
}))
}
Err(crate::gpu::arrow_schur::ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
Some(Err(ArrowSchurError::SchurFactorFailed { reason }))
}
Err(_) => None,
}
}
pub fn solve_with_lm_escalation_inner(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut escalations: usize = 0;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let damped_ridge_t = ridge_t + proximal_ridge;
let damped_ridge_beta = ridge_beta + proximal_ridge;
match solve_arrow_newton_step_artifacts(sys, damped_ridge_t, damped_ridge_beta, options) {
Ok(mut step) => {
step.pcg_diagnostics.ridge_escalations = escalations;
return Ok((step.delta_t, step.delta_beta, step.pcg_diagnostics));
}
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::PerRowFactorFailed { .. }
| ArrowSchurError::PerRowFactorIllConditioned { .. }
| ArrowSchurError::SchurFactorFailed { .. }
| ArrowSchurError::PcgFailed { .. }
);
last_err = Some(err);
if !recoverable {
break;
}
if attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
escalations += 1;
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
pub fn solve_arrow_newton_step_with_proximal_correction<F>(
sys: &ArrowSchurSystem,
base_ridge_t: f64,
base_ridge_beta: f64,
current_objective_value: f64,
options: &ArrowSolveOptions,
correction: &ArrowProximalCorrectionOptions,
mut trial_objective: F,
) -> Result<ArrowAcceptedProximalStep, ArrowSchurError>
where
F: for<'a, 'b> FnMut(ArrayView1<'a, f64>, ArrayView1<'b, f64>) -> f64,
{
if !current_objective_value.is_finite() {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: "current objective is not finite".to_string(),
});
}
if !(correction.ridge_growth.is_finite() && correction.ridge_growth > 1.0) {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"ridge_growth must be finite and > 1; got {}",
correction.ridge_growth
),
});
}
if !(correction.armijo_c1.is_finite()
&& correction.armijo_c1 > 0.0
&& correction.armijo_c1 < 1.0)
{
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!("armijo_c1 must be in (0, 1); got {}", correction.armijo_c1),
});
}
let grad_norm = arrow_gradient_norm(sys);
if grad_norm <= correction.gradient_tolerance.max(0.0) {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: 0,
});
}
let objective_resolution =
correction.convergence_objective_rel_tol.max(0.0) * (current_objective_value.abs() + 1.0);
let mut proximal_ridge = correction.initial_ridge.max(0.0);
let mut last_reason = String::from("no attempts were made");
let mut best_decrease: Option<(Array1<f64>, Array1<f64>, f64, f64, f64, f64, f64)> = None;
let mut smallest_increase = f64::INFINITY;
for attempt in 0..correction.max_attempts {
let ridge_t = base_ridge_t + proximal_ridge;
let ridge_beta = base_ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, ridge_t, ridge_beta, options) {
Ok((delta_t, delta_beta, _diag)) => {
let g_dot_p = arrow_gradient_dot_step(sys, delta_t.view(), delta_beta.view());
if !(g_dot_p.is_finite() && g_dot_p < 0.0) {
last_reason =
format!("candidate was not a finite descent direction: g·p={g_dot_p}");
} else {
let trial_value = trial_objective(delta_t.view(), delta_beta.view());
let armijo_bound = current_objective_value + correction.armijo_c1 * g_dot_p;
if trial_value.is_finite() && trial_value <= armijo_bound {
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge,
objective_value: current_objective_value,
trial_objective_value: trial_value,
gradient_dot_step: g_dot_p,
attempts: attempt + 1,
});
}
if trial_value.is_finite() {
let delta_obj = trial_value - current_objective_value;
if delta_obj < -objective_resolution {
let improves = best_decrease.as_ref().is_none_or(
|(_, _, best_value, _, _, _, _)| trial_value < *best_value,
);
if improves {
best_decrease = Some((
delta_t.clone(),
delta_beta.clone(),
trial_value,
g_dot_p,
ridge_t,
ridge_beta,
proximal_ridge,
));
}
} else if delta_obj < smallest_increase {
smallest_increase = delta_obj;
}
}
last_reason = {
let step_norm = (delta_t.iter().map(|v| v * v).sum::<f64>()
+ delta_beta.iter().map(|v| v * v).sum::<f64>())
.sqrt();
format!(
"Armijo rejected trial objective {trial_value}; bound {armijo_bound}; \
|g|={grad_norm:.4e} g.p={g_dot_p:.4e} |step|={step_norm:.4e} ridge={proximal_ridge:.3e}"
)
};
}
}
Err(err) => {
last_reason = err.to_string();
}
}
proximal_ridge = next_proximal_ridge(proximal_ridge, correction.ridge_growth);
}
if let Some((delta_t, delta_beta, trial_value, g_dot_p, ridge_t, ridge_beta, best_ridge)) =
best_decrease
{
let reapplied = trial_objective(delta_t.view(), delta_beta.view());
let final_value = if reapplied.is_finite() {
reapplied
} else {
trial_value
};
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge: best_ridge,
objective_value: current_objective_value,
trial_objective_value: final_value,
gradient_dot_step: g_dot_p,
attempts: correction.max_attempts,
});
}
if smallest_increase.is_finite() && smallest_increase <= objective_resolution {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: correction.max_attempts,
});
}
Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"failed after {} attempts; last rejection: {last_reason}",
correction.max_attempts
),
})
}
pub fn arrow_damped_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let total_len = sys.row_offsets[sys.rows.len()];
assert_eq!(delta_t.len(), total_len);
assert_eq!(delta_beta.len(), sys.k);
let mut lin = sys.gb.dot(&delta_beta);
let mut quad = ridge_beta * delta_beta.dot(&delta_beta);
let mut hbb_delta = Array1::<f64>::zeros(sys.k);
{
let x_slice = delta_beta
.as_slice()
.expect("delta_beta must be contiguous");
let y_slice = hbb_delta
.as_slice_mut()
.expect("hbb_delta must be contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
quad += delta_beta.dot(&hbb_delta);
let mut htbeta_x = Array1::<f64>::zeros(sys.d);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
let mut htbeta_x_i = htbeta_x.slice_mut(ndarray::s![..di]).to_owned();
htbeta_x_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, delta_beta, &mut htbeta_x_i);
for c in 0..di {
let dt_c = delta_t[row_base + c];
lin += row.gt[c] * dt_c;
quad += ridge_t * dt_c * dt_c;
for r in 0..di {
quad += dt_c * row.htt[[c, r]] * delta_t[row_base + r];
}
quad += 2.0 * dt_c * htbeta_x_i[c];
}
}
Ok(-(lin + 0.5 * quad))
}
pub fn arrow_bare_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let damped =
arrow_damped_quadratic_model_reduction(sys, delta_t, delta_beta, ridge_t, ridge_beta)?;
let ridge_beta_contrib = 0.5 * ridge_beta * delta_beta.dot(&delta_beta);
let ridge_t_contrib = {
let mut acc = 0.0_f64;
for v in delta_t.iter() {
acc += v * v;
}
0.5 * ridge_t * acc
};
Ok(damped + ridge_beta_contrib + ridge_t_contrib)
}
fn next_proximal_ridge(current: f64, growth: f64) -> f64 {
if current > 0.0 {
current * growth
} else {
DEFAULT_PROXIMAL_INITIAL_RIDGE
}
}
fn arrow_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut sum = 0.0;
for row in sys.rows.iter() {
for &v in row.gt.iter() {
sum += v * v;
}
}
for &v in sys.gb.iter() {
sum += v * v;
}
sum.sqrt()
}
fn arrow_gradient_dot_step(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_t.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut out = 0.0;
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
for c in 0..di {
out += row.gt[c] * delta_t[row_base + c];
}
}
for a in 0..sys.k {
out += sys.gb[a] * delta_beta[a];
}
out
}
struct ArrowNewtonStepArtifacts {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
htt_factors: ArrowFactorSlab,
schur_factor: Option<Array2<f64>>,
pcg_diagnostics: PcgDiagnostics,
gauge_deflated_directions: usize,
}
struct ArrowBlockFactorization {
factors: ArrowFactorSlab,
gauge_deflated_directions: usize,
}
fn factor_blocks_for_system<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
ridge_t: f64,
options: &ArrowSolveOptions,
backend: &B,
) -> Result<ArrowBlockFactorization, ArrowSchurError> {
let Some(deflation) = sys.row_gauge_deflation.as_ref() else {
return Ok(ArrowBlockFactorization {
factors: backend.factor_blocks(
&sys.rows,
ridge_t,
sys.d,
options.tolerate_ill_conditioning,
)?,
gauge_deflated_directions: 0,
});
};
let mut blocks = Vec::with_capacity(sys.rows.len());
let mut count = 0usize;
for (row_idx, row) in sys.rows.iter().enumerate() {
let result = factor_one_row_result(
row,
ridge_t,
sys.row_dims[row_idx],
row_idx,
options.tolerate_ill_conditioning,
deflation.row(row_idx),
)?;
count += result.gauge_deflated_directions;
blocks.push(result.factor);
}
Ok(ArrowBlockFactorization {
factors: ArrowFactorSlab::from_blocks(blocks),
gauge_deflated_directions: count,
})
}
enum MixedPrecisionAttempt {
Certified {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
schur_factor: Array2<f64>,
refinement_steps: usize,
},
Fallback {
reason: String,
},
}
fn back_substitute_delta_t<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
delta_beta: ArrayView1<'_, f64>,
backend: &B,
) -> Array1<f64> {
let n = sys.rows.len();
let total_dt_len = sys.row_offsets[n];
let mut delta_t = Array1::<f64>::zeros(total_dt_len);
let mut rhs = Array1::<f64>::zeros(sys.d);
let mut htbeta_delta = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
assert_eq!(sys.rows[i].gt.len(), di);
for c in 0..di {
htbeta_delta[c] = 0.0;
}
let mut htbeta_slice = htbeta_delta.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, &sys.rows[i], delta_beta, &mut htbeta_slice);
{
let mut rhs_i = rhs.slice_mut(ndarray::s![..di]);
for c in 0..di {
rhs_i[c] = sys.rows[i].gt[c] + htbeta_slice[c];
}
}
let rhs_slice = rhs.slice(ndarray::s![..di]).to_owned();
let dt_i = backend.solve_block_vector(htt_factors.factor(i), rhs_slice.view());
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
delta_t
}
fn try_mixed_precision_arrow_solve(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
htt_factors: &ArrowFactorSlab,
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<Option<MixedPrecisionAttempt>, ArrowSchurError> {
let MixedPrecisionPolicy::Certified {
max_refinement_steps,
residual_relative_tolerance,
kappa_unit_roundoff_margin,
} = options.mixed_precision
else {
return Ok(None);
};
if options.trust_region.radius.is_finite() {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "trust-region-truncated dense solves are not certified by the mixed-precision refinement path".to_string(),
}));
}
let schur_factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
if !options.tolerate_ill_conditioning {
let schur_kappa = cholesky_factor_kappa_estimate(&schur_factor);
if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"reduced Schur complement Cholesky succeeded but is ill-conditioned \
(kappa_estimate={schur_kappa:e}); accumulated per-row \
(H_tt)^-1 contamination would yield an inaccurate delta_beta"
),
});
}
}
if let Some(reason) =
mixed_precision_kappa_gate_failure(htt_factors, &schur_factor, kappa_unit_roundoff_margin)
{
return Ok(Some(MixedPrecisionAttempt::Fallback { reason }));
}
let row_factors_f32 = arrow_factor_slab_to_f32(htt_factors);
let schur_factor_f32 = schur_factor.mapv(|v| v as f32);
let (rhs_t, rhs_beta) = arrow_rhs(sys);
let mut x = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
rhs_t.view(),
rhs_beta.view(),
);
let certificate_tol = residual_relative_tolerance
.max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
for refinement_steps in 0..=max_refinement_steps {
let (res_t, res_beta) = arrow_residual(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
);
let certificate = arrow_backward_error_certificate(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
res_t.view(),
res_beta.view(),
);
if certificate <= certificate_tol {
return Ok(Some(MixedPrecisionAttempt::Certified {
delta_t: x.0,
delta_beta: x.1,
schur_factor,
refinement_steps,
}));
}
if refinement_steps == max_refinement_steps {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: format!(
"f64 residual certificate did not converge after {max_refinement_steps} refinement steps \
(backward_error={certificate:e}, tolerance={certificate_tol:e})"
),
}));
}
let correction = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
res_t.view(),
res_beta.view(),
);
if !correction
.0
.iter()
.chain(correction.1.iter())
.all(|v| v.is_finite())
{
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "f32 refinement correction produced a non-finite value".to_string(),
}));
}
for i in 0..x.0.len() {
x.0[i] += correction.0[i];
}
for i in 0..x.1.len() {
x.1[i] += correction.1[i];
}
}
Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "mixed refinement loop exhausted without certification".to_string(),
}))
}
fn mixed_precision_kappa_gate_failure(
htt_factors: &ArrowFactorSlab,
schur_factor: &Array2<f64>,
margin: f64,
) -> Option<String> {
let mut max_kappa = cholesky_factor_kappa_estimate(schur_factor);
let mut min_pivot = lower_cholesky_min_pivot(schur_factor.view());
let mut max_pivot = lower_cholesky_max_pivot(schur_factor.view());
for factor in htt_factors.iter() {
let owned = factor.to_owned();
max_kappa = max_kappa.max(cholesky_factor_kappa_estimate(&owned));
if let Some(pivot) = lower_cholesky_min_pivot(owned.view()) {
min_pivot = Some(match min_pivot {
Some(current) => current.min(pivot),
None => pivot,
});
}
if let Some(pivot) = lower_cholesky_max_pivot(owned.view()) {
max_pivot = Some(match max_pivot {
Some(current) => current.max(pivot),
None => pivot,
});
}
}
if let (Some(min_pivot), Some(max_pivot)) = (min_pivot, max_pivot) {
if min_pivot > 0.0 && max_pivot.is_finite() {
max_kappa = max_kappa.max(max_pivot / min_pivot);
} else {
max_kappa = f64::INFINITY;
}
}
let kappa_u = max_kappa * F32_UNIT_ROUNDOFF;
let threshold = margin
.min(MIXED_PRECISION_KAPPA_MARGIN_CEILING)
.max(F32_UNIT_ROUNDOFF);
if !(max_kappa.is_finite() && kappa_u < threshold) {
Some(format!(
"kappa gate refused f32 refinement: kappa_estimate={max_kappa:e}, \
kappa*u_f32={kappa_u:e}, required < {threshold:e}"
))
} else {
None
}
}
fn arrow_factor_slab_to_f32(htt_factors: &ArrowFactorSlab) -> Vec<Array2<f32>> {
htt_factors
.iter()
.map(|factor| factor.mapv(|v| v as f32))
.collect()
}
fn arrow_rhs(sys: &ArrowSchurSystem) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut rhs_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
rhs_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut rhs_beta = Array1::<f64>::zeros(sys.k);
for c in 0..sys.k {
rhs_beta[c] = -sys.gb[c];
}
(rhs_t, rhs_beta)
}
fn solve_arrow_system_f32(
sys: &ArrowSchurSystem,
row_factors: &[Array2<f32>],
schur_factor: &Array2<f32>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut y_rows = Vec::<Array1<f32>>::with_capacity(n);
let mut reduced_beta = rhs_beta.mapv(|v| v as f32);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let rhs_i = rhs_t.slice(ndarray::s![base..base + di]).mapv(|v| v as f32);
let y_i = cholesky_solve_lower_f32(&row_factors[i], &rhs_i);
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]).mapv(|v| v as f32);
for beta_col in 0..sys.k {
let mut acc = 0.0_f32;
for row_axis in 0..di {
acc += htbeta[[row_axis, beta_col]] * y_i[row_axis];
}
reduced_beta[beta_col] -= acc;
}
y_rows.push(y_i);
}
let x_beta_f32 = cholesky_solve_lower_f32(schur_factor, &reduced_beta);
let mut x_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]).mapv(|v| v as f32);
let mut cross = Array1::<f32>::zeros(di);
for row_axis in 0..di {
let mut acc = 0.0_f32;
for beta_col in 0..sys.k {
acc += htbeta[[row_axis, beta_col]] * x_beta_f32[beta_col];
}
cross[row_axis] = acc;
}
let correction = cholesky_solve_lower_f32(&row_factors[i], &cross);
for row_axis in 0..di {
x_t[base + row_axis] = (y_rows[i][row_axis] - correction[row_axis]) as f64;
}
}
let x_beta = x_beta_f32.mapv(|v| v as f64);
(x_t, x_beta)
}
fn cholesky_solve_lower_f32(l: &Array2<f32>, b: &Array1<f32>) -> Array1<f32> {
let n = l.nrows();
let mut y = Array1::<f32>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f32>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
fn arrow_residual(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let (ax_t, ax_beta) = arrow_operator_apply(sys, ridge_t, ridge_beta, x_t, x_beta);
let mut res_t = rhs_t.to_owned();
let mut res_beta = rhs_beta.to_owned();
for i in 0..res_t.len() {
res_t[i] -= ax_t[i];
}
for i in 0..res_beta.len() {
res_beta[i] -= ax_beta[i];
}
(res_t, res_beta)
}
fn arrow_operator_apply(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut y_t = Array1::<f64>::zeros(sys.row_offsets[n]);
let mut y_beta = Array1::<f64>::zeros(sys.k);
{
let x_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
for beta_col in 0..sys.k {
y_beta[beta_col] += ridge_beta * x_beta[beta_col];
}
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let row = &sys.rows[i];
for a in 0..di {
let mut acc = ridge_t * x_t[base + a];
for b in 0..di {
acc += row.htt[[a, b]] * x_t[base + b];
}
y_t[base + a] = acc;
}
let mut htbeta_xb = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, row, x_beta, &mut htbeta_xb);
for a in 0..di {
y_t[base + a] += htbeta_xb[a];
}
let x_ti = x_t.slice(ndarray::s![base..base + di]).to_owned();
sys_htbeta_accumulate_transpose(sys, i, row, x_ti.view(), &mut y_beta);
}
(y_t, y_beta)
}
fn arrow_backward_error_certificate(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
res_t: ArrayView1<'_, f64>,
res_beta: ArrayView1<'_, f64>,
) -> f64 {
let residual_norm = infinity_norm_pair(res_t, res_beta);
let operator_norm = arrow_operator_infinity_norm(sys, ridge_t, ridge_beta);
let solution_norm = infinity_norm_pair(x_t, x_beta);
let rhs_norm = infinity_norm_pair(rhs_t, rhs_beta);
let denom = operator_norm * solution_norm + rhs_norm;
if denom > 0.0 {
residual_norm / denom
} else {
residual_norm
}
}
fn infinity_norm_pair(lhs: ArrayView1<'_, f64>, rhs: ArrayView1<'_, f64>) -> f64 {
let mut out = 0.0_f64;
for &v in lhs.iter().chain(rhs.iter()) {
out = out.max(v.abs());
}
out
}
fn arrow_operator_infinity_norm(sys: &ArrowSchurSystem, ridge_t: f64, ridge_beta: f64) -> f64 {
let mut out = 0.0_f64;
for i in 0..sys.rows.len() {
let di = sys.row_dims[i];
let row = &sys.rows[i];
let htbeta = sys_htbeta_materialize_row(sys, i, row);
for a in 0..di {
let mut row_sum = 0.0_f64;
for b in 0..di {
row_sum += row.htt[[a, b]].abs();
}
row_sum += ridge_t;
for beta_col in 0..sys.k {
row_sum += htbeta[[a, beta_col]].abs();
}
out = out.max(row_sum);
}
}
let hbb = sys.effective_penalty_op().to_dense();
for beta_row in 0..sys.k {
let mut row_sum = 0.0_f64;
for beta_col in 0..sys.k {
row_sum += hbb[[beta_row, beta_col]].abs();
}
row_sum += ridge_beta;
for i in 0..sys.rows.len() {
let di = sys.row_dims[i];
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]);
for a in 0..di {
row_sum += htbeta[[a, beta_row]].abs();
}
}
out = out.max(row_sum);
}
out
}
fn solve_arrow_newton_step_artifacts(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
if !sys.cross_row_penalties.is_empty() {
return solve_arrow_newton_step_cross_row(sys, ridge_t, ridge_beta, options);
}
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
let (delta_t, delta_beta, schur_factor) = streaming.solve(ridge_t, ridge_beta, options)?;
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors: ArrowFactorSlab::from_blocks(Vec::new()),
schur_factor,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
});
}
let backend = CpuBatchedBlockSolver;
let block_factorization = factor_blocks_for_system(sys, ridge_t, options, &backend)?;
let htt_factors = block_factorization.factors;
let gauge_deflated_directions = block_factorization.gauge_deflated_directions;
let rhs_beta = reduced_rhs_beta(sys, &htt_factors, &backend);
let trust_metric_weights = None;
let mut mixed_precision_status = MixedPrecisionStatus::Off;
let (delta_beta, schur_factor, mut pcg_diagnostics) = match options.mode {
ArrowSolverMode::Direct => {
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::SqrtBA => {
let schur = build_dense_schur_sqrt_ba(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::InexactPCG => {
if options.mixed_precision.is_enabled() {
log::info!(
"arrow-Schur mixed precision fallback to f64: InexactPCG does not expose a dense Schur factor for certified f32 refinement"
);
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
if options.trust_region.radius == f64::INFINITY {
if let Some(device_data) = sys.device_sae_pcg.as_ref() {
let max_iterations = options
.pcg
.max_iterations
.min(options.trust_region.max_iterations);
let relative_tolerance = options
.pcg
.relative_tolerance
.max(options.trust_region.steihaug_relative_tolerance);
if let Ok((delta, mut diag)) =
crate::gpu::arrow_schur::solve_sae_matrix_free_pcg(
sys,
device_data.as_ref(),
ridge_t,
ridge_beta,
&rhs_beta,
max_iterations,
relative_tolerance,
)
{
diag.used_device_arrow = true;
return Ok(ArrowNewtonStepArtifacts {
delta_t: back_substitute_delta_t(
sys,
&htt_factors,
delta.view(),
&backend,
),
delta_beta: delta,
htt_factors,
schur_factor: None,
pcg_diagnostics: diag,
gauge_deflated_directions,
});
}
}
}
let (delta, diag) = steihaug_pcg_auto(
sys,
&htt_factors,
ridge_beta,
&rhs_beta,
&options.pcg,
&options.trust_region,
&backend,
options.gpu_matvec.as_ref(),
trust_metric_weights,
)?;
(delta, None, diag)
}
};
if mixed_precision_status != MixedPrecisionStatus::Off {
pcg_diagnostics.mixed_precision_status = mixed_precision_status;
}
let delta_t = back_substitute_delta_t(sys, &htt_factors, delta_beta.view(), &backend);
Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor,
pcg_diagnostics,
gauge_deflated_directions,
})
}
struct ArrowBlockDiagInverse<'a, B: BatchedBlockSolver> {
sys: &'a ArrowSchurSystem,
backend: &'a B,
htt_factors: ArrowFactorSlab,
schur_factor: Array2<f64>,
}
impl<'a, B: BatchedBlockSolver> ArrowBlockDiagInverse<'a, B> {
fn build(
sys: &'a ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
tolerate_ill_conditioning: bool,
backend: &'a B,
) -> Result<Self, ArrowSchurError>
where
B: Sync,
{
let htt_factors =
backend.factor_blocks(&sys.rows, ridge_t, sys.d, tolerate_ill_conditioning)?;
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, backend)?;
let schur_factor =
cholesky_lower(&schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
Ok(Self {
sys,
backend,
htt_factors,
schur_factor,
})
}
fn apply(
&self,
r_t: ArrayView1<'_, f64>,
r_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let sys = self.sys;
let n = sys.rows.len();
let k = sys.k;
let mut rhs_beta = r_beta.to_owned();
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let r_ti = r_t.slice(ndarray::s![base..base + di]).to_owned();
let u_i = self
.backend
.solve_block_vector(self.htt_factors.factor(i), r_ti.view());
let mut acc = Array1::<f64>::zeros(k);
sys_htbeta_accumulate_transpose(sys, i, &sys.rows[i], u_i.view(), &mut acc);
for a in 0..k {
rhs_beta[a] -= acc[a];
}
}
let x_beta = cholesky_solve_lower(&self.schur_factor, &rhs_beta);
let total_dt = sys.row_offsets[n];
let mut x_t = Array1::<f64>::zeros(total_dt);
let mut htbeta_xb = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
htbeta_xb[c] = 0.0;
}
let mut slab = htbeta_xb.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, &sys.rows[i], x_beta.view(), &mut slab);
let mut rhs_i = Array1::<f64>::zeros(di);
for c in 0..di {
rhs_i[c] = r_t[base + c] - slab[c];
}
let xi = self
.backend
.solve_block_vector(self.htt_factors.factor(i), rhs_i.view());
for c in 0..di {
x_t[base + c] = xi[c];
}
}
(x_t, x_beta)
}
}
fn arrow_cross_row_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut y_t = Array1::<f64>::zeros(total_dt);
let mut y_beta = Array1::<f64>::zeros(k);
let mut htbeta_xb = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let row = &sys.rows[i];
for a in 0..di {
let mut acc = ridge_t * x_t[base + a];
for b in 0..di {
acc += row.htt[[a, b]] * x_t[base + b];
}
y_t[base + a] = acc;
}
for c in 0..di {
htbeta_xb[c] = 0.0;
}
let mut slab = htbeta_xb.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, row, x_beta, &mut slab);
for c in 0..di {
y_t[base + c] += slab[c];
}
let x_ti = x_t.slice(ndarray::s![base..base + di]).to_owned();
sys_htbeta_accumulate_transpose(sys, i, row, x_ti.view(), &mut y_beta);
}
{
let x_beta_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_beta_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_beta_slice, y_beta_slice);
}
for a in 0..k {
y_beta[a] += ridge_beta * x_beta[a];
}
sys.apply_cross_row_penalty_hessian(x_t, &mut y_t);
(y_t, y_beta)
}
fn solve_arrow_newton_step_cross_row(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let precond = ArrowBlockDiagInverse::build(
sys,
ridge_t,
ridge_beta,
options.tolerate_ill_conditioning,
&backend,
)?;
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut b_t = Array1::<f64>::zeros(total_dt);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
b_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut b_beta = Array1::<f64>::zeros(k);
for a in 0..k {
b_beta[a] = -sys.gb[a];
}
let mut x_t = Array1::<f64>::zeros(total_dt);
let mut x_beta = Array1::<f64>::zeros(k);
let mut r_t = b_t.clone();
let mut r_beta = b_beta.clone();
let (mut z_t, mut z_beta) = precond.apply(r_t.view(), r_beta.view());
let mut p_t = z_t.clone();
let mut p_beta = z_beta.clone();
let mut rz = dot2(&r_t, &r_beta, &z_t, &z_beta);
let b_norm = (dot2(&b_t, &b_beta, &b_t, &b_beta)).sqrt();
const CROSS_ROW_CG_ABS_TOL: f64 = 1e-12;
const CROSS_ROW_CG_REL_TOL: f64 = 1e-13;
const CROSS_ROW_CG_MIN_ITER_BUDGET: usize = 64;
const CROSS_ROW_CG_ITER_MULTIPLE: usize = 4;
let tol = CROSS_ROW_CG_ABS_TOL.max(CROSS_ROW_CG_REL_TOL * b_norm);
let max_iter = (total_dt + k).max(CROSS_ROW_CG_MIN_ITER_BUDGET) * CROSS_ROW_CG_ITER_MULTIPLE;
let mut iters = 0usize;
let mut converged = b_norm == 0.0;
while iters < max_iter && !converged {
let (ap_t, ap_beta) =
arrow_cross_row_matvec(sys, ridge_t, ridge_beta, p_t.view(), p_beta.view());
let pap = dot2(&p_t, &p_beta, &ap_t, &ap_beta);
if !(pap.is_finite() && pap > 0.0) {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG hit non-positive curvature pᵀAp={pap:e}; \
the cross-row penalty Hessian or arrow block is not PD at this iterate"
),
});
}
let alpha = rz / pap;
for i in 0..total_dt {
x_t[i] += alpha * p_t[i];
r_t[i] -= alpha * ap_t[i];
}
for a in 0..k {
x_beta[a] += alpha * p_beta[a];
r_beta[a] -= alpha * ap_beta[a];
}
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
iters += 1;
if r_norm <= tol {
converged = true;
break;
}
let (nz_t, nz_beta) = precond.apply(r_t.view(), r_beta.view());
z_t = nz_t;
z_beta = nz_beta;
let rz_new = dot2(&r_t, &r_beta, &z_t, &z_beta);
let beta_cg = rz_new / rz;
for i in 0..total_dt {
p_t[i] = z_t[i] + beta_cg * p_t[i];
}
for a in 0..k {
p_beta[a] = z_beta[a] + beta_cg * p_beta[a];
}
rz = rz_new;
}
if !converged {
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG did not converge in {iters} iters \
(‖r‖={r_norm:e}, tol={tol:e})"
),
});
}
let final_residual = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
let diag = PcgDiagnostics {
iterations: iters,
matvec_calls: iters,
precond_apply_calls: iters + 1,
ridge_escalations: 0,
final_relative_residual: if b_norm > 0.0 {
final_residual / b_norm
} else {
0.0
},
stopping_reason: PcgStopReason::Converged,
mixed_precision_status: MixedPrecisionStatus::Off,
used_device_arrow: false,
};
Ok(ArrowNewtonStepArtifacts {
delta_t: x_t,
delta_beta: x_beta,
htt_factors: precond.htt_factors,
schur_factor: Some(precond.schur_factor),
pcg_diagnostics: diag,
gauge_deflated_directions: 0,
})
}
fn dot2(a_t: &Array1<f64>, a_beta: &Array1<f64>, b_t: &Array1<f64>, b_beta: &Array1<f64>) -> f64 {
let mut acc = 0.0_f64;
for i in 0..a_t.len() {
acc += a_t[i] * b_t[i];
}
for a in 0..a_beta.len() {
acc += a_beta[a] * b_beta[a];
}
acc
}
fn cholesky_solve_lower(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
fn reduced_rhs_beta<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
backend: &B,
) -> Array1<f64> {
let k = sys.k;
let mut rhs_beta = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let v = backend.solve_block_vector(htt_factors.factor(i), row.gt.view());
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut rhs_beta);
}
for j in 0..k {
rhs_beta[j] -= sys.gb[j];
}
rhs_beta
}
#[derive(Clone, Copy)]
enum SchurReductionKind {
Direct,
SqrtBa,
}
#[inline]
fn row_schur_contribution_factors<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
) -> (Array2<f64>, Array2<f64>) {
let htbeta = sys_htbeta_materialize_row(sys, row_idx, row);
match kind {
SchurReductionKind::Direct => {
let solved = backend.solve_block_matrix(htt_factor, htbeta.view());
(htbeta, solved)
}
SchurReductionKind::SqrtBa => {
let whitened = backend.sqrt_solve_block_matrix(htt_factor, htbeta.view());
(whitened.clone(), whitened)
}
}
}
#[inline]
fn subtract_row_schur_contribution<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
schur: &mut Array2<f64>,
) {
let (left, right) =
row_schur_contribution_factors(sys, row_idx, row, htt_factor, backend, kind);
backend.block_gemm_subtract(schur, &left, &right);
}
fn tile_schur_partial<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
backend: &B,
kind: SchurReductionKind,
ordinal: usize,
range: Range<usize>,
) -> Array2<f64> {
let k = sys.k;
let mut factors: Vec<(Array2<f64>, Array2<f64>)> = Vec::with_capacity(range.len());
let mut total_d = 0usize;
for i in range.clone() {
let (left, right) = row_schur_contribution_factors(
sys,
i,
&sys.rows[i],
htt_factors.factor(i),
backend,
kind,
);
total_d += left.nrows();
factors.push((left, right));
}
if total_d > 0 && k > 0 {
let mut left_stack = Array2::<f64>::zeros((total_d, k));
let mut right_stack = Array2::<f64>::zeros((total_d, k));
let mut base = 0usize;
for (left, right) in &factors {
let di = left.nrows();
left_stack
.slice_mut(ndarray::s![base..base + di, ..])
.assign(left);
right_stack
.slice_mut(ndarray::s![base..base + di, ..])
.assign(right);
base += di;
}
if let Some(product) =
crate::gpu::try_fast_atb_on_ordinal(ordinal, left_stack.view(), right_stack.view())
{
return product.mapv(|v| -v);
}
}
let mut partial = Array2::<f64>::zeros((k, k));
for (left, right) in &factors {
backend.block_gemm_subtract(&mut partial, left, right);
}
partial
}
fn reduce_row_schur_contributions<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
backend: &B,
kind: SchurReductionKind,
schur: &mut Array2<f64>,
) {
let n = sys.rows.len();
let k = sys.k;
let tiles = crate::gpu::runtime::GpuRuntime::global()
.map(|rt| crate::gpu::pool::balanced_partition(rt, n))
.filter(|tiles| tiles.len() > 1);
let Some(tiles) = tiles else {
for (i, row) in sys.rows.iter().enumerate() {
subtract_row_schur_contribution(
sys,
i,
row,
htt_factors.factor(i),
backend,
kind,
schur,
);
}
return;
};
let partials: Vec<Array2<f64>> = std::thread::scope(|scope| {
let handles: Vec<_> = tiles
.iter()
.map(|(ordinal, range)| {
let ordinal = *ordinal;
let range = range.clone();
scope.spawn(move || {
#[cfg(target_os = "linux")]
{
if let Some(ctx) = crate::gpu::runtime::cuda_context_for(ordinal) {
if ctx.bind_to_thread().is_err() {
}
}
}
tile_schur_partial(sys, htt_factors, backend, kind, ordinal, range)
})
})
.collect();
handles
.into_iter()
.map(|handle| handle.join().expect("schur-reduction tile thread panicked"))
.collect()
});
for partial in &partials {
for a in 0..k {
for b in 0..k {
schur[[a, b]] += partial[[a, b]];
}
}
}
}
fn build_dense_schur_direct<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
let op = sys.effective_penalty_op();
if op.dim() != k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Direct BA requires a K×K shared H_ββ penalty operator".to_string(),
});
}
let mut schur = op.to_dense();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
reduce_row_schur_contributions(
sys,
htt_factors,
backend,
SchurReductionKind::Direct,
&mut schur,
);
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn build_dense_schur_sqrt_ba<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Array2<f64>, ArrowSchurError> {
let k = sys.k;
let op = sys.effective_penalty_op();
if op.dim() != k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "Square-Root BA direct solve requires a K×K shared H_ββ penalty operator"
.to_string(),
});
}
let mut schur = op.to_dense();
for j in 0..k {
schur[[j, j]] += ridge_beta;
}
reduce_row_schur_contributions(
sys,
htt_factors,
backend,
SchurReductionKind::SqrtBa,
&mut schur,
);
symmetrize_upper_from_lower(&mut schur);
Ok(schur)
}
fn mixed_precision_reduced_beta(
schur: &Array2<f64>,
factor: &Array2<f64>,
rhs: &Array1<f64>,
options: &ArrowSolveOptions,
) -> Option<Array1<f64>> {
let MixedPrecisionPolicy::Certified {
max_refinement_steps,
residual_relative_tolerance,
kappa_unit_roundoff_margin,
} = options.mixed_precision
else {
return None;
};
if options.trust_region.radius.is_finite() {
return None;
}
let n = schur.nrows();
if n == 0 {
return None;
}
let kappa = cholesky_factor_kappa_estimate(factor);
if !kappa.is_finite() || kappa * F32_UNIT_ROUNDOFF >= kappa_unit_roundoff_margin {
return None;
}
let factor_f32 = factor.mapv(|v| v as f32);
let s_inf = matrix_inf_norm(schur);
let rhs_inf = rhs.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
let certificate_tol = residual_relative_tolerance
.max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
let mut x = cholesky_solve_lower_f32(&factor_f32, &rhs.mapv(|v| v as f32)).mapv(|v| v as f64);
let mut last_residual = f64::INFINITY;
for _ in 0..=max_refinement_steps {
let sx = schur.dot(&x);
let mut r = rhs.clone();
r -= &sx;
let r_inf = r.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
let x_inf = x.iter().fold(0.0_f64, |a, &b| a.max(b.abs()));
let denom = s_inf * x_inf + rhs_inf;
let backward_error = if denom > 0.0 { r_inf / denom } else { 0.0 };
if backward_error <= certificate_tol {
return Some(x);
}
if !(r_inf < last_residual) {
return None;
}
last_residual = r_inf;
let delta = cholesky_solve_lower_f32(&factor_f32, &r.mapv(|v| v as f32)).mapv(|v| v as f64);
x += δ
}
None
}
fn matrix_inf_norm(a: &Array2<f64>) -> f64 {
let mut max_row = 0.0_f64;
for row in a.rows() {
let s: f64 = row.iter().map(|v| v.abs()).sum();
if s > max_row {
max_row = s;
}
}
max_row
}
fn solve_dense_reduced_system(
schur: &Array2<f64>,
rhs_beta: &Array1<f64>,
options: &ArrowSolveOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, Option<Array2<f64>>, PcgDiagnostics), ArrowSchurError> {
let factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
if !options.tolerate_ill_conditioning {
let schur_kappa = cholesky_factor_kappa_estimate(&factor);
if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"reduced Schur complement Cholesky succeeded but is ill-conditioned \
(kappa_estimate={schur_kappa:e}); accumulated per-row \
(H_tt)⁻¹ contamination would yield an inaccurate Δβ"
),
});
}
}
let direct = mixed_precision_reduced_beta(schur, &factor, rhs_beta, options)
.unwrap_or_else(|| cholesky_solve_vector(&factor, rhs_beta));
if step_inside_trust_region(direct.view(), options.trust_region.radius, metric_weights) {
return Ok((direct, Some(factor), PcgDiagnostics::default()));
}
let identity = IdentityPreconditioner;
let (delta, diag) = steihaug_dense_system(
schur,
rhs_beta,
&identity,
&ArrowPcgOptions {
max_iterations: options.trust_region.max_iterations,
relative_tolerance: options.trust_region.steihaug_relative_tolerance,
},
&options.trust_region,
metric_weights,
)?;
Ok((delta, Some(factor), diag))
}
pub fn solve_streaming_reduced_beta(
s_acc: &Array2<f64>,
rhs_beta: &Array1<f64>,
options: &ArrowSolveOptions,
) -> Result<Array1<f64>, ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let mut schur = s_acc.clone();
symmetrize_upper_from_lower(&mut schur);
if proximal_ridge > 0.0 {
for j in 0..schur.nrows() {
schur[[j, j]] += proximal_ridge;
}
}
if crate::gpu::runtime::GpuRuntime::is_available() {
match crate::gpu::arrow_schur::solve_reduced_beta_pcg(
&schur,
rhs_beta,
options.trust_region.max_iterations,
options.trust_region.steihaug_relative_tolerance,
) {
Ok(delta_beta) => return Ok(delta_beta),
Err(crate::gpu::arrow_schur::ArrowSchurGpuFailure::Unavailable) => {}
Err(_) => {
}
}
}
match solve_dense_reduced_system(&schur, rhs_beta, options, None) {
Ok((delta_beta, _factor, _diag)) => return Ok(delta_beta),
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::SchurFactorFailed { .. } | ArrowSchurError::PcgFailed { .. }
);
last_err = Some(err);
if !recoverable || attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
fn step_inside_trust_region(
step: ArrayView1<'_, f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> bool {
!radius.is_finite() || metric_norm(step, metric_weights) <= radius
}
fn schur_matvec<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
x: &Array1<f64>,
out: &mut Array1<f64>,
backend: &B,
) {
out.fill(0.0);
let k = sys.k;
{
let x_slice = x.as_slice().expect("x must be contiguous");
let out_slice = out.as_slice_mut().expect("out must be contiguous");
sys.penalty_matvec_add(x_slice, out_slice);
for a in 0..k {
out_slice[a] += ridge_beta * x_slice[a];
}
}
let mut local = Array1::<f64>::zeros(sys.d);
let mut neg_contrib = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let mut local_i = local.slice_mut(ndarray::s![..di]).to_owned();
local_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, x.view(), &mut local_i);
let solved = backend.solve_block_vector(htt_factors.factor(i), local_i.view());
neg_contrib.fill(0.0);
sys_htbeta_accumulate_transpose(sys, i, row, solved.view(), &mut neg_contrib);
for a in 0..k {
out[a] -= neg_contrib[a];
}
}
}
#[derive(Clone)]
enum BlockFactor {
Chol {
factor: FaerLlt<f64>,
range: Range<usize>,
},
Scalar {
inv: Array1<f64>,
range: Range<usize>,
},
}
impl std::fmt::Debug for BlockFactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BlockFactor::Chol { range, .. } => {
write!(f, "BlockFactor::Chol {{ range: {:?} }}", range)
}
BlockFactor::Scalar { inv, range } => {
write!(
f,
"BlockFactor::Scalar {{ inv.len: {}, range: {:?} }}",
inv.len(),
range
)
}
}
}
}
#[derive(Debug, Clone)]
pub struct JacobiPreconditioner {
blocks: Vec<BlockFactor>,
}
const BLOCK_JACOBI_MAX_BLOCK: usize = 256;
const JACOBI_DIAGONAL_PD_FLOOR: f64 = 1e-18;
impl JacobiPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let use_block = !sys.block_offsets.is_empty()
&& sys
.block_offsets
.iter()
.map(|r| r.end.saturating_sub(r.start))
.max()
.unwrap_or(0)
<= BLOCK_JACOBI_MAX_BLOCK;
if use_block {
Self::build_block_jacobi(sys, htt_factors, ridge_beta, backend)
} else {
Self::build_scalar_jacobi(sys, htt_factors, ridge_beta, backend)
}
}
fn build_scalar_jacobi<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let k = sys.k;
let mut diag = Array1::<f64>::zeros(k);
{
let diag_slice = diag.as_slice_mut().expect("diag must be contiguous");
sys.penalty_diagonal_add(diag_slice);
}
for a in 0..k {
diag[a] += ridge_beta;
}
let mut col = Array1::<f64>::zeros(sys.d);
let mut e_a = Array1::<f64>::zeros(k);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let mut col_i = col.slice_mut(ndarray::s![..di]).to_owned();
for a in 0..k {
if sys.htbeta_matvec.is_some() || row.htbeta.dim() != (di, k) {
e_a.fill(0.0);
e_a[a] = 1.0;
col_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, e_a.view(), &mut col_i);
} else {
for c in 0..di {
col_i[c] = row.htbeta[[c, a]];
}
}
let solved = backend.solve_block_vector(htt_factors.factor(i), col_i.view());
let mut acc = 0.0;
for c in 0..di {
acc += col_i[c] * solved[c];
}
diag[a] -= acc;
}
}
let mut blocks = Vec::with_capacity(k);
for a in 0..k {
let v = diag[a];
if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"invalid Schur Jacobi diagonal at index {a}: {v}; \
operator regularization is required"
),
});
}
blocks.push(BlockFactor::Scalar {
inv: Array1::from_elem(1, 1.0 / v),
range: a..a + 1,
});
}
Ok(Self { blocks })
}
fn build_block_jacobi<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
let block_offsets = &sys.block_offsets;
let mut schur_blocks: Vec<Array2<f64>> = Vec::with_capacity(block_offsets.len());
for (block_idx, range) in block_offsets.iter().enumerate() {
let b = range.end - range.start;
let mut schur_block = Array2::<f64>::zeros((b, b));
sys.penalty_block_add(
BetaBlockId(block_idx),
block_offsets.as_ref(),
&mut schur_block,
);
for bi in 0..b {
schur_block[[bi, bi]] += ridge_beta;
}
schur_blocks.push(schur_block);
}
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let htbeta_full = sys_htbeta_materialize_row(sys, i, row);
for (block_idx, range) in block_offsets.iter().enumerate() {
let b = range.end - range.start;
let mut solved_cols = Array2::<f64>::zeros((di, b));
for bj in 0..b {
let gj = range.start + bj;
let rhs = htbeta_full.column(gj).to_owned();
let solved = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
for c in 0..di {
solved_cols[[c, bj]] = solved[c];
}
}
let schur_block = &mut schur_blocks[block_idx];
for bi in 0..b {
let gi = range.start + bi;
for bj in 0..b {
let mut acc = 0.0;
for c in 0..di {
acc += htbeta_full[[c, gi]] * solved_cols[[c, bj]];
}
schur_block[[bi, bj]] -= acc;
}
}
}
}
let mut blocks = Vec::with_capacity(block_offsets.len());
for (block_idx, range) in block_offsets.iter().enumerate() {
let b = range.end - range.start;
let schur_block = &schur_blocks[block_idx];
let factor_opt = {
use faer::Side;
let view = FaerArrayView::new(schur_block);
FaerLlt::new(view.as_ref(), Side::Lower).ok()
};
if let Some(llt) = factor_opt {
blocks.push(BlockFactor::Chol {
factor: llt,
range: range.clone(),
});
} else {
let mut inv = Array1::<f64>::zeros(b);
for bi in 0..b {
let v = schur_block[[bi, bi]];
if !v.is_finite() || v <= JACOBI_DIAGONAL_PD_FLOOR {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"block Jacobi scalar fallback: non-PD diagonal at \
global index {}: {v}; regularization required",
range.start + bi
),
});
}
inv[bi] = 1.0 / v;
}
blocks.push(BlockFactor::Scalar {
inv,
range: range.clone(),
});
}
}
Ok(Self { blocks })
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for block in &self.blocks {
match block {
BlockFactor::Scalar { inv, range } => {
for (local, gi) in range.clone().enumerate() {
out[gi] = inv[local] * r[gi];
}
}
BlockFactor::Chol { factor, range } => {
let b = range.end - range.start;
let mut rhs = Array1::<f64>::zeros(b);
for (local, gi) in range.clone().enumerate() {
rhs[local] = r[gi];
}
use faer::linalg::solvers::Solve;
let stride = rhs.strides()[0];
let len = rhs.len();
let rhs_mat =
unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
let solved = factor.solve(rhs_mat);
for (local, gi) in range.clone().enumerate() {
out[gi] = solved[(local, 0)];
}
}
}
}
out
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchurPreconditionerKind {
Diagonal,
BetaBlockJacobi,
ClusterJacobi,
AdditiveSchwarz { overlap: usize },
}
const PRECOND_ESCALATE_K_THRESHOLD: usize = 100;
#[derive(Clone)]
enum ClusterFactor {
Chol {
cols: Vec<usize>,
factor: FaerLlt<f64>,
},
Scalar {
cols: Vec<usize>,
inv: Vec<f64>,
},
}
impl std::fmt::Debug for ClusterFactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusterFactor::Chol { cols, .. } => {
write!(f, "ClusterFactor::Chol {{ cols.len: {} }}", cols.len())
}
ClusterFactor::Scalar { cols, inv } => write!(
f,
"ClusterFactor::Scalar {{ cols.len: {}, inv.len: {} }}",
cols.len(),
inv.len()
),
}
}
}
const CLUSTER_JACOBI_MAX_CLUSTER: usize = 512;
#[derive(Debug, Clone)]
pub struct ClusterJacobiPreconditioner {
clusters: Vec<ClusterFactor>,
}
impl ClusterJacobiPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
) -> Result<Self, ArrowSchurError> {
if sys.block_offsets.is_empty() {
let cols: Vec<usize> = (0..sys.k).collect();
return Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &[cols]);
}
let graph = BetaCouplingGraph::build(
&sys.block_offsets,
&sys.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>(),
);
let col_groups: Vec<Vec<usize>> = graph
.component_partition()
.iter()
.map(|comp_blocks| {
let mut cols: Vec<usize> = comp_blocks
.iter()
.flat_map(|&b| sys.block_offsets[b].clone())
.collect();
cols.sort_unstable();
cols
})
.collect();
Self::build_from_column_groups(sys, htt_factors, ridge_beta, backend, &col_groups)
}
fn build_from_column_groups<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
col_groups: &[Vec<usize>],
) -> Result<Self, ArrowSchurError> {
let d = sys.d;
let mut clusters = Vec::with_capacity(col_groups.len());
for cols in col_groups {
let b = cols.len();
if b == 0 {
continue;
}
if b > CLUSTER_JACOBI_MAX_CLUSTER {
let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
clusters.push(ClusterFactor::Scalar {
cols: cols.clone(),
inv,
});
continue;
}
let mut s_block = Array2::<f64>::zeros((b, b));
sys.penalty_subblock_add(cols, &mut s_block);
for bi in 0..b {
s_block[[bi, bi]] += ridge_beta;
}
let mut col_vec = Array1::<f64>::zeros(d);
let mut solved_cols = Array2::<f64>::zeros((d, b));
for (row_idx, row) in sys.rows.iter().enumerate() {
for bj in 0..b {
let gj = cols[bj];
for c in 0..d {
col_vec[c] = row.htbeta[[c, gj]];
}
let solved =
backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
for c in 0..d {
solved_cols[[c, bj]] = solved[c];
}
}
for bi in 0..b {
let gi = cols[bi];
for bj in 0..b {
let mut acc = 0.0;
for c in 0..d {
acc += row.htbeta[[c, gi]] * solved_cols[[c, bj]];
}
s_block[[bi, bj]] -= acc;
}
}
}
symmetrize_upper_from_lower(&mut s_block);
let factor_opt = {
use faer::Side;
let view = FaerArrayView::new(&s_block);
FaerLlt::new(view.as_ref(), Side::Lower).ok()
};
if let Some(llt) = factor_opt {
clusters.push(ClusterFactor::Chol {
cols: cols.clone(),
factor: llt,
});
} else {
let inv = build_schur_scalar_inv(sys, htt_factors, ridge_beta, backend, cols)?;
clusters.push(ClusterFactor::Scalar {
cols: cols.clone(),
inv,
});
}
}
Ok(Self { clusters })
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for cluster in &self.clusters {
apply_cluster(cluster, r, &mut out, &ClusterApplyMode::Overwrite);
}
out
}
}
#[derive(Debug, Clone)]
pub struct AdditiveSchwarzPreconditioner {
clusters: Vec<ClusterFactor>,
weights: Vec<f64>,
}
impl AdditiveSchwarzPreconditioner {
pub fn from_arrow_schur<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
overlap: usize,
) -> Result<Self, ArrowSchurError> {
if sys.block_offsets.is_empty() {
let cols: Vec<usize> = (0..sys.k).collect();
let inner = ClusterJacobiPreconditioner::build_from_column_groups(
sys,
htt_factors,
ridge_beta,
backend,
&[cols],
)?;
return Ok(Self {
clusters: inner.clusters,
weights: vec![1.0f64; sys.k],
});
}
let graph = BetaCouplingGraph::build(
&sys.block_offsets,
&sys.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>(),
);
let col_groups: Vec<Vec<usize>> = graph
.component_partition()
.iter()
.map(|seed| {
let mut current = seed.clone();
for _ in 0..overlap {
current = graph.expand_one_hop(¤t);
}
let mut cols: Vec<usize> = current
.iter()
.flat_map(|&b| sys.block_offsets[b].clone())
.collect();
cols.sort_unstable();
cols.dedup();
cols
})
.collect();
let mut counts = vec![0u32; sys.k];
for cols in &col_groups {
for &gi in cols {
counts[gi] += 1;
}
}
let weights: Vec<f64> = counts
.iter()
.map(|&c| if c == 0 { 1.0 } else { 1.0 / c as f64 })
.collect();
let inner = ClusterJacobiPreconditioner::build_from_column_groups(
sys,
htt_factors,
ridge_beta,
backend,
&col_groups,
)?;
Ok(Self {
clusters: inner.clusters,
weights,
})
}
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(r.len());
for cluster in &self.clusters {
apply_cluster(
cluster,
r,
&mut out,
&ClusterApplyMode::Accumulate {
weights: &self.weights,
},
);
}
out
}
}
enum ClusterApplyMode<'w> {
Overwrite,
Accumulate { weights: &'w [f64] },
}
impl ClusterApplyMode<'_> {
#[inline]
fn write(&self, out: &mut Array1<f64>, gi: usize, value: f64) {
match self {
ClusterApplyMode::Overwrite => out[gi] = value,
ClusterApplyMode::Accumulate { weights } => out[gi] += weights[gi] * value,
}
}
}
fn apply_cluster(
cluster: &ClusterFactor,
r: &Array1<f64>,
out: &mut Array1<f64>,
mode: &ClusterApplyMode<'_>,
) {
match cluster {
ClusterFactor::Scalar { cols, inv } => {
for (local, &gi) in cols.iter().enumerate() {
mode.write(out, gi, inv[local] * r[gi]);
}
}
ClusterFactor::Chol { cols, factor } => {
let b = cols.len();
let mut rhs = Array1::<f64>::zeros(b);
for (local, &gi) in cols.iter().enumerate() {
rhs[local] = r[gi];
}
use faer::linalg::solvers::Solve;
let stride = rhs.strides()[0];
let len = rhs.len();
let rhs_mat = unsafe { faer::MatRef::from_raw_parts(rhs.as_ptr(), len, 1, stride, 0) };
let solved = factor.solve(rhs_mat);
for (local, &gi) in cols.iter().enumerate() {
mode.write(out, gi, solved[(local, 0)]);
}
}
}
}
fn build_schur_scalar_inv<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
backend: &B,
cols: &[usize],
) -> Result<Vec<f64>, ArrowSchurError> {
let d = sys.d;
let mut result = Vec::with_capacity(cols.len());
let mut col_vec = Array1::<f64>::zeros(d);
let mut full_diag = Array1::<f64>::zeros(sys.k);
{
let fd_slice = full_diag.as_slice_mut().expect("full_diag contiguous");
sys.penalty_diagonal_add(fd_slice);
}
for &gi in cols {
let mut s = full_diag[gi] + ridge_beta;
for (row_idx, row) in sys.rows.iter().enumerate() {
for c in 0..d {
col_vec[c] = row.htbeta[[c, gi]];
}
let solved = backend.solve_block_vector(htt_factors.factor(row_idx), col_vec.view());
let mut acc = 0.0;
for c in 0..d {
acc += col_vec[c] * solved[c];
}
s -= acc;
}
if !s.is_finite() || s <= JACOBI_DIAGONAL_PD_FLOOR {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cluster Schur scalar fallback: non-PD diagonal at index {gi}: {s}"
),
});
}
result.push(1.0 / s);
}
Ok(result)
}
fn steihaug_pcg_auto<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
rhs: &Array1<f64>,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
backend: &B,
gpu_matvec: Option<&GpuSchurMatvec>,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let jacobi = JacobiPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend)?;
let (x0, diag0) = run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| jacobi.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)?;
if sys.k <= PRECOND_ESCALATE_K_THRESHOLD || diag0.stopping_reason != PcgStopReason::MaxIter {
return Ok((x0, diag0));
}
let cluster =
ClusterJacobiPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend)?;
let (x1, diag1) = run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| cluster.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)?;
if diag1.stopping_reason != PcgStopReason::MaxIter {
return Ok((x1, diag1));
}
let schwarz =
AdditiveSchwarzPreconditioner::from_arrow_schur(sys, htt_factors, ridge_beta, backend, 1)?;
let (x2, diag2) = run_pcg_with_preconditioner(
sys,
htt_factors,
ridge_beta,
rhs,
|r| schwarz.apply(r),
pcg,
trust,
backend,
gpu_matvec,
metric_weights,
)?;
if diag2.stopping_reason == PcgStopReason::MaxIter {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"Schur PCG exhausted all preconditioner tiers (Jacobi, ClusterJacobi, \
AdditiveSchwarz) at MaxIter; final relative residual = {:e}",
diag2.final_relative_residual
),
});
}
Ok((x2, diag2))
}
fn run_pcg_with_preconditioner<ApplyPrec, B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
ridge_beta: f64,
rhs: &Array1<f64>,
apply_prec: ApplyPrec,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
backend: &B,
gpu_matvec: Option<&GpuSchurMatvec>,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
where
ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
{
let max_iters = pcg.max_iterations.min(trust.max_iterations);
let tol = pcg
.relative_tolerance
.max(trust.steihaug_relative_tolerance);
if let Some(gpu_mv) = gpu_matvec {
let gpu_mv = Arc::clone(gpu_mv);
steihaug_cg(
rhs,
move |p, out| gpu_mv(p, out),
apply_prec,
max_iters,
tol,
trust.radius,
metric_weights,
)
} else {
steihaug_cg(
rhs,
|p, out| schur_matvec(sys, htt_factors, ridge_beta, p, out, backend),
apply_prec,
max_iters,
tol,
trust.radius,
metric_weights,
)
}
}
#[derive(Debug, Clone, Copy)]
struct IdentityPreconditioner;
impl IdentityPreconditioner {
fn apply(&self, r: &Array1<f64>) -> Array1<f64> {
r.clone()
}
}
fn steihaug_dense_system(
schur: &Array2<f64>,
rhs: &Array1<f64>,
preconditioner: &IdentityPreconditioner,
pcg: &ArrowPcgOptions,
trust: &ArrowTrustRegionOptions,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError> {
steihaug_cg(
rhs,
|p, out| dense_matvec(schur, p, out),
|r| preconditioner.apply(r),
pcg.max_iterations,
pcg.relative_tolerance,
trust.radius,
metric_weights,
)
}
fn steihaug_cg<MatVec, ApplyPrec>(
rhs: &Array1<f64>,
mut matvec: MatVec,
mut apply_preconditioner: ApplyPrec,
max_iterations: usize,
relative_tolerance: f64,
trust_radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Result<(Array1<f64>, PcgDiagnostics), ArrowSchurError>
where
MatVec: FnMut(&Array1<f64>, &mut Array1<f64>),
ApplyPrec: FnMut(&Array1<f64>) -> Array1<f64>,
{
let n = rhs.len();
if let Some(weights) = metric_weights {
assert_eq!(
weights.len(),
n,
"Steihaug-CG metric weight length must match solve dimension"
);
}
let radius = if trust_radius.is_finite() && trust_radius > 0.0 {
trust_radius
} else {
f64::INFINITY
};
let rhs_norm = metric_norm(rhs.view(), metric_weights);
if rhs_norm == 0.0 {
return Ok((Array1::<f64>::zeros(n), PcgDiagnostics::default()));
}
let tol = (relative_tolerance.max(0.0) * rhs_norm).max(PCG_ABSOLUTE_TOLERANCE_FLOOR);
let mut x = Array1::<f64>::zeros(n);
let mut r = rhs.clone();
let mut z = apply_preconditioner(&r);
let mut diag = PcgDiagnostics {
precond_apply_calls: 1,
..PcgDiagnostics::default()
};
let mut p = z.clone();
let mut rz = metric_dot(&r, &z, metric_weights);
if rz <= 0.0 || !rz.is_finite() {
if radius.is_finite() {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &r, radius, metric_weights), diag));
}
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive preconditioned residual in Schur PCG".to_string(),
});
}
if metric_norm(r.view(), metric_weights) <= tol {
diag.final_relative_residual = 0.0;
diag.stopping_reason = PcgStopReason::Converged;
return Ok((x, diag));
}
let mut ap = Array1::<f64>::zeros(n);
let mut candidate = Array1::<f64>::zeros(n);
for _ in 0..max_iterations {
matvec(&p, &mut ap);
diag.matvec_calls += 1;
diag.iterations += 1;
let pap = metric_dot(&p, &ap, metric_weights);
if pap <= 0.0 || !pap.is_finite() {
if radius.is_finite() {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
}
return Err(ArrowSchurError::PcgFailed {
reason: "negative curvature in unbounded Schur PCG".to_string(),
});
}
let alpha = rz / pap;
for i in 0..n {
candidate[i] = x[i] + alpha * p[i];
}
if radius.is_finite() && metric_norm(candidate.view(), metric_weights) >= radius {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::TrustRegion;
return Ok((step_to_trust_boundary(&x, &p, radius, metric_weights), diag));
}
x.assign(&candidate);
for i in 0..n {
r[i] -= alpha * ap[i];
}
if metric_norm(r.view(), metric_weights) <= tol {
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::Converged;
return Ok((x, diag));
}
z = apply_preconditioner(&r);
diag.precond_apply_calls += 1;
let rz_next = metric_dot(&r, &z, metric_weights);
if rz_next <= 0.0 || !rz_next.is_finite() {
return Err(ArrowSchurError::PcgFailed {
reason: "non-positive or non-finite PCG residual".to_string(),
});
}
let beta = rz_next / rz;
for i in 0..n {
p[i] = z[i] + beta * p[i];
}
rz = rz_next;
}
diag.final_relative_residual = metric_norm(r.view(), metric_weights) / rhs_norm;
diag.stopping_reason = PcgStopReason::MaxIter;
Ok((x, diag))
}
fn step_to_trust_boundary(
x: &Array1<f64>,
p: &Array1<f64>,
radius: f64,
metric_weights: Option<&MetricWeights>,
) -> Array1<f64> {
let pp = metric_dot(p, p, metric_weights);
if pp == 0.0 {
return x.clone();
}
let xp = metric_dot(x, p, metric_weights);
let xx = metric_dot(x, x, metric_weights);
let disc = (xp * xp + pp * (radius * radius - xx)).max(0.0);
let tau = (-xp + disc.sqrt()) / pp;
let mut out = x.clone();
for i in 0..out.len() {
out[i] += tau * p[i];
}
out
}
fn dense_matvec(a: &Array2<f64>, x: &Array1<f64>, out: &mut Array1<f64>) {
let n = a.nrows();
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
acc += a[[i, j]] * x[j];
}
out[i] = acc;
}
}
fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut acc = 0.0;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
fn metric_dot(a: &Array1<f64>, b: &Array1<f64>, metric_weights: Option<&MetricWeights>) -> f64 {
assert_eq!(a.len(), b.len());
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), a.len());
let mut acc = 0.0;
for i in 0..a.len() {
acc += weights[i] * a[i] * b[i];
}
acc
}
None => dot(a, b),
}
}
fn metric_norm(v: ArrayView1<'_, f64>, metric_weights: Option<&MetricWeights>) -> f64 {
let mut acc = 0.0;
match metric_weights {
Some(weights) => {
assert_eq!(weights.len(), v.len());
for i in 0..v.len() {
acc += weights[i] * v[i] * v[i];
}
}
None => {
for x in v.iter() {
acc += x * x;
}
}
}
acc.sqrt()
}
fn symmetrize_upper_from_lower(a: &mut Array2<f64>) {
let n = a.nrows().min(a.ncols());
for i in 0..n {
for j in 0..i {
let v = 0.5 * (a[[i, j]] + a[[j, i]]);
a[[i, j]] = v;
a[[j, i]] = v;
}
}
}
#[derive(Debug, Clone)]
pub enum ArrowSchurError {
PerRowFactorFailed { row: usize, reason: String },
PerRowFactorIllConditioned { row: usize, kappa_estimate: f64 },
SchurFactorFailed { reason: String },
PcgFailed { reason: String },
AdaptiveCorrectionFailed { reason: String },
}
impl std::fmt::Display for ArrowSchurError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ArrowSchurError::PerRowFactorFailed { row, reason } => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky failed: {reason}"
),
ArrowSchurError::PerRowFactorIllConditioned {
row,
kappa_estimate,
} => write!(
f,
"arrow-Schur: per-row H_tt^({row}) Cholesky succeeded but failed \
the safe-inversion guard (kappa_estimate={kappa_estimate:e}); \
Schur reduction would be numerically contaminated"
),
ArrowSchurError::SchurFactorFailed { reason } => {
write!(f, "arrow-Schur: Schur complement Cholesky failed: {reason}")
}
ArrowSchurError::PcgFailed { reason } => {
write!(f, "arrow-Schur: Schur PCG failed: {reason}")
}
ArrowSchurError::AdaptiveCorrectionFailed { reason } => {
write!(
f,
"arrow-Schur: adaptive proximal correction failed: {reason}"
)
}
}
}
}
impl std::error::Error for ArrowSchurError {}
fn cholesky_lower(a: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err(format!("cholesky_lower: non-square {}×{}", n, a.ncols()));
}
if let Some((idx, _)) = a.iter().enumerate().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
let mut maybe_device = a.clone();
if crate::gpu::try_cholesky_lower_inplace(&mut maybe_device).is_some() {
return Ok(maybe_device);
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for kk in 0..j {
sum -= l[[i, kk]] * l[[j, kk]];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Ok(l)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn sparse_block_kronecker_matches_dense_kronecker() {
let p = 2usize;
let dim_a = 5usize;
let k = dim_a * p;
let g_dense = array![
[3.0_f64, 0.5, 0.2, -0.1, 0.0],
[0.5, 4.0, 0.0, 0.3, 0.1],
[0.2, 0.0, 2.0, 0.4, -0.2],
[-0.1, 0.3, 0.4, 5.0, 0.6],
[0.0, 0.1, -0.2, 0.6, 1.5],
];
let dense = KroneckerPenaltyOp {
factor_a: g_dense.clone(),
factor_b: Array2::<f64>::eye(p),
global_offset: 0,
k,
};
let block_00 = g_dense.slice(ndarray::s![0..2, 0..2]).to_owned();
let block_01 = g_dense.slice(ndarray::s![0..2, 2..5]).to_owned();
let block_10 = g_dense.slice(ndarray::s![2..5, 0..2]).to_owned();
let block_11 = g_dense.slice(ndarray::s![2..5, 2..5]).to_owned();
let sparse = SparseBlockKroneckerPenaltyOp {
p,
dim_a,
k,
blocks: vec![
SparseGBlock {
row_off: 0,
col_off: 0,
data: block_00,
},
SparseGBlock {
row_off: 0,
col_off: 2,
data: block_01,
},
SparseGBlock {
row_off: 2,
col_off: 0,
data: block_10,
},
SparseGBlock {
row_off: 2,
col_off: 2,
data: block_11,
},
],
};
let d_dense = dense.to_dense();
let d_sparse = sparse.to_dense();
for i in 0..k {
for j in 0..k {
assert!(
(d_dense[[i, j]] - d_sparse[[i, j]]).abs() < 1e-12,
"to_dense mismatch at ({i},{j}): {} vs {}",
d_dense[[i, j]],
d_sparse[[i, j]]
);
}
}
let x: Vec<f64> = (0..k).map(|i| 0.1 * (i as f64) - 0.3).collect();
let mut y_dense = vec![0.0_f64; k];
let mut y_sparse = vec![0.0_f64; k];
dense.matvec(&x, &mut y_dense);
sparse.matvec(&x, &mut y_sparse);
for i in 0..k {
assert!(
(y_dense[i] - y_sparse[i]).abs() < 1e-12,
"matvec mismatch at {i}: {} vs {}",
y_dense[i],
y_sparse[i]
);
}
let mut diag_dense = vec![0.0_f64; k];
let mut diag_sparse = vec![0.0_f64; k];
dense.diagonal(&mut diag_dense);
sparse.diagonal(&mut diag_sparse);
for i in 0..k {
assert!(
(diag_dense[i] - diag_sparse[i]).abs() < 1e-12,
"diagonal mismatch at {i}: {} vs {}",
diag_dense[i],
diag_sparse[i]
);
}
let offsets = [0..(2 * p), (2 * p)..k];
for id in 0..offsets.len() {
let b = offsets[id].end - offsets[id].start;
let mut blk_dense = Array2::<f64>::zeros((b, b));
let mut blk_sparse = Array2::<f64>::zeros((b, b));
dense.block(BetaBlockId(id), &offsets, &mut blk_dense);
sparse.block(BetaBlockId(id), &offsets, &mut blk_sparse);
for i in 0..b {
for j in 0..b {
assert!(
(blk_dense[[i, j]] - blk_sparse[[i, j]]).abs() < 1e-12,
"block {id} mismatch at ({i},{j})"
);
}
}
}
}
fn factored_reference_dense(
ranks: &[usize],
basis_sizes: &[usize],
blocks: &[FactoredFrameGBlock],
) -> Array2<f64> {
let n_atoms = ranks.len();
let mut offsets = vec![0usize; n_atoms + 1];
for k in 0..n_atoms {
offsets[k + 1] = offsets[k] + basis_sizes[k] * ranks[k];
}
let dim = offsets[n_atoms];
let mut h = Array2::<f64>::zeros((dim, dim));
for blk in blocks {
let (r_i, r_j) = (ranks[blk.atom_i], ranks[blk.atom_j]);
let (off_i, off_j) = (offsets[blk.atom_i], offsets[blk.atom_j]);
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for lj in 0..m_j {
for a in 0..r_i {
for b in 0..r_j {
h[[off_i + li * r_i + a, off_j + lj * r_j + b]] +=
blk.g[[li, lj]] * blk.w[[a, b]];
}
}
}
}
}
h
}
#[test]
fn factored_frame_kronecker_matches_dense_reference() {
let ranks = vec![2usize, 3];
let basis_sizes = vec![2usize, 3];
let g00 = array![[3.0_f64, 0.5], [0.5, 4.0]];
let g11 = array![[2.0_f64, 0.4, -0.2], [0.4, 5.0, 0.6], [-0.2, 0.6, 1.5]];
let g01 = array![[0.2_f64, -0.1, 0.0], [0.3, 0.1, -0.2]];
let g10 = g01.t().to_owned();
let w00 = Array2::<f64>::eye(2);
let w11 = Array2::<f64>::eye(3);
let w01 = array![[0.8_f64, 0.1, -0.05], [0.0, 0.7, 0.2]];
let w10 = w01.t().to_owned();
let blocks = vec![
FactoredFrameGBlock {
atom_i: 0,
atom_j: 0,
g: g00.clone(),
w: w00.clone(),
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 1,
g: g11.clone(),
w: w11.clone(),
},
FactoredFrameGBlock {
atom_i: 0,
atom_j: 1,
g: g01.clone(),
w: w01.clone(),
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 0,
g: g10.clone(),
w: w10.clone(),
},
];
let op = FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), blocks.clone())
.expect("op");
assert_eq!(op.dim(), 13);
let reference = factored_reference_dense(&ranks, &basis_sizes, &blocks);
let dense = op.to_dense();
for i in 0..13 {
for j in 0..13 {
assert!(
(dense[[i, j]] - reference[[i, j]]).abs() < 1e-12,
"to_dense mismatch at ({i},{j}): {} vs {}",
dense[[i, j]],
reference[[i, j]]
);
}
}
let x: Vec<f64> = (0..13).map(|i| 0.13 * (i as f64) - 0.4).collect();
let mut y = vec![0.0_f64; 13];
op.matvec(&x, &mut y);
for i in 0..13 {
let mut expect = 0.0;
for j in 0..13 {
expect += reference[[i, j]] * x[j];
}
assert!(
(y[i] - expect).abs() < 1e-10,
"matvec mismatch at {i}: {} vs {expect}",
y[i]
);
}
let mut diag = vec![0.0_f64; 13];
op.diagonal(&mut diag);
for i in 0..13 {
assert!(
(diag[i] - reference[[i, i]]).abs() < 1e-12,
"diagonal mismatch at {i}"
);
}
let offsets_ranges = [0..4usize, 4..13usize];
for id in 0..2 {
let b = offsets_ranges[id].end - offsets_ranges[id].start;
let mut blk = Array2::<f64>::zeros((b, b));
op.block(BetaBlockId(id), &offsets_ranges, &mut blk);
for bi in 0..b {
for bj in 0..b {
let gi = offsets_ranges[id].start + bi;
let gj = offsets_ranges[id].start + bj;
assert!(
(blk[[bi, bj]] - reference[[gi, gj]]).abs() < 1e-12,
"block {id} mismatch at ({bi},{bj})"
);
}
}
}
}
#[test]
fn factored_frame_kronecker_reduces_to_sparse_block_at_full_rank() {
let p = 2usize;
let g00 = array![[3.0_f64, 0.5], [0.5, 4.0]];
let g11 = array![[2.0_f64, 0.4], [0.4, 5.0]];
let g01 = array![[0.2_f64, -0.1], [0.3, 0.1]];
let g10 = g01.t().to_owned();
let ident = Array2::<f64>::eye(p);
let factored = FactoredFrameKroneckerOp::new(
vec![p, p],
vec![2, 2],
vec![
FactoredFrameGBlock {
atom_i: 0,
atom_j: 0,
g: g00.clone(),
w: ident.clone(),
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 1,
g: g11.clone(),
w: ident.clone(),
},
FactoredFrameGBlock {
atom_i: 0,
atom_j: 1,
g: g01.clone(),
w: ident.clone(),
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 0,
g: g10.clone(),
w: ident.clone(),
},
],
)
.expect("factored op");
let sparse = SparseBlockKroneckerPenaltyOp {
p,
dim_a: 4,
k: 8,
blocks: vec![
SparseGBlock {
row_off: 0,
col_off: 0,
data: g00,
},
SparseGBlock {
row_off: 2,
col_off: 2,
data: g11,
},
SparseGBlock {
row_off: 0,
col_off: 2,
data: g01,
},
SparseGBlock {
row_off: 2,
col_off: 0,
data: g10,
},
],
};
assert_eq!(factored.dim(), sparse.dim());
let x: Vec<f64> = (0..8).map(|i| 0.2 * (i as f64) - 0.5).collect();
let mut yf = vec![0.0_f64; 8];
let mut ys = vec![0.0_f64; 8];
factored.matvec(&x, &mut yf);
sparse.matvec(&x, &mut ys);
for i in 0..8 {
assert!(
(yf[i] - ys[i]).abs() < 1e-12,
"full-rank factored op must equal SparseBlockKronecker at {i}: {} vs {}",
yf[i],
ys[i]
);
}
}
fn mgs_orthonormalize(a: &Array2<f64>) -> Array2<f64> {
let (p, r) = a.dim();
let mut q = a.clone();
for j in 0..r {
for i in 0..j {
let mut dot = 0.0;
for c in 0..p {
dot += q[[c, i]] * q[[c, j]];
}
for c in 0..p {
q[[c, j]] -= dot * q[[c, i]];
}
}
let mut nrm = 0.0;
for c in 0..p {
nrm += q[[c, j]] * q[[c, j]];
}
let nrm = nrm.sqrt();
assert!(nrm > 1e-9, "mgs column {j} degenerate");
for c in 0..p {
q[[c, j]] /= nrm;
}
}
q
}
#[test]
fn frame_output_gram_orthonormal_is_identity() {
let p = 5usize;
let r = 3usize;
let mut seed = Array2::<f64>::zeros((p, r));
for c in 0..p {
for a in 0..r {
seed[[c, a]] = ((c as f64) * 0.37 + (a as f64) * 1.31).sin() + 0.1 * (a as f64);
}
}
let u = mgs_orthonormalize(&seed);
let g = frame_output_gram(u.view(), u.view());
assert_eq!(g.dim(), (r, r));
for a in 0..r {
for b in 0..r {
let expect = if a == b { 1.0 } else { 0.0 };
assert!(
(g[[a, b]] - expect).abs() < 1e-12,
"UᵀU not identity at ({a},{b}): {}",
g[[a, b]]
);
}
}
}
#[test]
fn from_frames_and_blocks_matches_dense_reference() {
let p = 4usize;
let basis_sizes = vec![2usize, 3];
let mut seed0 = Array2::<f64>::zeros((p, 2));
let mut seed1 = Array2::<f64>::zeros((p, 3));
for c in 0..p {
for a in 0..2 {
seed0[[c, a]] = ((c as f64) * 0.91 - (a as f64) * 0.5).cos() + 0.2 * (c as f64);
}
for a in 0..3 {
seed1[[c, a]] = ((c as f64) * 0.23 + (a as f64) * 1.7).sin() - 0.3 * (a as f64);
}
}
let u0 = mgs_orthonormalize(&seed0);
let u1 = mgs_orthonormalize(&seed1);
let g00 = array![[3.0_f64, 0.5], [0.5, 4.0]];
let g11 = array![[2.0_f64, 0.4, -0.2], [0.4, 5.0, 0.6], [-0.2, 0.6, 1.5]];
let g01 = array![[0.2_f64, -0.1, 0.0], [0.3, 0.1, -0.2]];
let g10 = g01.t().to_owned();
let mut g_blocks: std::collections::BTreeMap<(usize, usize), Array2<f64>> =
std::collections::BTreeMap::new();
g_blocks.insert((0, 0), g00.clone());
g_blocks.insert((1, 1), g11.clone());
g_blocks.insert((0, 1), g01.clone());
g_blocks.insert((1, 0), g10.clone());
let frames = vec![Some(u0.clone()), Some(u1.clone())];
let op =
FactoredFrameKroneckerOp::from_frames_and_blocks(&frames, &basis_sizes, p, &g_blocks)
.expect("from_frames_and_blocks");
assert_eq!(op.dim(), 13);
let ranks = vec![2usize, 3];
let w00 = frame_output_gram(u0.view(), u0.view());
let w11 = frame_output_gram(u1.view(), u1.view());
let w01 = frame_output_gram(u0.view(), u1.view());
let w10 = frame_output_gram(u1.view(), u0.view());
let ref_blocks = vec![
FactoredFrameGBlock {
atom_i: 0,
atom_j: 0,
g: g00,
w: w00,
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 1,
g: g11,
w: w11,
},
FactoredFrameGBlock {
atom_i: 0,
atom_j: 1,
g: g01,
w: w01,
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 0,
g: g10,
w: w10,
},
];
let reference = factored_reference_dense(&ranks, &basis_sizes, &ref_blocks);
let dense = op.to_dense();
for i in 0..13 {
for j in 0..13 {
assert!(
(dense[[i, j]] - reference[[i, j]]).abs() < 1e-12,
"to_dense mismatch at ({i},{j}): {} vs {}",
dense[[i, j]],
reference[[i, j]]
);
}
}
let x: Vec<f64> = (0..13).map(|i| 0.17 * (i as f64) - 0.6).collect();
let mut y = vec![0.0_f64; 13];
op.matvec(&x, &mut y);
for i in 0..13 {
let mut expect = 0.0;
for j in 0..13 {
expect += reference[[i, j]] * x[j];
}
assert!(
(y[i] - expect).abs() < 1e-10,
"matvec mismatch at {i}: {} vs {expect}",
y[i]
);
}
}
#[test]
fn from_frames_and_blocks_mixed_framed_unframed() {
let p = 4usize;
let basis_sizes = vec![2usize, 2]; let mut seed0 = Array2::<f64>::zeros((p, 2));
for c in 0..p {
for a in 0..2 {
seed0[[c, a]] = ((c as f64) * 0.61 + (a as f64) * 0.9).cos() - 0.15 * (c as f64);
}
}
let u0 = mgs_orthonormalize(&seed0);
let g00 = array![[3.0_f64, 0.5], [0.5, 4.0]];
let g11 = array![[2.0_f64, 0.4], [0.4, 5.0]];
let g01 = array![[0.2_f64, -0.1], [0.3, 0.1]];
let g10 = g01.t().to_owned();
let mut g_blocks: std::collections::BTreeMap<(usize, usize), Array2<f64>> =
std::collections::BTreeMap::new();
g_blocks.insert((0, 0), g00.clone());
g_blocks.insert((1, 1), g11.clone());
g_blocks.insert((0, 1), g01.clone());
g_blocks.insert((1, 0), g10.clone());
let frames = vec![Some(u0.clone()), None];
let op =
FactoredFrameKroneckerOp::from_frames_and_blocks(&frames, &basis_sizes, p, &g_blocks)
.expect("from_frames_and_blocks mixed");
assert_eq!(op.ranks, vec![2usize, 4]);
assert_eq!(op.dim(), 12);
let dense = op.to_dense();
let off1 = 4usize;
for li in 0..2 {
for lj in 0..2 {
for a in 0..4 {
for b in 0..4 {
let gi = off1 + li * 4 + a;
let gj = off1 + lj * 4 + b;
let expect = if a == b { g11[[li, lj]] } else { 0.0 };
assert!(
(dense[[gi, gj]] - expect).abs() < 1e-12,
"g_11 ⊗ I_4 mismatch at ({gi},{gj}): {} vs {expect}",
dense[[gi, gj]]
);
}
}
}
}
let ranks = vec![2usize, 4];
let ident_p = Array2::<f64>::eye(p);
let w00 = frame_output_gram(u0.view(), u0.view());
let w11 = frame_output_gram(ident_p.view(), ident_p.view());
let w01 = frame_output_gram(u0.view(), ident_p.view());
let w10 = frame_output_gram(ident_p.view(), u0.view());
let ref_blocks = vec![
FactoredFrameGBlock {
atom_i: 0,
atom_j: 0,
g: g00,
w: w00,
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 1,
g: g11.clone(),
w: w11,
},
FactoredFrameGBlock {
atom_i: 0,
atom_j: 1,
g: g01,
w: w01,
},
FactoredFrameGBlock {
atom_i: 1,
atom_j: 0,
g: g10,
w: w10,
},
];
let reference = factored_reference_dense(&ranks, &basis_sizes, &ref_blocks);
let x: Vec<f64> = (0..12).map(|i| 0.11 * (i as f64) - 0.4).collect();
let mut y = vec![0.0_f64; 12];
op.matvec(&x, &mut y);
for i in 0..12 {
let mut expect = 0.0;
for j in 0..12 {
expect += reference[[i, j]] * x[j];
}
assert!(
(y[i] - expect).abs() < 1e-10,
"mixed matvec mismatch at {i}: {} vs {expect}",
y[i]
);
}
}
#[test]
fn arrow_schur_matches_dense_reference_2x2() {
let n = 2;
let d = 2;
let k = 3;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0, 0.5], [0.2, 1.0, 0.0]];
sys.rows[0].gt = array![0.3_f64, -0.2];
sys.rows[1].htt = array![[1.5_f64, -0.1], [-0.1, 2.0]];
sys.rows[1].htbeta = array![[0.1_f64, 0.5, 0.0], [0.0, 0.3, 1.0]];
sys.rows[1].gt = array![-0.1_f64, 0.4];
sys.hbb = array![[4.0_f64, 0.2, 0.0], [0.2, 5.0, 0.1], [0.0, 0.1, 6.0],];
sys.gb = array![0.5_f64, -0.3, 0.2];
let (delta_t, delta_beta, _diag) = sys.solve(0.0, 0.0).expect("arrow-schur solve");
let streaming_options = ArrowSolveOptions::direct().with_streaming_chunk_size(Some(1));
let (delta_t_stream, delta_beta_stream, _diag_stream) = sys
.solve_with_options(0.0, 0.0, &streaming_options)
.expect("streaming arrow-schur solve");
assert_eq!(delta_beta, delta_beta_stream);
assert_eq!(delta_t, delta_t_stream);
let total = k + n * d;
let mut hjoint = Array2::<f64>::zeros((total, total));
let mut gjoint = Array1::<f64>::zeros(total);
for a in 0..k {
for b in 0..k {
hjoint[[a, b]] = sys.hbb[[a, b]];
}
gjoint[a] = sys.gb[a];
}
for i in 0..n {
let toff = k + i * d;
for a in 0..d {
for b in 0..d {
hjoint[[toff + a, toff + b]] = sys.rows[i].htt[[a, b]];
}
gjoint[toff + a] = sys.rows[i].gt[a];
for a2 in 0..k {
hjoint[[toff + a, a2]] = sys.rows[i].htbeta[[a, a2]];
hjoint[[a2, toff + a]] = sys.rows[i].htbeta[[a, a2]];
}
}
}
let lj = cholesky_lower(&hjoint).expect("dense ref PD");
let neg_g = gjoint.mapv(|v| -v);
let xref = cholesky_solve_vector(&lj, &neg_g);
for a in 0..k {
assert!(
(xref[a] - delta_beta[a]).abs() < 1e-10,
"β[{a}] mismatch: dense {} vs arrow {}",
xref[a],
delta_beta[a]
);
}
for i in 0..n {
for a in 0..d {
let dense = xref[k + i * d + a];
let arrow = delta_t[i * d + a];
assert!(
(dense - arrow).abs() < 1e-10,
"t[{i},{a}] mismatch: dense {dense} vs arrow {arrow}"
);
}
}
}
fn diagonal_arrow_fixture(row_min: f64, schur_min: f64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(2, 2, 2);
sys.rows[0].htt = array![[row_min, 0.0], [0.0, row_min + 1.0]];
sys.rows[1].htt = array![[row_min + 2.0, 0.0], [0.0, row_min + 3.0]];
for row in sys.rows.iter_mut() {
row.htbeta.fill(0.0);
row.gt.fill(0.0);
}
sys.hbb = array![[schur_min, 0.0], [0.0, schur_min + 1.0]];
sys.gb.fill(0.0);
sys
}
fn diagonal_fixture_dense_lambda_min(sys: &ArrowSchurSystem) -> f64 {
let mut out = f64::INFINITY;
for row in &sys.rows {
for axis in 0..row.htt.nrows() {
out = out.min(row.htt[[axis, axis]]);
}
}
for axis in 0..sys.hbb.nrows() {
out = out.min(sys.hbb[[axis, axis]]);
}
out
}
#[test]
fn arrow_factor_min_pivot_matches_dense_lambda_min_ordering() {
let weak = diagonal_arrow_fixture(0.2, 0.8);
let strong = diagonal_arrow_fixture(0.7, 1.2);
let options = ArrowSolveOptions::direct();
let (_dt_w, _db_w, weak_cache) =
solve_arrow_newton_step_with_options(&weak, 0.0, 0.0, &options)
.expect("weak diagonal fixture should factor");
let (_dt_s, _db_s, strong_cache) =
solve_arrow_newton_step_with_options(&strong, 0.0, 0.0, &options)
.expect("strong diagonal fixture should factor");
let weak_lambda = diagonal_fixture_dense_lambda_min(&weak);
let strong_lambda = diagonal_fixture_dense_lambda_min(&strong);
assert!(weak_lambda < strong_lambda);
let weak_pivot = arrow_factor_min_pivot(&weak_cache)
.min_pivot
.expect("weak pivot");
let strong_pivot = arrow_factor_min_pivot(&strong_cache)
.min_pivot
.expect("strong pivot");
assert_abs_diff_eq!(weak_pivot, weak_lambda, epsilon = 1.0e-14);
assert_abs_diff_eq!(strong_pivot, strong_lambda, epsilon = 1.0e-14);
assert!(weak_pivot < strong_pivot);
}
fn quartic_counterexample_value(t: f64) -> f64 {
0.25 * t.powi(4) - t * t + 2.0 * t
}
fn quartic_counterexample_system(t: f64) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(1, 1, 0);
sys.rows[0].gt = array![t.powi(3) - 2.0 * t + 2.0];
sys.rows[0].htt = array![[3.0 * t * t - 2.0]];
sys
}
#[test]
fn proximal_correction_breaks_scalar_newton_cycle() {
let options = ArrowSolveOptions::direct();
let correction = ArrowProximalCorrectionOptions {
initial_ridge: 1e-8,
ridge_growth: 10.0,
max_attempts: 16,
armijo_c1: 1e-4,
gradient_tolerance: 1e-12,
convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
};
let mut t = 0.0_f64;
let mut previous_value = quartic_counterexample_value(t);
for _ in 0..32 {
let sys = quartic_counterexample_system(t);
let accepted = solve_arrow_newton_step_with_proximal_correction(
&sys,
0.0,
0.0,
previous_value,
&options,
&correction,
|delta_t, _delta_beta| quartic_counterexample_value(t + delta_t[0]),
)
.expect("proximal correction should accept a descent step");
assert!(
accepted.trial_objective_value <= previous_value,
"accepted step must not increase the objective"
);
t += accepted.delta_t[0];
previous_value = accepted.trial_objective_value;
}
let final_grad = t.powi(3) - 2.0 * t + 2.0;
assert!(
final_grad.abs() < 1e-7,
"corrected iteration should reach the scalar critical point; t={t}, g={final_grad}"
);
}
#[test]
fn factor_one_row_conditions_barely_pd_block_via_ridge() {
let d = 2;
let k = 2;
let mut row = ArrowRowBlock::new(d, k);
row.htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
row.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row.gt = array![0.0_f64, 0.0];
let factor = factor_one_row(&row, 0.0, d, 0, false).expect(
"barely-PD H_tt must be CONDITIONED by per-row ridge escalation, not rejected (gam#578)",
);
let kappa = cholesky_factor_kappa_estimate(&factor);
assert!(
kappa.is_finite() && kappa <= safe_spd_kappa_max(d),
"conditioned factor must be within the safe-inversion κ ceiling; got κ={kappa:e}"
);
for i in 0..d {
for j in 0..d {
let mut acc = 0.0_f64;
for kk in 0..d {
acc += factor[[i, kk]] * factor[[j, kk]];
}
if i == j {
assert!(
acc >= row.htt[[i, j]] - 1e-12,
"diagonal of L Lᵀ must be H_tt + (nonneg ridge) at ({i},{j}): \
{acc} vs {}",
row.htt[[i, j]]
);
} else {
assert!(
(acc - row.htt[[i, j]]).abs() < 1e-9,
"off-diagonal of L Lᵀ must equal H_tt at ({i},{j}): {acc} vs {}",
row.htt[[i, j]]
);
}
}
}
let factor = factor_one_row(&row, 0.0, d, 0, true)
.expect("tolerate_ill_conditioning must accept a barely-PD-but-PD block");
for i in 0..d {
for j in 0..d {
let mut acc = 0.0_f64;
for kk in 0..d {
acc += factor[[i, kk]] * factor[[j, kk]];
}
assert!(
(acc - row.htt[[i, j]]).abs() < 1e-12,
"tolerated factor must satisfy L Lᵀ = H_tt at ({i},{j})"
);
}
}
let mut row_npd = ArrowRowBlock::new(d, k);
row_npd.htt = array![[1.0_f64, 2.0], [2.0, 1.0]]; row_npd.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row_npd.gt = array![0.0_f64, 0.0];
let npd = factor_one_row(&row_npd, 0.0, d, 0, true);
assert!(
matches!(npd, Err(ArrowSchurError::PerRowFactorFailed { .. })),
"non-PD block must error even with tolerate_ill_conditioning; got {npd:?}"
);
let mut row_ok = ArrowRowBlock::new(d, k);
row_ok.htt = array![[2.0_f64, 0.1], [0.1, 3.0]];
row_ok.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row_ok.gt = array![0.0_f64, 0.0];
factor_one_row(&row_ok, 0.0, d, 0, false)
.expect("well-conditioned block must still factor at ridge_t=0");
let mut row_nan = ArrowRowBlock::new(d, k);
row_nan.htt = array![[f64::NAN, 0.0], [0.0, 1.0]];
row_nan.htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
row_nan.gt = array![0.0_f64, 0.0];
let nan = factor_one_row(&row_nan, 1.0e-6, d, 0, false);
assert!(
matches!(nan, Err(ArrowSchurError::PerRowFactorFailed { .. })),
"non-finite block must surface PerRowFactorFailed, not loop or condition; got {nan:?}"
);
}
#[test]
fn factor_one_row_conditions_scalar_tiny_pivot_via_ridge() {
let d = 1;
let k = 1;
let mut row = ArrowRowBlock::new(d, k);
row.htt = array![[1.0e-20_f64]];
row.htbeta = array![[1.0_f64]];
row.gt = array![0.0_f64];
let factor = factor_one_row(&row, 0.0, d, 0, false)
.expect("tiny positive scalar pivot must be ridge-conditioned");
let pivot = factor[[0, 0]] * factor[[0, 0]];
assert!(
pivot >= safe_spd_pivot_min(1.0),
"scalar pivot must be lifted above the absolute safe floor; got {pivot:e}"
);
assert!(
pivot > row.htt[[0, 0]],
"scalar block must not be accepted at the raw tiny pivot"
);
let tolerated = factor_one_row(&row, 0.0, d, 0, true)
.expect("tolerated log-det path must accept a positive scalar block");
let raw_pivot = tolerated[[0, 0]] * tolerated[[0, 0]];
assert!(
(raw_pivot - row.htt[[0, 0]]).abs() < 1.0e-30,
"tolerated factor must remain the raw scalar Cholesky"
);
}
#[test]
fn sys_htbeta_materialize_row_sums_operator_and_dense_slab() {
let mut sys = ArrowSchurSystem::new(1, 1, 3);
sys.rows[0].htbeta = array![[0.25_f64, 0.5, 0.75]];
sys.activate_dense_htbeta_supplement();
sys.set_row_htbeta_operator(
|row_idx, x, out| {
assert_eq!(row_idx, 0);
out[0] += 2.0 * x[0] - x[1] + 0.5 * x[2];
},
|row_idx, v, out| {
assert_eq!(row_idx, 0);
out[0] += 2.0 * v[0];
out[1] -= v[0];
out[2] += 0.5 * v[0];
},
);
let htbeta = sys_htbeta_materialize_row(&sys, 0, &sys.rows[0]);
assert_eq!(htbeta, array![[2.25_f64, -0.5, 1.25]]);
}
#[test]
fn lm_escalation_recovers_from_ill_conditioned_row() {
let n = 1;
let d = 2;
let k = 2;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[1.0_f64, 1.0], [1.0, 1.0 + 1e-14]];
sys.rows[0].htbeta = array![[1.0_f64, 0.0], [0.0, 1.0]];
sys.rows[0].gt = array![0.1_f64, -0.2];
sys.hbb = array![[4.0_f64, 0.2], [0.2, 5.0]];
sys.gb = array![0.3_f64, -0.1];
let factor = factor_one_row(&sys.rows[0], 0.0, d, 0, false)
.expect("barely-PD row must be conditioned, not rejected (gam#578)");
let kappa = cholesky_factor_kappa_estimate(&factor);
assert!(
kappa.is_finite() && kappa <= safe_spd_kappa_max(d),
"conditioned per-row factor must satisfy the κ ceiling; got κ={kappa:e}"
);
let options = ArrowSolveOptions::direct();
let (delta_t, delta_beta, diag) = solve_with_lm_escalation_inner(&sys, 0.0, 0.0, &options)
.expect("LM escalation must recover from a barely-PD per-row block");
for v in delta_t.iter().chain(delta_beta.iter()) {
assert!(v.is_finite(), "recovered step must be finite: {v}");
}
assert!(
diag.ridge_escalations <= DEFAULT_PROXIMAL_MAX_ATTEMPTS,
"recovery must use a bounded number of outer ridge escalations; got {}",
diag.ridge_escalations
);
}
#[test]
fn latent_block_inverse_diagonal_matches_dense() {
let n = 3usize;
let d = 2usize;
let k = 2usize;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[4.0_f64, 0.5], [0.5, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.2], [-0.3, 0.7]];
sys.rows[1].htt = array![[5.0_f64, -0.4], [-0.4, 2.5]];
sys.rows[1].htbeta = array![[0.6_f64, -0.1], [0.4, 0.9]];
sys.rows[2].htt = array![[3.5_f64, 0.2], [0.2, 4.5]];
sys.rows[2].htbeta = array![[-0.2_f64, 0.5], [0.8, -0.6]];
for row in sys.rows.iter_mut() {
row.gt = array![0.0_f64, 0.0];
}
sys.hbb = array![[12.0_f64, 0.7], [0.7, 10.0]];
sys.gb = array![0.0_f64, 0.0];
let options = ArrowSolveOptions::direct();
let (_delta_t, _delta_beta, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options)
.expect("direct arrow solve should factor this SPD system");
let dim = n * d + k;
let mut h = Array2::<f64>::zeros((dim, dim));
for i in 0..n {
let base = i * d;
for r in 0..d {
for c in 0..d {
h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
}
}
for r in 0..d {
for c in 0..k {
let v = sys.rows[i].htbeta[[r, c]];
h[[base + r, n * d + c]] = v;
h[[n * d + c, base + r]] = v;
}
}
}
for r in 0..k {
for c in 0..k {
h[[n * d + r, n * d + c]] = sys.hbb[[r, c]];
}
}
let l = cholesky_lower(&h).expect("assembled bordered H must be SPD");
let h_inv = cholesky_solve_matrix(&l, &Array2::<f64>::eye(dim));
let diag = cache
.latent_block_inverse_diagonal()
.expect("dense Schur cache must support the selected-inverse diagonal");
assert_eq!(diag.len(), n * d);
for i in 0..n {
for j in 0..d {
let idx = i * d + j; let expected = h_inv[[idx, idx]];
let got = diag[idx];
assert!(
(got - expected).abs() < 1e-9,
"row {i} axis {j}: selected-inverse diag {got} vs dense {expected}"
);
}
}
let trace_selected: f64 = diag.iter().sum();
let trace_dense: f64 = (0..n * d).map(|idx| h_inv[[idx, idx]]).sum();
assert!(
(trace_selected - trace_dense).abs() < 1e-9,
"full latent trace {trace_selected} vs dense {trace_dense}"
);
}
#[test]
fn full_inverse_apply_matches_dense_inverse_and_newton_step() {
let n = 3usize;
let d = 2usize;
let k = 2usize;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[4.0_f64, 0.5], [0.5, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.2], [-0.3, 0.7]];
sys.rows[0].gt = array![0.4_f64, -0.7];
sys.rows[1].htt = array![[5.0_f64, -0.4], [-0.4, 2.5]];
sys.rows[1].htbeta = array![[0.6_f64, -0.1], [0.4, 0.9]];
sys.rows[1].gt = array![-0.2_f64, 0.9];
sys.rows[2].htt = array![[3.5_f64, 0.2], [0.2, 4.5]];
sys.rows[2].htbeta = array![[-0.2_f64, 0.5], [0.8, -0.6]];
sys.rows[2].gt = array![1.1_f64, 0.3];
sys.hbb = array![[12.0_f64, 0.7], [0.7, 10.0]];
sys.gb = array![0.5_f64, -0.8];
let options = ArrowSolveOptions::direct();
let (delta_t, delta_beta, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options)
.expect("direct arrow solve should factor this SPD system");
let mut g_t = Array1::<f64>::zeros(n * d);
for i in 0..n {
for j in 0..d {
g_t[i * d + j] = sys.rows[i].gt[j];
}
}
let (u_t, u_beta) = cache
.full_inverse_apply(g_t.view(), sys.gb.view())
.expect("full_inverse_apply on the ridge-0 Direct cache");
for idx in 0..n * d {
assert!(
(u_t[idx] + delta_t[idx]).abs() < 1e-10,
"t[{idx}]: full_inverse_apply {} vs −(Newton step) {}",
u_t[idx],
-delta_t[idx]
);
}
for c in 0..k {
assert!(
(u_beta[c] + delta_beta[c]).abs() < 1e-10,
"beta[{c}]: full_inverse_apply {} vs −(Newton step) {}",
u_beta[c],
-delta_beta[c]
);
}
let dim = n * d + k;
let mut h = Array2::<f64>::zeros((dim, dim));
for i in 0..n {
let base = i * d;
for r in 0..d {
for c in 0..d {
h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
}
for c in 0..k {
let v = sys.rows[i].htbeta[[r, c]];
h[[base + r, n * d + c]] = v;
h[[n * d + c, base + r]] = v;
}
}
}
for r in 0..k {
for c in 0..k {
h[[n * d + r, n * d + c]] = sys.hbb[[r, c]];
}
}
let l = cholesky_lower(&h).expect("assembled bordered H must be SPD");
let mut w_full = Array1::<f64>::zeros(dim);
for (idx, v) in w_full.iter_mut().enumerate() {
*v = 0.3 + 0.17 * (idx as f64) * (if idx % 2 == 0 { 1.0 } else { -1.0 });
}
let dense_u = cholesky_solve_vector(&l, &w_full);
let (u_t2, u_beta2) = cache
.full_inverse_apply(
w_full.slice(ndarray::s![..n * d]),
w_full.slice(ndarray::s![n * d..]),
)
.expect("full_inverse_apply on arbitrary RHS");
for idx in 0..n * d {
assert!(
(u_t2[idx] - dense_u[idx]).abs() < 1e-10,
"t[{idx}]: full_inverse_apply {} vs dense {}",
u_t2[idx],
dense_u[idx]
);
}
for c in 0..k {
assert!(
(u_beta2[c] - dense_u[n * d + c]).abs() < 1e-10,
"beta[{c}]: full_inverse_apply {} vs dense {}",
u_beta2[c],
dense_u[n * d + c]
);
}
}
#[test]
fn schur_inverse_beta_block_matches_dense() {
let n = 3usize;
let d = 2usize;
let k = 2usize;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[4.0_f64, 0.5], [0.5, 3.0]];
sys.rows[0].htbeta = array![[1.0_f64, 0.2], [-0.3, 0.7]];
sys.rows[1].htt = array![[5.0_f64, -0.4], [-0.4, 2.5]];
sys.rows[1].htbeta = array![[0.6_f64, -0.1], [0.4, 0.9]];
sys.rows[2].htt = array![[3.5_f64, 0.2], [0.2, 4.5]];
sys.rows[2].htbeta = array![[-0.2_f64, 0.5], [0.8, -0.6]];
for row in sys.rows.iter_mut() {
row.gt = array![0.0_f64, 0.0];
}
sys.hbb = array![[12.0_f64, 0.7], [0.7, 10.0]];
sys.gb = array![0.0_f64, 0.0];
let options = ArrowSolveOptions::direct();
let (_dt, _db, cache) = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options)
.expect("direct arrow solve should factor this SPD system");
let dim = n * d + k;
let mut h = Array2::<f64>::zeros((dim, dim));
for i in 0..n {
let base = i * d;
for r in 0..d {
for c in 0..d {
h[[base + r, base + c]] = sys.rows[i].htt[[r, c]];
}
}
for r in 0..d {
for c in 0..k {
let v = sys.rows[i].htbeta[[r, c]];
h[[base + r, n * d + c]] = v;
h[[n * d + c, base + r]] = v;
}
}
}
for r in 0..k {
for c in 0..k {
h[[n * d + r, n * d + c]] = sys.hbb[[r, c]];
}
}
let l = cholesky_lower(&h).expect("assembled bordered H must be SPD");
let h_inv = cholesky_solve_matrix(&l, &Array2::<f64>::eye(dim));
let beta_off = n * d;
for col in 0..k {
let mut e = Array1::<f64>::zeros(k);
e[col] = 1.0;
let x = cache
.schur_inverse_apply(e.view())
.expect("dense Schur cache must support schur_inverse_apply");
for r in 0..k {
let expected = h_inv[[beta_off + r, beta_off + col]];
assert!(
(x[r] - expected).abs() < 1e-9,
"S_β⁻¹[{r},{col}] {} vs dense {expected}",
x[r]
);
}
}
let a_scalar = 0.75_f64;
let mut trace = 0.0_f64;
for col in 0..k {
let mut m_col = Array1::<f64>::zeros(k);
m_col[col] = a_scalar;
let z = cache
.schur_inverse_apply(m_col.view())
.expect("schur_inverse_apply");
trace += z[col];
}
let trace_dense: f64 = a_scalar
* (0..k)
.map(|j| h_inv[[beta_off + j, beta_off + j]])
.sum::<f64>();
assert!(
(trace - trace_dense).abs() < 1e-9,
"Kron-block trace {trace} vs dense {trace_dense}"
);
let full = cache
.schur_inverse_block(0..k)
.expect("dense Schur cache must support schur_inverse_block");
assert_eq!(full.dim(), (k, k));
for r in 0..k {
for c in 0..k {
let expected = h_inv[[beta_off + r, beta_off + c]];
assert!(
(full[[r, c]] - expected).abs() < 1e-9,
"block[{r},{c}] {} vs dense {expected}",
full[[r, c]]
);
assert!(
(full[[r, c]] - full[[c, r]]).abs() < 1e-12,
"schur_inverse_block must be symmetric at [{r},{c}]"
);
}
}
let sub = cache
.schur_inverse_block(1..k)
.expect("interior block must be supported");
assert_eq!(sub.dim(), (k - 1, k - 1));
assert!(
(sub[[0, 0]] - h_inv[[beta_off + 1, beta_off + 1]]).abs() < 1e-9,
"interior block [1,1] {} vs dense {}",
sub[[0, 0]],
h_inv[[beta_off + 1, beta_off + 1]]
);
assert!(cache.schur_inverse_block(0..(k + 1)).is_err());
}
#[test]
fn ill_conditioning_tolerated_returns_cache_with_exact_logdet() {
let n = 2usize;
let d = 2usize;
let k = 2usize;
let mut sys = ArrowSchurSystem::new(n, d, k);
sys.rows[0].htt = array![[1.0_f64, 0.0], [0.0, 1e-9]];
sys.rows[0].htbeta = array![[0.3_f64, 0.1], [0.05, 0.2]];
sys.rows[1].htt = array![[2.0_f64, 0.0], [0.0, 2e-9]];
sys.rows[1].htbeta = array![[0.2_f64, -0.1], [0.1, 0.15]];
for row in sys.rows.iter_mut() {
row.gt = array![0.0_f64, 0.0];
}
sys.hbb = array![[5.0_f64, 0.3], [0.3, 4.0]];
sys.gb = array![0.0_f64, 0.0];
for i in 0..n {
let factor = factor_one_row(&sys.rows[i], 0.0, d, i, false)
.expect("barely-PD row must be conditioned, not rejected (gam#578)");
let kappa = cholesky_factor_kappa_estimate(&factor);
assert!(
kappa.is_finite() && kappa <= safe_spd_kappa_max(d),
"conditioned per-row factor {i} must satisfy the safe-Schur κ ceiling; got κ={kappa:e}"
);
}
let single_shot =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &ArrowSolveOptions::direct());
assert!(
matches!(
single_shot,
Err(ArrowSchurError::SchurFactorFailed { .. })
| Err(ArrowSchurError::PerRowFactorIllConditioned { .. })
| Err(ArrowSchurError::PcgFailed { .. })
),
"single-shot strict direct() cannot keep the dense Schur PD with per-row \
conditioning alone; expected a recoverable factorization error, got {single_shot:?}"
);
let (strict_dt, strict_db, strict_diag) =
solve_with_lm_escalation_inner(&sys, 0.0, 0.0, &ArrowSolveOptions::direct())
.expect("LM escalation must recover the ill-conditioned strict solve (gam#845)");
for v in strict_dt.iter().chain(strict_db.iter()) {
assert!(v.is_finite(), "recovered strict step must be finite: {v}");
}
assert!(
strict_diag.ridge_escalations <= DEFAULT_PROXIMAL_MAX_ATTEMPTS,
"recovery must use a bounded number of outer ridge escalations; got {}",
strict_diag.ridge_escalations
);
let opts = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let tolerate_indefinite = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &opts);
assert!(
matches!(
tolerate_indefinite,
Err(ArrowSchurError::SchurFactorFailed { .. })
),
"tolerate mode must refuse the indefinite assembled H rather than fabricate \
a log-determinant; got {tolerate_indefinite:?}"
);
let mut pd_sys = ArrowSchurSystem::new(n, d, k);
pd_sys.rows[0].htt = array![[1.0_f64, 0.0], [0.0, 1e-9]];
pd_sys.rows[0].htbeta = array![[0.3_f64, 0.1], [3e-6, 1e-6]];
pd_sys.rows[1].htt = array![[2.0_f64, 0.0], [0.0, 2e-9]];
pd_sys.rows[1].htbeta = array![[0.2_f64, -0.1], [2e-6, 4e-6]];
for row in pd_sys.rows.iter_mut() {
row.gt = array![0.0_f64, 0.0];
}
pd_sys.hbb = array![[5.0_f64, 0.3], [0.3, 4.0]];
pd_sys.gb = array![0.0_f64, 0.0];
let (_dt, _db, cache) = solve_arrow_newton_step_with_options(&pd_sys, 0.0, 0.0, &opts)
.expect("tolerate mode must factor the ill-conditioned-but-PD system");
let (log_det_tt, log_det_schur) = cache.arrow_log_det();
let log_det_cache = log_det_tt + log_det_schur.expect("dense Schur factor present");
let dim = n * d + k;
let mut h = Array2::<f64>::zeros((dim, dim));
for i in 0..n {
let base = i * d;
for r in 0..d {
for c in 0..d {
h[[base + r, base + c]] = pd_sys.rows[i].htt[[r, c]];
}
}
for r in 0..d {
for c in 0..k {
let v = pd_sys.rows[i].htbeta[[r, c]];
h[[base + r, n * d + c]] = v;
h[[n * d + c, base + r]] = v;
}
}
}
for r in 0..k {
for c in 0..k {
h[[n * d + r, n * d + c]] = pd_sys.hbb[[r, c]];
}
}
let lh = cholesky_lower(&h).expect("assembled bordered H must be SPD");
let log_det_dense: f64 = 2.0 * (0..dim).map(|i| lh[[i, i]].ln()).sum::<f64>();
assert!(
(log_det_cache - log_det_dense).abs() < 1e-6,
"tolerated-cache log|H| {log_det_cache} vs dense {log_det_dense}"
);
let tdiag = cache
.latent_block_inverse_diagonal()
.expect("tolerated cache must support latent_block_inverse_diagonal");
assert_eq!(tdiag.len(), n * d);
assert!(tdiag.iter().all(|v| v.is_finite()));
}
#[test]
fn arrow_factor_slab_accessor_matches_array_blocks_bitwise() {
let blocks = vec![
array![[1.0_f64]],
array![[2.0_f64, 0.0], [0.25, 3.0]],
array![[4.0_f64, 0.0, 0.0], [0.5, 5.0, 0.0], [-0.25, 0.75, 6.0]],
];
let slab = ArrowFactorSlab::from_blocks(blocks.clone());
assert_eq!(slab.len(), blocks.len());
for row in 0..blocks.len() {
let view = slab.factor(row);
assert_eq!(view.dim(), blocks[row].dim());
for r in 0..blocks[row].nrows() {
for c in 0..blocks[row].ncols() {
assert_eq!(view[[r, c]].to_bits(), blocks[row][[r, c]].to_bits());
}
}
}
}
fn fixed_row_kernel_fixture<const D: usize>() -> (ArrowRowBlock, Array1<f64>) {
let mut row = ArrowRowBlock::new(D, 0);
for r in 0..D {
for c in 0..D {
row.htt[[r, c]] = if r == c {
4.0 + r as f64
} else {
0.03125 * ((r + c + 1) as f64)
};
}
}
let rhs = Array1::from_iter((0..D).map(|i| 0.5 + i as f64 * 0.25));
(row, rhs)
}
fn assert_fixed_row_kernels_match_dynamic<const D: usize>() -> usize {
let (row, rhs) = fixed_row_kernel_fixture::<D>();
let ridge = 0.125_f64;
let fixed = factor_row_block_cholesky_fixed::<D>(&row, ridge).expect("fixed factor");
let dynamic = factor_row_block_cholesky_dynamic(&row, ridge, D).expect("dynamic factor");
for r in 0..D {
for c in 0..D {
assert_eq!(
fixed[[r, c]].to_bits(),
dynamic[[r, c]].to_bits(),
"factor mismatch at D={D} ({r},{c})"
);
}
}
let fixed_solve = cholesky_solve_vector_fixed::<D>(fixed.view(), rhs.view());
let dynamic_solve = cholesky_solve_vector(dynamic.view(), rhs.view());
for i in 0..D {
assert_eq!(
fixed_solve[i].to_bits(),
dynamic_solve[i].to_bits(),
"solve mismatch at D={D} index {i}"
);
}
D
}
#[test]
fn fixed_row_kernels_match_dynamic_path_bitwise() {
let checked = assert_fixed_row_kernels_match_dynamic::<1>()
+ assert_fixed_row_kernels_match_dynamic::<2>()
+ assert_fixed_row_kernels_match_dynamic::<3>()
+ assert_fixed_row_kernels_match_dynamic::<4>();
assert_eq!(checked, 10);
}
fn dense_direct_system(n: usize, d: usize, k: usize) -> ArrowSchurSystem {
let mut sys = ArrowSchurSystem::new(n, d, k);
for (i, row) in sys.rows.iter_mut().enumerate() {
for r in 0..d {
for c in 0..d {
row.htt[[r, c]] = if r == c { 4.0 + (i % 3) as f64 } else { 0.1 };
}
row.gt[r] = 0.05 * ((i + r + 1) as f64).sin();
for c in 0..k {
row.htbeta[[r, c]] = 0.01 * (((i + 1) * (c + 1)) as f64).cos();
}
}
}
for r in 0..k {
sys.gb[r] = 0.02 * ((r + 1) as f64).cos();
for c in 0..k {
sys.hbb[[r, c]] = if r == c { 6.0 } else { 0.0 };
}
}
sys.refresh_row_hessian_fingerprint();
sys
}
#[test]
fn device_dispatch_predicate_gates_on_work_not_rows() {
let policy = crate::gpu::policy::GpuDispatchPolicy::default();
assert!(!policy.dense_hessian_work_target_is_gpu(300, 8));
assert!(policy.dense_hessian_work_target_is_gpu(2_000, 4_096));
}
#[test]
fn device_seam_declines_without_gpu_and_matches_cpu() {
if crate::gpu::runtime::GpuRuntime::global().is_some() {
return;
}
let sys = dense_direct_system(6, 2, 4);
let options = ArrowSolveOptions::direct();
assert!(try_device_arrow_direct(&sys, 0.0, 0.0, &options).is_none());
assert!(maybe_inject_gpu_schur_matvec(&sys, 0.0, 0.0, &options).is_none());
let (dt_core, db_core, diag) =
solve_arrow_newton_step_core(&sys, 0.0, 0.0, &options).expect("core solve");
assert!(
!diag.used_device_arrow,
"no device present, so the solve must not be flagged device-served"
);
let artifacts =
solve_arrow_newton_step_artifacts(&sys, 0.0, 0.0, &options).expect("artifacts solve");
for (a, b) in dt_core.iter().zip(artifacts.delta_t.iter()) {
assert_eq!(a.to_bits(), b.to_bits(), "Δt must be bit-identical to CPU");
}
for (a, b) in db_core.iter().zip(artifacts.delta_beta.iter()) {
assert_eq!(a.to_bits(), b.to_bits(), "Δβ must be bit-identical to CPU");
}
}
#[test]
fn streaming_mixed_precision_matches_f64_and_keeps_logdet_f64() {
let sys = dense_direct_system(40, 3, 6);
let f64_options = ArrowSolveOptions::direct().with_streaming_chunk_size(Some(8));
let mp_options = f64_options
.clone()
.with_mixed_precision_policy(MixedPrecisionPolicy::certified());
assert!(matches!(
f64_options.mixed_precision,
MixedPrecisionPolicy::Off
));
let mut s_f64 = StreamingArrowSchur::from_system(&sys, 8);
let (_, db_f64, _) = s_f64
.solve(0.0, 0.0, &f64_options)
.expect("f64 streaming solve");
let mut s_mp = StreamingArrowSchur::from_system(&sys, 8);
let (_, db_mp, _) = s_mp
.solve(0.0, 0.0, &mp_options)
.expect("mp streaming solve");
let mut max_abs = 0.0_f64;
for (a, b) in db_f64.iter().zip(db_mp.iter()) {
max_abs = max_abs.max((a - b).abs());
}
assert!(
max_abs < 1e-7,
"mixed-precision Δβ deviates from f64 by {max_abs:e}, above the certified tolerance"
);
let mut ld_f64 = StreamingArrowSchur::from_system(&sys, 8);
let logdet_f64 = ld_f64
.exact_arrow_log_det(0.0, 0.0, &f64_options)
.expect("f64 logdet");
let mut ld_mp = StreamingArrowSchur::from_system(&sys, 8);
let logdet_mp = ld_mp
.exact_arrow_log_det(0.0, 0.0, &mp_options)
.expect("mp logdet");
assert_eq!(
logdet_f64.to_bits(),
logdet_mp.to_bits(),
"evidence log|H| must stay bit-for-bit f64 under the mixed-precision policy"
);
}
#[test]
fn streaming_mixed_precision_default_upgrades_only_off() {
let off = ArrowSolveOptions::direct();
assert!(matches!(
off.with_streaming_mixed_precision_default().mixed_precision,
MixedPrecisionPolicy::Certified { .. }
));
let pinned =
ArrowSolveOptions::direct().with_mixed_precision_policy(MixedPrecisionPolicy::Off);
let custom = ArrowSolveOptions::direct().with_mixed_precision_policy(
MixedPrecisionPolicy::Certified {
max_refinement_steps: 1,
residual_relative_tolerance: 1e-6,
kappa_unit_roundoff_margin: 0.25,
},
);
match custom
.with_streaming_mixed_precision_default()
.mixed_precision
{
MixedPrecisionPolicy::Certified {
max_refinement_steps,
..
} => assert_eq!(max_refinement_steps, 1, "explicit policy preserved"),
MixedPrecisionPolicy::Off => panic!("explicit Certified must not be downgraded"),
}
assert!(matches!(pinned.mixed_precision, MixedPrecisionPolicy::Off));
}
}