use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use std::ops::Range;
use std::sync::Arc;
use crate::cache::Fingerprinter;
use crate::linalg::faer_ndarray::{FaerArrayView, FaerLlt};
use crate::linalg::triangular::{
cholesky_solve_matrix, cholesky_solve_vector, forward_substitution_lower_matrix,
};
use crate::terms::analytic_penalties::{AnalyticPenaltyKind, AnalyticPenaltyRegistry, PenaltyTier};
use crate::terms::latent_coord::{LatentCoordValues, LatentManifold};
const DIRECT_SOLVE_MAX_K: usize = 2_000;
const DEFAULT_PCG_MAX_ITERATIONS: usize = 200;
const DEFAULT_PCG_RELATIVE_TOLERANCE: f64 = 1e-4;
const PCG_ABSOLUTE_TOLERANCE_FLOOR: f64 = 1e-14;
const DEFAULT_TRUST_REGION_RADIUS: f64 = f64::INFINITY;
pub const DEFAULT_PROXIMAL_INITIAL_RIDGE: f64 = 1e-8;
const F32_UNIT_ROUNDOFF: f64 = (f32::EPSILON as f64) * 0.5;
const DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS: usize = 6;
const DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE: f64 = 1e-11;
const DEFAULT_MIXED_PRECISION_KAPPA_MARGIN: f64 = 0.5;
const MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER: f64 = 64.0;
const MIXED_PRECISION_KAPPA_MARGIN_CEILING: f64 = 1.0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct BetaEdge {
a: usize,
b: usize,
}
#[derive(Debug, Clone)]
struct BetaCouplingGraph {
num_blocks: usize,
edges: Vec<BetaEdge>,
adj_start: Vec<usize>,
adj_targets: Vec<usize>,
}
impl BetaCouplingGraph {
fn build(block_offsets: &[Range<usize>], htbeta_rows: &[Array2<f64>]) -> Self {
let num_blocks = block_offsets.len();
if num_blocks == 0 {
return Self {
num_blocks: 0,
edges: Vec::new(),
adj_start: vec![0],
adj_targets: Vec::new(),
};
}
let mut edge_set = Vec::<(usize, usize)>::new();
for row in htbeta_rows {
let mut active = Vec::<usize>::new();
for (block, range) in block_offsets.iter().enumerate() {
if range
.clone()
.any(|col| (0..row.nrows()).any(|axis| row[[axis, col]] != 0.0))
{
active.push(block);
}
}
for i in 0..active.len() {
for j in (i + 1)..active.len() {
edge_set.push((active[i].min(active[j]), active[i].max(active[j])));
}
}
}
edge_set.sort_unstable();
edge_set.dedup();
let edges: Vec<_> = edge_set.iter().map(|&(a, b)| BetaEdge { a, b }).collect();
let mut degree = vec![0usize; num_blocks];
for &BetaEdge { a, b } in &edges {
degree[a] += 1;
degree[b] += 1;
}
let mut adj_start = vec![0usize; num_blocks + 1];
for block in 0..num_blocks {
adj_start[block + 1] = adj_start[block] + degree[block];
}
let mut adj_targets = vec![0usize; adj_start[num_blocks]];
let mut cursor = adj_start[..num_blocks].to_vec();
for &BetaEdge { a, b } in &edges {
adj_targets[cursor[a]] = b;
cursor[a] += 1;
adj_targets[cursor[b]] = a;
cursor[b] += 1;
}
Self {
num_blocks,
edges,
adj_start,
adj_targets,
}
}
fn neighbours(&self, node: usize) -> &[usize] {
&self.adj_targets[self.adj_start[node]..self.adj_start[node + 1]]
}
fn component_partition(&self) -> Vec<Vec<usize>> {
let mut parent: Vec<usize> = (0..self.num_blocks).collect();
let mut rank = vec![0u8; self.num_blocks];
fn find(parent: &mut [usize], mut x: usize) -> usize {
while parent[x] != x {
parent[x] = parent[parent[x]];
x = parent[x];
}
x
}
for &BetaEdge { a, b } in &self.edges {
let lhs = find(&mut parent, a);
let rhs = find(&mut parent, b);
if lhs != rhs {
if rank[lhs] < rank[rhs] {
parent[lhs] = rhs;
} else if rank[lhs] > rank[rhs] {
parent[rhs] = lhs;
} else {
parent[rhs] = lhs;
rank[lhs] += 1;
}
}
}
let mut label_map = vec![usize::MAX; self.num_blocks];
let mut parts = Vec::<Vec<usize>>::new();
for block in 0..self.num_blocks {
let root = find(&mut parent, block);
let label = if label_map[root] == usize::MAX {
label_map[root] = parts.len();
parts.push(Vec::new());
label_map[root]
} else {
label_map[root]
};
parts[label].push(block);
}
parts
}
fn expand_one_hop(&self, seed: &[usize]) -> Vec<usize> {
let mut expanded = seed.to_vec();
for &block in seed {
expanded.extend_from_slice(self.neighbours(block));
}
expanded.sort_unstable();
expanded.dedup();
expanded
}
}
pub const DEFAULT_PROXIMAL_RIDGE_GROWTH: f64 = 10.0;
pub const DEFAULT_PROXIMAL_MAX_ATTEMPTS: usize = 22;
const DEFAULT_ARMIJO_C1: f64 = 1e-4;
const DEFAULT_GRADIENT_TOLERANCE: f64 = 1e-10;
const DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL: f64 = 8e-12;
const EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT: u64 = 0;
const ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES: usize = 256 * 1024 * 1024;
pub type SharedBetaMatvec =
Arc<dyn for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type RowHtbetaTransposeMatvec =
Arc<dyn for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync>;
pub type StreamingArrowRowBuilder =
Arc<dyn Fn(usize) -> Result<ArrowRowBlock, ArrowSchurError> + Send + Sync>;
pub type GpuSchurMatvec = Arc<dyn Fn(&Array1<f64>, &mut Array1<f64>) + Send + Sync>;
type MetricWeights = [f64];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BetaBlockId(pub usize);
pub trait BetaPenaltyOp: Send + Sync {
fn dim(&self) -> usize;
fn matvec(&self, x: &[f64], y: &mut [f64]);
fn gradient(&self, beta: &[f64], out: &mut [f64]);
fn diagonal(&self, diag: &mut [f64]);
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>);
fn to_dense(&self) -> Array2<f64>;
fn fingerprint(&self, hasher: &mut Fingerprinter);
}
pub struct DensePenaltyOp(pub Array2<f64>);
impl BetaPenaltyOp for DensePenaltyOp {
fn dim(&self) -> usize {
self.0.nrows()
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * x[b];
}
y[a] += acc;
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let k = self.0.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.0[[a, b]] * beta[b];
}
out[a] += acc;
}
}
fn diagonal(&self, diag: &mut [f64]) {
let k = self.0.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.0[[j, j]];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.0[[range.start + bi, range.start + bj]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
self.0.clone()
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("dense-penalty-op-v1");
hasher.write_f64_array2(&self.0);
}
}
pub struct BlockPenaltyOp {
pub k: usize,
pub blocks: Vec<(usize, Array2<f64>)>,
}
impl BetaPenaltyOp for BlockPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * x[off + j];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
let gi = off + i;
let mut acc = 0.0_f64;
for j in 0..b {
acc += local[[i, j]] * beta[off + j];
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
for (off, local) in &self.blocks {
let b = local.nrows();
for j in 0..b {
diag[off + j] += local[[j, j]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_out = range.end - range.start;
for (off, local) in &self.blocks {
let b = local.nrows();
let block_end = off + b;
if block_end <= range.start || *off >= range.end {
continue;
}
for bi in 0..b_out {
let gi = range.start + bi;
if gi < *off || gi >= block_end {
continue;
}
let li = gi - off;
for bj in 0..b_out {
let gj = range.start + bj;
if gj < *off || gj >= block_end {
continue;
}
let lj = gj - off;
out[[bi, bj]] += local[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for (off, local) in &self.blocks {
let b = local.nrows();
for i in 0..b {
for j in 0..b {
out[[off + i, off + j]] += local[[i, j]];
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("block-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for (off, local) in &self.blocks {
hasher.write_usize(*off);
hasher.write_f64_array2(local);
}
}
}
pub struct KroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub factor_b: Array2<f64>,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for KroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * x[off + j_a * p_b + j_b];
}
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
acc += a_ij * self.factor_b[[i_b, j_b]] * beta[off + j_a * p_b + j_b];
}
}
out[gi] += acc;
}
}
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p_b {
diag[off + i_a * p_b + i_b] +=
self.factor_a[[i_a, i_a]] * self.factor_b[[i_b, i_b]];
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let block_end = off + p_a * p_b;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p_b;
let i_b = li % p_b;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p_b;
let j_b = lj % p_b;
out[[bi, bj]] += self.factor_a[[i_a, j_a]] * self.factor_b[[i_b, j_b]];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p_b = self.factor_b.nrows();
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for i_b in 0..p_b {
let gi = off + i_a * p_b + i_b;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for j_b in 0..p_b {
let gj = off + j_a * p_b + j_b;
out[[gi, gj]] += a_ij * self.factor_b[[i_b, j_b]];
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_f64_array2(&self.factor_a);
hasher.write_f64_array2(&self.factor_b);
}
}
pub struct IdentityRightKroneckerPenaltyOp {
pub factor_a: Array2<f64>,
pub p: usize,
pub global_offset: usize,
pub k: usize,
}
impl BetaPenaltyOp for IdentityRightKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let mut acc = 0.0_f64;
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
acc += a_ij * x[off + j_a * p + i_b];
}
y[gi] += acc;
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
for i_a in 0..p_a {
let a_ii = self.factor_a[[i_a, i_a]];
for i_b in 0..p {
diag[off + i_a * p + i_b] += a_ii;
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let block_end = off + p_a * p;
if block_end <= range.start || off >= range.end {
return;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < off || gi >= block_end {
continue;
}
let li = gi - off;
let i_a = li / p;
let i_b = li % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < off || gj >= block_end {
continue;
}
let lj = gj - off;
let j_a = lj / p;
let j_b = lj % p;
if i_b == j_b {
out[[bi, bj]] += self.factor_a[[i_a, j_a]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p_a = self.factor_a.nrows();
let p = self.p;
let off = self.global_offset;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for i_a in 0..p_a {
for j_a in 0..p_a {
let a_ij = self.factor_a[[i_a, j_a]];
if a_ij == 0.0 {
continue;
}
for i_b in 0..p {
let gi = off + i_a * p + i_b;
let gj = off + j_a * p + i_b;
out[[gi, gj]] += a_ij;
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("identity-right-kronecker-penalty-op-v1");
hasher.write_usize(self.global_offset);
hasher.write_usize(self.k);
hasher.write_usize(self.p);
hasher.write_f64_array2(&self.factor_a);
}
}
#[derive(Debug, Clone)]
pub struct SparseGBlock {
pub row_off: usize,
pub col_off: usize,
pub data: Array2<f64>,
}
pub struct SparseBlockKroneckerPenaltyOp {
pub p: usize,
pub dim_a: usize,
pub k: usize,
pub blocks: Vec<SparseGBlock>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaeSmoothBlock {
pub global_offset: usize,
pub factor_a: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct DeviceSaePcgData {
pub p: usize,
pub beta_dim: usize,
pub a_phi: Vec<Vec<(usize, f64)>>,
pub local_jac: Vec<Vec<f64>>,
pub smooth_blocks: Vec<DeviceSaeSmoothBlock>,
pub sparse_g_blocks: Vec<SparseGBlock>,
}
impl DeviceSaePcgData {
fn a_phi_shared(&self) -> Arc<[Vec<(usize, f64)>]> {
Arc::from(self.a_phi.clone().into_boxed_slice())
}
}
impl BetaPenaltyOp for SparseBlockKroneckerPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
y[gi_base + oc] += a_ij * x[gj_base + oc];
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
let p = self.p;
for blk in &self.blocks {
if blk.row_off != blk.col_off {
continue;
}
let (m_i, m_j) = blk.data.dim();
let m = m_i.min(m_j);
for li in 0..m {
let a_ii = blk.data[[li, li]];
let gi_base = (blk.row_off + li) * p;
for oc in 0..p {
diag[gi_base + oc] += a_ii;
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let p = self.p;
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
let row_start = blk.row_off * p;
let row_end = (blk.row_off + m_i) * p;
let col_start = blk.col_off * p;
let col_end = (blk.col_off + m_j) * p;
if row_end <= range.start
|| row_start >= range.end
|| col_end <= range.start
|| col_start >= range.end
{
continue;
}
for bi in 0..b {
let gi = range.start + bi;
if gi < row_start || gi >= row_end {
continue;
}
let li = (gi - row_start) / p;
let oc_i = (gi - row_start) % p;
for bj in 0..b {
let gj = range.start + bj;
if gj < col_start || gj >= col_end {
continue;
}
let oc_j = (gj - col_start) % p;
if oc_i != oc_j {
continue;
}
let lj = (gj - col_start) / p;
out[[bi, bj]] += blk.data[[li, lj]];
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let p = self.p;
let mut out = Array2::<f64>::zeros((self.k, self.k));
for blk in &self.blocks {
let (m_i, m_j) = blk.data.dim();
for li in 0..m_i {
let gi_base = (blk.row_off + li) * p;
for lj in 0..m_j {
let a_ij = blk.data[[li, lj]];
if a_ij == 0.0 {
continue;
}
let gj_base = (blk.col_off + lj) * p;
for oc in 0..p {
out[[gi_base + oc, gj_base + oc]] += a_ij;
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("sparse-block-kronecker-penalty-op-v1");
hasher.write_usize(self.p);
hasher.write_usize(self.dim_a);
hasher.write_usize(self.k);
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.row_off);
hasher.write_usize(blk.col_off);
hasher.write_f64_array2(&blk.data);
}
}
}
#[derive(Debug, Clone)]
pub struct FactoredFrameGBlock {
pub atom_i: usize,
pub atom_j: usize,
pub g: Array2<f64>,
pub w: Array2<f64>,
}
pub struct FactoredFrameKroneckerOp {
pub ranks: Vec<usize>,
pub basis_sizes: Vec<usize>,
pub offsets: Vec<usize>,
pub dim: usize,
pub blocks: Vec<FactoredFrameGBlock>,
}
pub fn frame_output_gram(u_i: ArrayView2<f64>, u_j: ArrayView2<f64>) -> Array2<f64> {
let (p_i, r_i) = u_i.dim();
let (p_j, r_j) = u_j.dim();
assert_eq!(
p_i, p_j,
"frame_output_gram: frames live in different ambient dims ({p_i} vs {p_j})"
);
let mut w = Array2::<f64>::zeros((r_i, r_j));
for a in 0..r_i {
for b in 0..r_j {
let mut acc = 0.0;
for c in 0..p_i {
acc += u_i[[c, a]] * u_j[[c, b]];
}
w[[a, b]] = acc;
}
}
w
}
impl FactoredFrameKroneckerOp {
pub fn new(
ranks: Vec<usize>,
basis_sizes: Vec<usize>,
blocks: Vec<FactoredFrameGBlock>,
) -> Result<Self, String> {
if ranks.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp: {} ranks but {} basis sizes",
ranks.len(),
basis_sizes.len()
));
}
let n_atoms = ranks.len();
let mut offsets = Vec::with_capacity(n_atoms + 1);
let mut acc = 0usize;
for k in 0..n_atoms {
offsets.push(acc);
acc += basis_sizes[k] * ranks[k];
}
offsets.push(acc);
let dim = acc;
for blk in &blocks {
if blk.atom_i >= n_atoms || blk.atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp: block atom indices ({}, {}) out of range (n_atoms = {n_atoms})",
blk.atom_i, blk.atom_j
));
}
if blk.g.dim() != (basis_sizes[blk.atom_i], basis_sizes[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) g has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.g.dim(),
basis_sizes[blk.atom_i],
basis_sizes[blk.atom_j]
));
}
if blk.w.dim() != (ranks[blk.atom_i], ranks[blk.atom_j]) {
return Err(format!(
"FactoredFrameKroneckerOp: block ({}, {}) w has shape {:?} but expected ({}, {})",
blk.atom_i,
blk.atom_j,
blk.w.dim(),
ranks[blk.atom_i],
ranks[blk.atom_j]
));
}
}
Ok(Self {
ranks,
basis_sizes,
offsets,
dim,
blocks,
})
}
pub fn from_frames_and_blocks(
frames: &[Option<Array2<f64>>],
basis_sizes: &[usize],
p: usize,
g_blocks: &std::collections::BTreeMap<(usize, usize), Array2<f64>>,
) -> Result<Self, String> {
if frames.len() != basis_sizes.len() {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: {} frames but {} basis sizes",
frames.len(),
basis_sizes.len()
));
}
let n_atoms = frames.len();
let mut ranks = Vec::with_capacity(n_atoms);
for (k, frame) in frames.iter().enumerate() {
match frame {
Some(u) => {
let (pr, r) = u.dim();
if pr != p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has {pr} rows but ambient dim is {p}"
));
}
if r > p {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: frame {k} has rank {r} > ambient dim {p}"
));
}
ranks.push(r);
}
None => ranks.push(p),
}
}
let identity = Array2::<f64>::eye(p);
let frame_or_ident = |k: usize| -> ArrayView2<f64> {
match &frames[k] {
Some(u) => u.view(),
None => identity.view(),
}
};
let mut blocks = Vec::with_capacity(g_blocks.len());
for (&(atom_i, atom_j), g) in g_blocks {
if atom_i >= n_atoms || atom_j >= n_atoms {
return Err(format!(
"FactoredFrameKroneckerOp::from_frames_and_blocks: block atom indices ({atom_i}, {atom_j}) out of range (n_atoms = {n_atoms})"
));
}
let w = frame_output_gram(frame_or_ident(atom_i), frame_or_ident(atom_j));
blocks.push(FactoredFrameGBlock {
atom_i,
atom_j,
g: g.clone(),
w,
});
}
Self::new(ranks, basis_sizes.to_vec(), blocks)
}
}
impl BetaPenaltyOp for FactoredFrameKroneckerOp {
fn dim(&self) -> usize {
self.dim
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
let yi_base = off_i + li * r_i;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
let xj_base = off_j + lj * r_j;
for a in 0..r_i {
let mut acc = 0.0;
for b in 0..r_j {
acc += blk.w[[a, b]] * x[xj_base + b];
}
y[yi_base + a] += g * acc;
}
}
}
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
self.matvec(beta, out);
}
fn diagonal(&self, diag: &mut [f64]) {
for blk in &self.blocks {
if blk.atom_i != blk.atom_j {
continue;
}
let r = self.ranks[blk.atom_i];
let off = self.offsets[blk.atom_i];
let (m_i, m_j) = blk.g.dim();
let m = m_i.min(m_j);
for li in 0..m {
let gii = blk.g[[li, li]];
let base = off + li * r;
for a in 0..r {
diag[base + a] += gii * blk.w[[a, a]];
}
}
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b_dim = range.end - range.start;
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for a in 0..r_i {
let gi = off_i + li * r_i + a;
if gi < range.start || gi >= range.end {
continue;
}
let bi = gi - range.start;
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
if gj < range.start || gj >= range.end {
continue;
}
let bj = gj - range.start;
if bi < b_dim && bj < b_dim {
out[[bi, bj]] += g * blk.w[[a, b]];
}
}
}
}
}
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.dim, self.dim));
for blk in &self.blocks {
let r_i = self.ranks[blk.atom_i];
let r_j = self.ranks[blk.atom_j];
let off_i = self.offsets[blk.atom_i];
let off_j = self.offsets[blk.atom_j];
let (m_i, m_j) = blk.g.dim();
for li in 0..m_i {
for lj in 0..m_j {
let g = blk.g[[li, lj]];
if g == 0.0 {
continue;
}
for a in 0..r_i {
let gi = off_i + li * r_i + a;
for b in 0..r_j {
let gj = off_j + lj * r_j + b;
out[[gi, gj]] += g * blk.w[[a, b]];
}
}
}
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("factored-frame-kronecker-op-v1");
hasher.write_usize(self.dim);
for &r in &self.ranks {
hasher.write_usize(r);
}
for &m in &self.basis_sizes {
hasher.write_usize(m);
}
hasher.write_usize(self.blocks.len());
for blk in &self.blocks {
hasher.write_usize(blk.atom_i);
hasher.write_usize(blk.atom_j);
hasher.write_f64_array2(&blk.g);
hasher.write_f64_array2(&blk.w);
}
}
}
pub struct CompositePenaltyOp {
pub k: usize,
pub ops: Vec<Arc<dyn BetaPenaltyOp>>,
}
impl BetaPenaltyOp for CompositePenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
for op in &self.ops {
op.matvec(x, y);
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
for op in &self.ops {
op.gradient(beta, out);
}
}
fn diagonal(&self, diag: &mut [f64]) {
for op in &self.ops {
op.diagonal(diag);
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
for op in &self.ops {
op.block(id, offsets, out);
}
}
fn to_dense(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.k, self.k));
for op in &self.ops {
let dense = op.to_dense();
out += &dense;
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("composite-penalty-op-v1");
hasher.write_usize(self.k);
hasher.write_usize(self.ops.len());
for op in &self.ops {
op.fingerprint(hasher);
}
}
}
pub struct MatvecDiagPenaltyOp {
k: usize,
matvec: SharedBetaMatvec,
diagonal_vec: Array1<f64>,
}
impl MatvecDiagPenaltyOp {
pub fn new(k: usize, matvec: SharedBetaMatvec, diagonal_vec: Array1<f64>) -> Self {
assert_eq!(diagonal_vec.len(), k);
Self {
k,
matvec,
diagonal_vec,
}
}
}
impl BetaPenaltyOp for MatvecDiagPenaltyOp {
fn dim(&self) -> usize {
self.k
}
fn matvec(&self, x: &[f64], y: &mut [f64]) {
let x_arr = Array1::from_iter(x.iter().copied());
let mut out = Array1::<f64>::zeros(self.k);
(self.matvec)(x_arr.view(), &mut out);
for a in 0..self.k {
y[a] += out[a];
}
}
fn gradient(&self, beta: &[f64], out: &mut [f64]) {
let beta_arr = Array1::from_iter(beta.iter().copied());
let mut hb = Array1::<f64>::zeros(self.k);
(self.matvec)(beta_arr.view(), &mut hb);
for a in 0..self.k {
out[a] += hb[a];
}
}
fn diagonal(&self, diag: &mut [f64]) {
for j in 0..self.k.min(diag.len()) {
diag[j] += self.diagonal_vec[j];
}
}
fn block(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
let range = &offsets[id.0];
let b = range.end - range.start;
let mut probe = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[range.start + bj] = 1.0;
let mut col = Array1::<f64>::zeros(self.k);
(self.matvec)(probe.view(), &mut col);
for bi in 0..b {
out[[bi, bj]] += col[range.start + bi];
}
}
}
fn to_dense(&self) -> Array2<f64> {
let k = self.k;
let mut out = Array2::<f64>::zeros((k, k));
let mut probe = Array1::<f64>::zeros(k);
for j in 0..k {
probe.fill(0.0);
probe[j] = 1.0;
let mut col = Array1::<f64>::zeros(k);
(self.matvec)(probe.view(), &mut col);
for i in 0..k {
out[[i, j]] = col[i];
}
}
out
}
fn fingerprint(&self, hasher: &mut Fingerprinter) {
hasher.write_str("matvec-diag-penalty-op-v1");
hasher.write_usize(self.k);
for &value in self.diagonal_vec.iter() {
hasher.write_f64(value);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArrowSolverMode {
Direct,
SqrtBA,
InexactPCG,
}
impl ArrowSolverMode {
pub const fn automatic(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::Direct
} else {
Self::InexactPCG
}
}
pub const fn automatic_for_single_precision(k: usize) -> Self {
if k <= DIRECT_SOLVE_MAX_K {
Self::SqrtBA
} else {
Self::InexactPCG
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum PcgStopReason {
#[default]
Converged,
MaxIter,
TrustRegion,
Indefinite,
Stagnation,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PcgDiagnostics {
pub iterations: usize,
pub matvec_calls: usize,
pub precond_apply_calls: usize,
pub ridge_escalations: usize,
pub final_relative_residual: f64,
pub stopping_reason: PcgStopReason,
pub mixed_precision_status: MixedPrecisionStatus,
pub used_device_arrow: bool,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MixedPrecisionStatus {
#[default]
Off,
Certified { refinement_steps: usize },
F64Fallback,
}
#[derive(Debug, Clone)]
pub struct ArrowPcgOptions {
pub max_iterations: usize,
pub relative_tolerance: f64,
}
impl Default for ArrowPcgOptions {
fn default() -> Self {
Self {
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTrustRegionOptions {
pub radius: f64,
pub steihaug_relative_tolerance: f64,
pub max_iterations: usize,
}
impl Default for ArrowTrustRegionOptions {
fn default() -> Self {
Self {
radius: DEFAULT_TRUST_REGION_RADIUS,
steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MixedPrecisionPolicy {
Off,
Certified {
max_refinement_steps: usize,
residual_relative_tolerance: f64,
kappa_unit_roundoff_margin: f64,
},
}
impl Default for MixedPrecisionPolicy {
fn default() -> Self {
Self::Off
}
}
impl MixedPrecisionPolicy {
pub fn certified() -> Self {
Self::Certified {
max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
}
}
fn is_enabled(self) -> bool {
matches!(self, MixedPrecisionPolicy::Certified { .. })
}
}
#[derive(Clone)]
pub struct ArrowSolveOptions {
pub mode: ArrowSolverMode,
pub pcg: ArrowPcgOptions,
pub trust_region: ArrowTrustRegionOptions,
pub streaming_chunk_size: Option<usize>,
pub riemannian_trust_region: bool,
pub gpu_matvec: Option<GpuSchurMatvec>,
pub tolerate_ill_conditioning: bool,
pub mixed_precision: MixedPrecisionPolicy,
}
impl std::fmt::Debug for ArrowSolveOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowSolveOptions")
.field("mode", &self.mode)
.field("pcg", &self.pcg)
.field("trust_region", &self.trust_region)
.field("streaming_chunk_size", &self.streaming_chunk_size)
.field("riemannian_trust_region", &self.riemannian_trust_region)
.field("gpu_matvec", &self.gpu_matvec.is_some())
.field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
.field("mixed_precision", &self.mixed_precision)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ArrowProximalCorrectionOptions {
pub initial_ridge: f64,
pub ridge_growth: f64,
pub max_attempts: usize,
pub armijo_c1: f64,
pub gradient_tolerance: f64,
pub convergence_objective_rel_tol: f64,
}
impl Default for ArrowProximalCorrectionOptions {
fn default() -> Self {
Self {
initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
armijo_c1: DEFAULT_ARMIJO_C1,
gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowAcceptedProximalStep {
pub delta_t: Array1<f64>,
pub delta_beta: Array1<f64>,
pub ridge_t: f64,
pub ridge_beta: f64,
pub proximal_ridge: f64,
pub objective_value: f64,
pub trial_objective_value: f64,
pub gradient_dot_step: f64,
pub attempts: usize,
}
impl ArrowSolveOptions {
pub fn automatic(k: usize) -> Self {
Self {
mode: ArrowSolverMode::automatic(k),
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn direct() -> Self {
Self {
mode: ArrowSolverMode::Direct,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn sqrt_ba() -> Self {
Self {
mode: ArrowSolverMode::SqrtBA,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn inexact_pcg() -> Self {
Self {
mode: ArrowSolverMode::InexactPCG,
pcg: ArrowPcgOptions::default(),
trust_region: ArrowTrustRegionOptions::default(),
streaming_chunk_size: None,
riemannian_trust_region: false,
gpu_matvec: None,
tolerate_ill_conditioning: false,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
self
}
pub fn with_ill_conditioning_tolerated(mut self) -> Self {
self.tolerate_ill_conditioning = true;
self
}
pub fn with_mixed_precision_policy(mut self, policy: MixedPrecisionPolicy) -> Self {
self.mixed_precision = policy;
self
}
#[must_use]
pub fn with_streaming_mixed_precision_default(&self) -> Self {
let mut out = self.clone();
if matches!(out.mixed_precision, MixedPrecisionPolicy::Off) {
out.mixed_precision = MixedPrecisionPolicy::certified();
}
out
}
}
pub trait BatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError>;
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64>;
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64>;
fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
}
#[derive(Debug, Clone)]
pub struct ArrowRowGaugeDeflation {
pub directions: Arc<[Vec<Array1<f64>>]>,
}
impl ArrowRowGaugeDeflation {
pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
Self {
directions: Arc::from(directions.into_boxed_slice()),
}
}
fn row(&self, row: usize) -> &[Array1<f64>] {
self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuBatchedBlockSolver;
impl BatchedBlockSolver for CpuBatchedBlockSolver {
fn factor_blocks(
&self,
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Result<ArrowFactorSlab, ArrowSchurError> {
if let Some(batched) =
try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
{
return Ok(batched);
}
let mut out = Vec::with_capacity(rows.len());
for (row_idx, row) in rows.iter().enumerate() {
out.push(factor_one_row(
row,
ridge_t,
d,
row_idx,
tolerate_ill_conditioning,
)?);
}
Ok(ArrowFactorSlab::from_blocks(out))
}
fn solve_block_vector(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView1<'_, f64>,
) -> Array1<f64> {
match (factor.nrows(), factor.ncols(), rhs.len()) {
(1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
(2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
(3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
(4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
_ => cholesky_solve_vector(factor, rhs),
}
}
fn solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
cholesky_solve_matrix(factor, rhs)
}
fn sqrt_solve_block_matrix(
&self,
factor: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Array2<f64> {
forward_substitution_lower_matrix(factor, rhs)
}
fn block_gemm_subtract(
&self,
schur: &mut Array2<f64>,
left: &Array2<f64>,
right: &Array2<f64>,
) {
let k = schur.nrows();
let d = left.nrows();
assert_eq!(left.ncols(), k);
assert_eq!(right.ncols(), k);
assert_eq!(schur.ncols(), k);
for c in 0..d {
for a in 0..k {
let lca = left[[c, a]];
if lca == 0.0 {
continue;
}
for b in 0..k {
schur[[a, b]] -= lca * right[[c, b]];
}
}
}
}
}
#[derive(Debug, Clone)]
struct ArrowRowFactorResult {
factor: Array2<f64>,
gauge_deflated_directions: usize,
}
fn try_factor_blocks_batched(
rows: &[ArrowRowBlock],
ridge_t: f64,
d: usize,
tolerate_ill_conditioning: bool,
) -> Option<ArrowFactorSlab> {
if d == 0 || rows.is_empty() {
return None;
}
if rows
.iter()
.any(|row| row.htt.dim() != (d, d) || row.gt.len() != d)
{
return None;
}
if !crate::gpu::runtime::GpuRuntime::is_available() {
return None;
}
let mut blocks: Vec<Array2<f64>> = Vec::with_capacity(rows.len());
for row in rows {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_t;
}
blocks.push(block);
}
crate::gpu::try_cholesky_batched_lower_inplace(&mut blocks)?;
if !tolerate_ill_conditioning {
for (row, factor) in rows.iter().zip(blocks.iter()) {
let diag_scale = row_block_diag_scale(row, d);
let kappa_est = cholesky_factor_kappa_estimate(factor);
if !cholesky_factor_passes_safe_inversion(factor, d, diag_scale, kappa_est) {
return None;
}
}
}
Some(ArrowFactorSlab::from_blocks(blocks))
}
fn row_block_diag_scale(row: &ArrowRowBlock, d: usize) -> f64 {
(0..d)
.map(|a| row.htt[[a, a]].abs())
.fold(0.0_f64, f64::max)
.max(1.0)
}
fn cholesky_factor_kappa_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for a in 0..d {
let v = factor[[a, a]];
if v < min_diag {
min_diag = v;
}
if v > max_diag {
max_diag = v;
}
}
if min_diag > 0.0 && max_diag.is_finite() {
let ratio = max_diag / min_diag;
ratio * ratio
} else {
f64::INFINITY
}
}
fn cholesky_factor_min_pivot_estimate(factor: &Array2<f64>) -> f64 {
let d = factor.nrows();
if d == 0 {
return 0.0;
}
let mut min_pivot = f64::INFINITY;
for a in 0..d {
let v = factor[[a, a]];
if !(v > 0.0 && v.is_finite()) {
return 0.0;
}
let pivot = v * v;
if pivot < min_pivot {
min_pivot = pivot;
}
}
min_pivot
}
fn safe_spd_pivot_min(diag_scale: f64) -> f64 {
f64::EPSILON.sqrt() * diag_scale.max(1.0)
}
fn cholesky_factor_passes_safe_inversion(
factor: &Array2<f64>,
dim: usize,
diag_scale: f64,
kappa_est: f64,
) -> bool {
kappa_est.is_finite()
&& kappa_est <= safe_spd_kappa_max(dim)
&& cholesky_factor_min_pivot_estimate(factor) >= safe_spd_pivot_min(diag_scale)
}
fn safe_spd_kappa_max(dim: usize) -> f64 {
let d_scale = (dim as f64).max(1.0);
1.0 / (f64::EPSILON.sqrt() * d_scale)
}
fn factor_row_block_cholesky(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
match d {
1 => factor_row_block_cholesky_fixed::<1>(row, ridge_eff),
2 => factor_row_block_cholesky_fixed::<2>(row, ridge_eff),
3 => factor_row_block_cholesky_fixed::<3>(row, ridge_eff),
4 => factor_row_block_cholesky_fixed::<4>(row, ridge_eff),
_ => factor_row_block_cholesky_dynamic(row, ridge_eff, d),
}
}
fn factor_row_block_cholesky_dynamic(
row: &ArrowRowBlock,
ridge_eff: f64,
d: usize,
) -> Result<Array2<f64>, String> {
let mut block = row.htt.clone();
for a in 0..d {
block[[a, a]] += ridge_eff;
}
cholesky_lower(&block)
}
fn factor_row_block_cholesky_fixed<const D: usize>(
row: &ArrowRowBlock,
ridge_eff: f64,
) -> Result<Array2<f64>, String> {
for i in 0..D {
for j in 0..D {
let value = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
if !value.is_finite() {
let idx = i * D + j;
return Err(format!(
"cholesky_lower: non-finite entry at linear index {idx}"
));
}
}
}
let mut l = [[0.0_f64; D]; D];
for i in 0..D {
for j in 0..=i {
let mut sum = if i == j {
row.htt[[i, j]] + ridge_eff
} else {
row.htt[[i, j]]
};
for kk in 0..j {
sum -= l[i][kk] * l[j][kk];
}
if i == j {
if !sum.is_finite() || sum <= 0.0 {
return Err(format!(
"non-PD pivot {sum} at index {i} (matrix is not positive definite)"
));
}
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let mut out = Array2::<f64>::zeros((D, D));
for i in 0..D {
for j in 0..=i {
out[[i, j]] = l[i][j];
}
}
Ok(out)
}
fn row_gauge_curvature(row: &ArrowRowBlock, d: usize, gauge: &Array1<f64>) -> Option<f64> {
if gauge.len() != d {
return None;
}
let mut acc = 0.0_f64;
for i in 0..d {
let gi = gauge[i];
for j in 0..d {
acc += gi * row.htt[[i, j]] * gauge[j];
}
}
if acc.is_finite() { Some(acc) } else { None }
}
fn factor_gauge_deflated_evidence_row(
row: &ArrowRowBlock,
d: usize,
gauges: &[Array1<f64>],
) -> Option<ArrowRowFactorResult> {
const GAUGE_RAYLEIGH_EPS: f64 = 1.0e-8;
if gauges.is_empty() {
return None;
}
let max_diag = row_block_diag_scale(row, d);
if !(max_diag.is_finite() && max_diag > 0.0) {
return None;
}
let mut basis: Vec<Array1<f64>> = Vec::new();
for gauge in gauges {
if gauge.len() != d {
continue;
}
let norm_sq = gauge.iter().map(|&v| v * v).sum::<f64>();
if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
continue;
}
let curvature = row_gauge_curvature(row, d, gauge)?;
if curvature.abs() > GAUGE_RAYLEIGH_EPS * max_diag * norm_sq {
continue;
}
let mut direction = gauge.clone();
for existing in &basis {
let coeff = direction.dot(existing);
for idx in 0..d {
direction[idx] -= coeff * existing[idx];
}
}
let residual_norm_sq = direction.iter().map(|&v| v * v).sum::<f64>();
if !(residual_norm_sq.is_finite() && residual_norm_sq > 1.0e-24) {
continue;
}
let inv_norm = residual_norm_sq.sqrt().recip();
for value in direction.iter_mut() {
*value *= inv_norm;
}
basis.push(direction);
}
if basis.is_empty() {
return None;
}
let mut deflated = row.htt.clone();
for direction in &basis {
for i in 0..d {
for j in 0..d {
deflated[[i, j]] += direction[i] * direction[j];
}
}
}
let factor = cholesky_lower(&deflated).ok()?;
Some(ArrowRowFactorResult {
factor,
gauge_deflated_directions: basis.len(),
})
}
fn cholesky_solve_vector_fixed<const D: usize>(
l: ArrayView2<'_, f64>,
b: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert!(
(0..D).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f64::MIN_POSITIVE),
"cholesky_solve_vector_fixed: factor diagonal must be finite and non-subnormal"
);
let mut y = [0.0_f64; D];
for i in 0..D {
let mut sum = b[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = [0.0_f64; D];
for i in (0..D).rev() {
let mut sum = y[i];
for k in (i + 1)..D {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
let mut out = Array1::<f64>::zeros(D);
for i in 0..D {
out[i] = x[i];
}
out
}
fn factor_one_row(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
) -> Result<Array2<f64>, ArrowSchurError> {
factor_one_row_result(row, ridge_t, d, row_idx, tolerate_ill_conditioning, &[])
.map(|result| result.factor)
}
fn factor_one_row_result(
row: &ArrowRowBlock,
ridge_t: f64,
d: usize,
row_idx: usize,
tolerate_ill_conditioning: bool,
row_gauges: &[Array1<f64>],
) -> Result<ArrowRowFactorResult, ArrowSchurError> {
if row.htt.dim() != (d, d) {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt shape {:?} does not match per_point_hessian_block dimension ({d}, {d})",
row.htt.dim()
),
});
}
if row.gt.len() != d {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} g_t length {} does not match latent dimension {d}",
row.gt.len()
),
});
}
const RIDGE_GROWTH_FACTOR: f64 = 10.0;
const RIDGE_SEED_DIAG_FRACTION: f64 = 1.0e-10;
const RIDGE_CAP_DIAG_FRACTION: f64 = 1.0e-12;
const RIDGE_CAP_SCALE: f64 = 1.0e12;
let diag_scale = row_block_diag_scale(row, d);
let ridge_cap = ridge_t.max(RIDGE_CAP_DIAG_FRACTION * diag_scale) * RIDGE_CAP_SCALE;
let mut ridge_eff = ridge_t;
let factor = loop {
match factor_row_block_cholesky(row, ridge_eff, d) {
Ok(factor) => {
if tolerate_ill_conditioning {
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let kappa_est = cholesky_factor_kappa_estimate(&factor);
if cholesky_factor_passes_safe_inversion(&factor, d, diag_scale, kappa_est) {
break ArrowRowFactorResult {
factor,
gauge_deflated_directions: 0,
};
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorIllConditioned {
row: row_idx,
kappa_estimate: kappa_est,
});
}
ridge_eff = next;
}
Err(e) => {
if tolerate_ill_conditioning {
if ridge_t == 0.0 {
if let Some(deflated) =
factor_gauge_deflated_evidence_row(row, d, row_gauges)
{
return Ok(deflated);
}
}
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt is non-PD at base ridge {ridge_t:e}; \
evidence mode preserves the genuine Cholesky of \
H_tt and does not condition non-PD blocks: {e}"
),
});
}
let next = if ridge_eff > 0.0 {
ridge_eff * RIDGE_GROWTH_FACTOR
} else {
RIDGE_SEED_DIAG_FRACTION * diag_scale
};
if !next.is_finite() || next > ridge_cap {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"row {row_idx} H_tt remained non-PD up to ridge {ridge_eff:e} \
(base ridge_t={ridge_t}); last cholesky error: {e}"
),
});
}
ridge_eff = next;
}
}
};
Ok(factor)
}
fn manifold_mode_fingerprint(latent: &LatentCoordValues) -> u64 {
let manifold = latent.manifold();
if manifold.is_euclidean() {
return EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-manifold-mode-v1");
hasher.write_usize(latent.n_obs());
hasher.write_usize(latent.latent_dim());
write_latent_manifold(&mut hasher, manifold);
let mut metric_weights = Vec::new();
append_latent_metric_weights(&mut metric_weights, manifold);
hasher.write_usize(metric_weights.len());
for weight in metric_weights {
hasher.write_f64(weight);
}
hasher.finish_u64()
}
fn row_hessian_fingerprint_for_system(sys: &ArrowSchurSystem) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-v2");
hasher.write_usize(sys.rows.len());
hasher.write_usize(sys.d);
hasher.write_usize(sys.k);
let htbeta_op_addr: Option<usize> = sys
.htbeta_matvec
.as_ref()
.map(|op| Arc::as_ptr(op) as *const () as usize);
for row in sys.rows.iter() {
hasher.write_f64_array2(&row.htt);
match htbeta_op_addr {
Some(addr) => {
hasher.write_usize(addr);
if sys.htbeta_dense_supplement {
hasher.write_f64_array2(&row.htbeta);
}
}
None => hasher.write_f64_array2(&row.htbeta),
}
}
match sys.penalty_op.as_ref() {
Some(op) => {
hasher.write_bool(true);
op.fingerprint(&mut hasher);
}
None => {
hasher.write_bool(false);
hasher.write_f64_array2(&sys.hbb);
}
}
match sys.hbb_diag.as_ref() {
Some(diag) => {
hasher.write_bool(true);
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
}
None => hasher.write_bool(false),
}
hasher.finish_u64()
}
fn combine_row_and_registry_fingerprints(row: u64, registry: u64) -> u64 {
if registry == 0 {
return row;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-with-penalties-v1");
hasher.write_u64(row);
hasher.write_u64(registry);
hasher.finish_u64()
}
fn analytic_penalty_row_hessian_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> Option<u64> {
if penalty.tier() != PenaltyTier::Psi || !analytic_penalty_is_row_block_diagonal(penalty) {
return None;
}
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
match penalty {
AnalyticPenaltyKind::RowPrecisionPrior(p) => {
let (n, rows, cols) = p.lambda_per_row.dim();
hasher.write_str("row-precision-fixed");
hasher.write_usize(n);
hasher.write_usize(rows);
hasher.write_usize(cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
if p.learnable_weight {
hasher.write_usize(p.rho_index);
hasher.write_f64(p.weight * rho_local[p.rho_index].exp());
}
for &value in p.lambda_per_row.iter() {
hasher.write_f64(value);
}
}
AnalyticPenaltyKind::ParametricRowPrecisionPrior(p) => {
let (aux_n, aux_dim) = p.aux.dim();
let (mu_rows, mu_cols) = p.mu.dim();
let weight_offset = p.log_alpha.len() + p.raw_beta.len() + p.mu.len();
hasher.write_str("row-precision-parametric");
hasher.write_usize(aux_n);
hasher.write_usize(aux_dim);
hasher.write_usize(mu_rows);
hasher.write_usize(mu_cols);
hasher.write_f64(p.weight);
hasher.write_bool(p.learnable_weight);
for &value in p.aux.iter() {
hasher.write_f64(value);
}
for k in 0..p.log_alpha.len() {
let active_log_alpha = p.log_alpha[k] + rho_local[k];
hasher.write_f64(p.log_alpha[k]);
hasher.write_f64(active_log_alpha);
hasher.write_f64(active_log_alpha.exp());
}
let raw_beta_offset = p.log_alpha.len();
for k in 0..p.raw_beta.len() {
let active_raw_beta = p.raw_beta[k] + rho_local[raw_beta_offset + k];
hasher.write_f64(p.raw_beta[k]);
hasher.write_f64(active_raw_beta);
hasher.write_f64(crate::linalg::utils::stable_softplus(active_raw_beta));
}
let mu_offset = p.log_alpha.len() + p.raw_beta.len();
for k in 0..p.mu.nrows() {
for a in 0..p.mu.ncols() {
let idx = mu_offset + k * p.aux.ncols() + a;
hasher.write_f64(p.mu[[k, a]]);
hasher.write_f64(p.mu[[k, a]] + rho_local[idx]);
}
}
if p.learnable_weight {
hasher.write_usize(weight_offset);
hasher.write_f64(p.weight * rho_local[weight_offset].exp());
}
}
_ => {
hasher.write_str("row-block-diagonal");
if let Some(diag) = penalty.hessian_diag(target_t, rho_local) {
hasher.write_usize(diag.len());
for &value in diag.iter() {
hasher.write_f64(value);
}
} else {
hasher.write_usize(0);
}
}
}
Some(hasher.finish_u64())
}
fn cross_row_penalty_fingerprint(
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) -> u64 {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-analytic-cross-row-hessian-v1");
hasher.write_str(penalty.name());
hasher.write_usize(target_t.len());
hasher.write_usize(rho_local.len());
for &rho in rho_local.iter() {
hasher.write_f64(rho);
}
let probe = penalty.psd_majorizer_hvp(target_t, rho_local, target_t);
hasher.write_usize(probe.len());
for &value in probe.iter() {
hasher.write_f64(value);
}
hasher.finish_u64()
}
fn write_latent_manifold(hasher: &mut Fingerprinter, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => {
hasher.write_str("euclidean");
}
LatentManifold::Circle { period } => {
hasher.write_str("circle");
hasher.write_f64(*period);
}
LatentManifold::Sphere { dim } => {
hasher.write_str("sphere");
hasher.write_usize(*dim);
}
LatentManifold::Interval { lo, hi } => {
hasher.write_str("interval");
hasher.write_f64(*lo);
hasher.write_f64(*hi);
}
LatentManifold::Product(parts) => {
hasher.write_str("product");
hasher.write_usize(parts.len());
for part in parts {
write_latent_manifold(hasher, part);
}
}
LatentManifold::ProductWithMetric { manifolds, weights } => {
hasher.write_str("product-with-metric");
hasher.write_usize(manifolds.len());
for part in manifolds {
write_latent_manifold(hasher, part);
}
hasher.write_usize(weights.len());
for weight in weights {
hasher.write_f64(*weight);
}
}
}
}
fn append_latent_metric_weights(out: &mut Vec<f64>, manifold: &LatentManifold) {
match manifold {
LatentManifold::Euclidean => out.push(1.0),
LatentManifold::Circle { period } => {
out.push(1.0 / (period * period));
}
LatentManifold::Sphere { dim } => {
let scale = std::f64::consts::PI;
for _ in 0..*dim {
out.push(1.0 / (scale * scale));
}
}
LatentManifold::Interval { lo, hi } => {
let scale = hi - lo;
out.push(1.0 / (scale * scale));
}
LatentManifold::Product(parts) => {
for part in parts {
append_latent_metric_weights(out, part);
}
}
LatentManifold::ProductWithMetric {
manifolds: _,
weights,
} => {
out.extend(weights.iter().copied());
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowRowBlock {
pub htt: Array2<f64>,
pub htbeta: Array2<f64>,
pub gt: Array1<f64>,
}
impl ArrowRowBlock {
pub fn new(d: usize, k: usize) -> Self {
Self::new_with_htbeta_cols(d, k)
}
pub fn new_with_htbeta_cols(d: usize, htbeta_cols: usize) -> Self {
Self {
htt: Array2::<f64>::zeros((d, d)),
htbeta: Array2::<f64>::zeros((d, htbeta_cols)),
gt: Array1::<f64>::zeros(d),
}
}
}
pub struct ArrowSchurSystem {
pub rows: Vec<ArrowRowBlock>,
pub hbb: Array2<f64>,
pub hbb_matvec: Option<SharedBetaMatvec>,
pub htbeta_matvec: Option<RowHtbetaMatvec>,
pub htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
pub htbeta_dense_supplement: bool,
pub hbb_diag: Option<Array1<f64>>,
pub gb: Array1<f64>,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub analytic_row_hessian_fingerprint: u64,
pub block_offsets: Arc<[Range<usize>]>,
pub penalty_op: Option<Arc<dyn BetaPenaltyOp>>,
pub device_sae_pcg: Option<Arc<DeviceSaePcgData>>,
pub cross_row_penalties: Vec<CrossRowLatentPenalty>,
pub row_gauge_deflation: Option<ArrowRowGaugeDeflation>,
pub ibp_cross_row: Option<IbpCrossRowSource>,
}
impl Clone for ArrowSchurSystem {
fn clone(&self) -> Self {
Self {
rows: self.rows.clone(),
hbb: self.hbb.clone(),
hbb_matvec: self.hbb_matvec.clone(),
htbeta_matvec: self.htbeta_matvec.clone(),
htbeta_transpose_matvec: self.htbeta_transpose_matvec.clone(),
htbeta_dense_supplement: self.htbeta_dense_supplement,
hbb_diag: self.hbb_diag.clone(),
gb: self.gb.clone(),
d: self.d,
row_dims: Arc::clone(&self.row_dims),
row_offsets: Arc::clone(&self.row_offsets),
k: self.k,
manifold_mode_fingerprint: self.manifold_mode_fingerprint,
row_hessian_fingerprint: self.row_hessian_fingerprint,
analytic_row_hessian_fingerprint: self.analytic_row_hessian_fingerprint,
block_offsets: Arc::clone(&self.block_offsets),
penalty_op: self.penalty_op.clone(),
device_sae_pcg: self.device_sae_pcg.clone(),
cross_row_penalties: self.cross_row_penalties.clone(),
row_gauge_deflation: self.row_gauge_deflation.clone(),
ibp_cross_row: self.ibp_cross_row.clone(),
}
}
}
#[derive(Clone)]
pub struct CrossRowLatentPenalty {
pub penalty: AnalyticPenaltyKind,
pub rho_local: Array1<f64>,
pub target_t: Array1<f64>,
}
#[derive(Clone, Debug, Default)]
pub struct IbpCrossRowSource {
pub r: usize,
pub d: Array1<f64>,
pub entries: Vec<(usize, usize, f64)>,
}
impl IbpCrossRowSource {
fn dense_u(&self, delta_t_len: usize) -> Array2<f64> {
let mut u = Array2::<f64>::zeros((delta_t_len, self.r));
for &(g, k, z) in &self.entries {
u[[g, k]] += z;
}
u
}
fn self_term_downdate(&self, delta_t_len: usize) -> Array1<f64> {
let mut down = Array1::<f64>::zeros(delta_t_len);
for &(g, k, z) in &self.entries {
down[g] += self.d[k] * z * z;
}
down
}
}
impl ArrowSchurSystem {
pub fn new(n: usize, d: usize, k: usize) -> Self {
Self::new_with_hbb(n, d, k, Array2::<f64>::zeros((k, k)))
}
pub fn new_with_empty_hbb(n: usize, d: usize, k: usize) -> Self {
Self::new_with_empty_hbb_and_htbeta_cols(n, d, k, k)
}
pub fn new_with_empty_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
htbeta_cols: usize,
) -> Self {
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_hbb(n: usize, d: usize, k: usize, hbb: Array2<f64>) -> Self {
Self::new_with_hbb_and_htbeta_cols(n, d, k, hbb, k)
}
pub fn new_with_hbb_and_htbeta_cols(
n: usize,
d: usize,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let rows = (0..n)
.map(|_| ArrowRowBlock::new_with_htbeta_cols(d, htbeta_cols))
.collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_matrix_free_shared<F>(
n: usize,
d: usize,
k: usize,
matvec: F,
diag: Array1<f64>,
) -> Self
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), k);
let rows = (0..n).map(|_| ArrowRowBlock::new(d, k)).collect();
let row_dims: Arc<[usize]> = (0..n).map(|_| d).collect::<Vec<_>>().into();
let row_offsets: Arc<[usize]> = (0..=n).map(|i| i * d).collect::<Vec<_>>().into();
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
let penalty_op: Option<Arc<dyn BetaPenaltyOp>> = Some(Arc::new(MatvecDiagPenaltyOp::new(
k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: Some(matvec_arc),
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: Some(diag),
gb: Array1::<f64>::zeros(k),
d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_per_row_dims(per_row_dims: Vec<usize>, k: usize) -> Self {
Self::new_with_per_row_dims_and_hbb(per_row_dims, k, Array2::<f64>::zeros((k, k)))
}
pub fn new_with_per_row_dims_empty_hbb(per_row_dims: Vec<usize>, k: usize) -> Self {
Self::new_with_per_row_dims_empty_hbb_and_htbeta_cols(per_row_dims, k, k)
}
pub fn new_with_per_row_dims_empty_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
htbeta_cols: usize,
) -> Self {
let n = per_row_dims.len();
let d = per_row_dims.iter().copied().max().unwrap_or(0);
let mut offsets = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
offsets.push(cursor);
for &dim in &per_row_dims {
cursor += dim;
offsets.push(cursor);
}
let rows = per_row_dims
.iter()
.map(|&dim| ArrowRowBlock::new_with_htbeta_cols(dim, htbeta_cols))
.collect();
Self {
rows,
hbb: Array2::<f64>::zeros((0, 0)),
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d,
row_dims: Arc::from(per_row_dims.into_boxed_slice()),
row_offsets: Arc::from(offsets.into_boxed_slice()),
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn new_with_per_row_dims_and_hbb(
per_row_dims: Vec<usize>,
k: usize,
hbb: Array2<f64>,
) -> Self {
Self::new_with_per_row_dims_and_hbb_and_htbeta_cols(per_row_dims, k, hbb, k)
}
pub fn new_with_per_row_dims_and_hbb_and_htbeta_cols(
per_row_dims: Vec<usize>,
k: usize,
mut hbb: Array2<f64>,
htbeta_cols: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
hbb.fill(0.0);
let n = per_row_dims.len();
let max_d = per_row_dims.iter().copied().max().unwrap_or(0);
let row_dims: Arc<[usize]> = per_row_dims.iter().copied().collect::<Vec<_>>().into();
let mut off_vec = Vec::with_capacity(n + 1);
let mut cursor = 0usize;
for &di in &per_row_dims {
off_vec.push(cursor);
cursor += di;
}
off_vec.push(cursor);
let row_offsets: Arc<[usize]> = off_vec.into();
let rows = per_row_dims
.iter()
.map(|&di| ArrowRowBlock::new_with_htbeta_cols(di, htbeta_cols))
.collect();
Self {
rows,
hbb,
hbb_matvec: None,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
htbeta_dense_supplement: false,
hbb_diag: None,
gb: Array1::<f64>::zeros(k),
d: max_d,
row_dims,
row_offsets,
k,
manifold_mode_fingerprint: EUCLIDEAN_MANIFOLD_MODE_FINGERPRINT,
row_hessian_fingerprint: 0,
analytic_row_hessian_fingerprint: 0,
block_offsets: Arc::from([] as [Range<usize>; 0]),
penalty_op: None,
device_sae_pcg: None,
cross_row_penalties: Vec::new(),
row_gauge_deflation: None,
ibp_cross_row: None,
}
}
pub fn set_row_gauge_deflation(&mut self, deflation: ArrowRowGaugeDeflation) {
self.row_gauge_deflation = Some(deflation);
}
pub fn set_ibp_cross_row_source(&mut self, source: IbpCrossRowSource) {
if source.r == 0 || source.entries.is_empty() {
self.ibp_cross_row = None;
} else {
self.ibp_cross_row = Some(source);
}
}
pub fn n(&self) -> usize {
self.rows.len()
}
pub fn compute_row_hessian_fingerprint(&self) -> u64 {
row_hessian_fingerprint_for_system(self)
}
pub fn current_row_hessian_fingerprint(&self) -> u64 {
combine_row_and_registry_fingerprints(
self.compute_row_hessian_fingerprint(),
self.analytic_row_hessian_fingerprint,
)
}
pub fn refresh_row_hessian_fingerprint(&mut self) {
self.row_hessian_fingerprint = self.current_row_hessian_fingerprint();
}
pub fn set_shared_beta_operator<F>(&mut self, matvec: F, diag: Array1<f64>)
where
F: for<'a> Fn(ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
assert_eq!(diag.len(), self.k);
let matvec_arc: SharedBetaMatvec = Arc::new(matvec);
self.penalty_op = Some(Arc::new(MatvecDiagPenaltyOp::new(
self.k,
Arc::clone(&matvec_arc),
diag.clone(),
)));
self.hbb_matvec = Some(matvec_arc);
self.hbb_diag = Some(diag);
}
pub fn activate_dense_htbeta_supplement(&mut self) {
self.htbeta_dense_supplement = true;
}
pub fn set_row_htbeta_operator<F, T>(&mut self, forward: F, transpose: T)
where
F: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
T: for<'a> Fn(usize, ArrayView1<'a, f64>, &mut Array1<f64>) + Send + Sync + 'static,
{
self.htbeta_matvec = Some(Arc::new(forward));
self.htbeta_transpose_matvec = Some(Arc::new(transpose));
}
pub fn set_block_offsets(&mut self, offsets: Arc<[Range<usize>]>) {
self.block_offsets = offsets;
}
pub fn set_penalty_op(&mut self, op: Arc<dyn BetaPenaltyOp>) {
self.penalty_op = Some(op);
}
pub fn set_device_sae_pcg_data(&mut self, data: DeviceSaePcgData) {
assert_eq!(data.beta_dim, self.k);
assert_eq!(data.a_phi.len(), self.rows.len());
assert_eq!(data.local_jac.len(), self.rows.len());
self.device_sae_pcg = Some(Arc::new(data));
}
pub fn effective_penalty_op(&self) -> Arc<dyn BetaPenaltyOp> {
match self.penalty_op.as_ref() {
Some(op) => Arc::clone(op),
None => Arc::new(DensePenaltyOp(self.hbb.clone())),
}
}
#[inline]
fn penalty_matvec_add(&self, x: &[f64], y: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.matvec(x, y);
} else {
let k = self.hbb.nrows();
for a in 0..k {
let mut acc = 0.0_f64;
for b in 0..k {
acc += self.hbb[[a, b]] * x[b];
}
y[a] += acc;
}
}
}
fn penalty_ridge_prologue_into(&self, x: &[f64], ridge: f64, y: &mut [f64], parallel: bool) {
let k = self.hbb.nrows();
let dense_parallel = parallel
&& self.penalty_op.is_none()
&& self.hbb.dim() == (k, k)
&& k >= SCHUR_PROLOGUE_PARALLEL_K_MIN;
if dense_parallel {
use rayon::prelude::*;
let hbb = &self.hbb;
y.par_iter_mut().enumerate().for_each(|(a, ya)| {
let mut acc = 0.0_f64;
for b in 0..k {
acc += hbb[[a, b]] * x[b];
}
*ya = acc + ridge * x[a];
});
} else {
self.penalty_matvec_add(x, y);
for a in 0..k {
y[a] += ridge * x[a];
}
}
}
#[inline]
fn penalty_diagonal_add(&self, diag: &mut [f64]) {
if let Some(op) = self.penalty_op.as_ref() {
op.diagonal(diag);
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
let k = hbb_diag.len().min(diag.len());
for j in 0..k {
diag[j] += hbb_diag[j];
}
} else {
let k = self.hbb.nrows().min(diag.len());
for j in 0..k {
diag[j] += self.hbb[[j, j]];
}
}
}
#[inline]
fn penalty_block_add(&self, id: BetaBlockId, offsets: &[Range<usize>], out: &mut Array2<f64>) {
if let Some(op) = self.penalty_op.as_ref() {
op.block(id, offsets, out);
} else {
let range = &offsets[id.0];
let b = range.end - range.start;
if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[range.start + bi, range.start + bj]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[range.start + bi];
}
}
}
}
#[inline]
fn penalty_subblock_add(&self, cols: &[usize], out: &mut Array2<f64>) {
let b = cols.len();
if let Some(op) = self.penalty_op.as_ref() {
let mut probe = Array1::<f64>::zeros(self.k);
let mut result = Array1::<f64>::zeros(self.k);
for bj in 0..b {
probe.fill(0.0);
probe[cols[bj]] = 1.0;
result.fill(0.0);
{
let p_slice = probe.as_slice().expect("probe contiguous");
let r_slice = result.as_slice_mut().expect("result contiguous");
op.matvec(p_slice, r_slice);
}
for bi in 0..b {
out[[bi, bj]] += result[cols[bi]];
}
}
} else if self.hbb.dim() == (self.k, self.k) {
for bi in 0..b {
for bj in 0..b {
out[[bi, bj]] += self.hbb[[cols[bi], cols[bj]]];
}
}
} else if let Some(hbb_diag) = self.hbb_diag.as_ref() {
for bi in 0..b {
out[[bi, bi]] += hbb_diag[cols[bi]];
}
}
}
pub fn add_analytic_penalty_contributions(
&mut self,
registry: &AnalyticPenaltyRegistry,
target_t: ArrayView1<'_, f64>,
target_beta: ArrayView1<'_, f64>,
rho_global: ArrayView1<'_, f64>,
) -> Result<(), ArrowSchurError> {
let layout = registry.rho_layout();
let mut penalty_fingerprints = Vec::new();
self.cross_row_penalties.clear();
for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
let rho_local = rho_global.slice(ndarray::s![rho_slice.clone()]);
match tier {
PenaltyTier::Psi => {
if analytic_penalty_is_row_block_diagonal(penalty) {
self.add_ext_coord_penalty(penalty, target_t, rho_local);
if let Some(fingerprint) =
analytic_penalty_row_hessian_fingerprint(penalty, target_t, rho_local)
{
penalty_fingerprints.push(fingerprint);
}
} else {
self.add_ext_coord_penalty_gradient_only(penalty, target_t, rho_local);
self.cross_row_penalties.push(CrossRowLatentPenalty {
penalty: penalty.clone(),
rho_local: rho_local.to_owned(),
target_t: target_t.to_owned(),
});
}
}
PenaltyTier::Beta => {
self.add_beta_penalty(penalty, target_beta, rho_local);
}
PenaltyTier::Rho => {
}
}
}
for cross in &self.cross_row_penalties {
penalty_fingerprints.push(cross_row_penalty_fingerprint(
&cross.penalty,
target_t,
cross.rho_local.view(),
));
}
self.analytic_row_hessian_fingerprint = if penalty_fingerprints.is_empty() {
0
} else {
let mut hasher = Fingerprinter::new();
hasher.write_str("arrow-schur-row-hessian-registry-v1");
hasher.write_usize(penalty_fingerprints.len());
for fingerprint in penalty_fingerprints {
hasher.write_u64(fingerprint);
}
hasher.finish_u64()
};
Ok(())
}
pub fn apply_riemannian_latent_geometry(&mut self, latent: &LatentCoordValues) {
let manifold = latent.manifold();
self.manifold_mode_fingerprint = manifold_mode_fingerprint(latent);
if manifold.is_euclidean() {
return;
}
assert_eq!(latent.n_obs(), self.rows.len());
assert_eq!(latent.latent_dim(), self.d);
for (i, row) in self.rows.iter_mut().enumerate() {
let t_i = ArrayView1::from(latent.row(i));
let gt_e = row.gt.clone();
let htt_e = row.htt.clone();
let htbeta_e = row.htbeta.clone();
row.gt = manifold.project_gradient_to_tangent(t_i, gt_e.view());
row.htt = manifold.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
row.htbeta = manifold.project_matrix_columns_to_gradient_tangent(
t_i,
gt_e.view(),
htbeta_e.view(),
);
}
}
fn add_ext_coord_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
apply_analytic_penalty(
penalty,
target_t,
rho_local,
n * d,
d,
self,
|sys, flat, value| sys.rows[flat / d].gt[flat % d] += value,
|sys, flat, value| sys.rows[flat / d].htt[[flat % d, flat % d]] += value,
|a, probe| {
for i in 0..n {
probe[i * d + a] = 1.0;
}
},
|sys, a, hv| {
for i in 0..n {
for b in 0..d {
sys.rows[i].htt[[b, a]] += hv[i * d + b];
}
}
},
);
}
fn add_ext_coord_penalty_gradient_only(
&mut self,
penalty: &AnalyticPenaltyKind,
target_t: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let d = self.d;
let n = self.rows.len();
assert_eq!(target_t.len(), n * d);
let grad = penalty.grad_target(target_t, rho_local);
for flat in 0..n * d {
self.rows[flat / d].gt[flat % d] += grad[flat];
}
}
fn apply_cross_row_penalty_hessian(&self, v: ArrayView1<'_, f64>, out: &mut Array1<f64>) {
for cross in &self.cross_row_penalties {
assert_eq!(cross.target_t.len(), v.len());
let hv =
cross
.penalty
.psd_majorizer_hvp(cross.target_t.view(), cross.rho_local.view(), v);
assert_eq!(hv.len(), out.len());
for i in 0..out.len() {
out[i] += hv[i];
}
}
}
fn add_beta_penalty(
&mut self,
penalty: &AnalyticPenaltyKind,
target_beta: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
) {
let k = self.k;
let hvp_columns = if self.hbb.dim() == (k, k) { k } else { 0 };
apply_analytic_penalty(
penalty,
target_beta,
rho_local,
k,
hvp_columns,
self,
|sys, j, value| sys.gb[j] += value,
|sys, j, value| {
if sys.hbb.dim() == (k, k) {
sys.hbb[[j, j]] += value;
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += value;
}
},
|j, probe| probe[j] = 1.0,
|sys, j, hv| {
for i in 0..k {
sys.hbb[[i, j]] += hv[i];
}
if let Some(hbb_diag) = sys.hbb_diag.as_mut() {
hbb_diag[j] += hv[j];
}
},
);
}
pub fn solve(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_lm_escalation(
&self,
ridge_t: f64,
ridge_beta: f64,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let options = ArrowSolveOptions::automatic(self.k);
solve_with_lm_escalation_inner(self, ridge_t, ridge_beta, &options)
}
pub fn solve_with_options(
&self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
solve_arrow_newton_step_core(self, ridge_t, ridge_beta, options)
}
}
pub struct StreamingArrowSchur {
pub n_rows: usize,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub chunk_size: usize,
pub s_acc: Array2<f64>,
rhs_acc: Array1<f64>,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
htbeta_matvec: Option<RowHtbetaMatvec>,
htbeta_transpose_matvec: Option<RowHtbetaTransposeMatvec>,
tolerate_ill_conditioning: bool,
ibp_cross_row_active: bool,
}
impl std::fmt::Debug for StreamingArrowSchur {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingArrowSchur")
.field("n_rows", &self.n_rows)
.field("d", &self.d)
.field("k", &self.k)
.field("chunk_size", &self.chunk_size)
.finish_non_exhaustive()
}
}
impl StreamingArrowSchur {
#[must_use]
pub fn new(
n_rows: usize,
d: usize,
row_dims: Arc<[usize]>,
row_offsets: Arc<[usize]>,
k: usize,
hbb: Array2<f64>,
gb: Array1<f64>,
row_builder: StreamingArrowRowBuilder,
chunk_size: usize,
) -> Self {
assert_eq!(hbb.dim(), (k, k));
assert_eq!(gb.len(), k);
Self {
n_rows,
d,
row_dims,
row_offsets,
k,
chunk_size: chunk_size.max(1),
s_acc: Array2::<f64>::zeros((k, k)),
rhs_acc: Array1::<f64>::zeros(k),
hbb,
gb,
row_builder,
htbeta_matvec: None,
htbeta_transpose_matvec: None,
tolerate_ill_conditioning: false,
ibp_cross_row_active: false,
}
}
#[must_use]
pub fn from_system(sys: &ArrowSchurSystem, chunk_size: usize) -> Self {
let htbeta_matvec = sys.htbeta_matvec.clone();
let rows: Vec<ArrowRowBlock> = if htbeta_matvec.is_some() {
sys.rows
.iter()
.map(|row| ArrowRowBlock {
htt: row.htt.clone(),
htbeta: Array2::<f64>::zeros((0, 0)),
gt: row.gt.clone(),
})
.collect()
} else {
sys.rows.clone()
};
let rows = Arc::new(rows);
let row_builder: StreamingArrowRowBuilder = Arc::new(move |row| {
rows.get(row)
.cloned()
.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: format!("streaming row {row} out of bounds"),
})
});
let hbb_dense = sys.effective_penalty_op().to_dense();
let mut streaming = Self::new(
sys.rows.len(),
sys.d,
Arc::clone(&sys.row_dims),
Arc::clone(&sys.row_offsets),
sys.k,
hbb_dense,
sys.gb.clone(),
row_builder,
chunk_size,
);
streaming.htbeta_matvec = htbeta_matvec;
streaming.htbeta_transpose_matvec = sys.htbeta_transpose_matvec.clone();
streaming.ibp_cross_row_active = sys.ibp_cross_row.is_some();
streaming
}
fn row_htbeta(&self, row_idx: usize, row: &ArrowRowBlock, di: usize) -> Array2<f64> {
if let Some(op_t) = self.htbeta_transpose_matvec.as_ref() {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_c = Array1::<f64>::zeros(di);
let mut beta_row = Array1::<f64>::zeros(self.k);
for c in 0..di {
e_c.fill(0.0);
e_c[c] = 1.0;
beta_row.fill(0.0);
op_t(row_idx, e_c.view(), &mut beta_row);
for a in 0..self.k {
mat[[c, a]] = beta_row[a];
}
}
return mat;
}
match self.htbeta_matvec.as_ref() {
Some(op) => {
let mut mat = Array2::<f64>::zeros((di, self.k));
let mut e_a = Array1::<f64>::zeros(self.k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..self.k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] = col[c];
}
}
mat
}
None => row.htbeta.clone(),
}
}
#[must_use]
pub fn take_accumulators(&mut self) -> (Array2<f64>, Array1<f64>) {
let s = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
let rhs = std::mem::replace(&mut self.rhs_acc, Array1::<f64>::zeros(self.k));
(s, rhs)
}
pub fn reset_accumulator(&mut self, ridge_beta: f64) -> Result<(), ArrowSchurError> {
if self.hbb.dim() != (self.k, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur requires a dense beta block accumulator".to_string(),
});
}
self.s_acc.assign(&self.hbb);
for j in 0..self.k {
self.s_acc[[j, j]] += ridge_beta;
self.rhs_acc[j] = 0.0;
}
Ok(())
}
pub fn accumulate_chunk(
&mut self,
start: usize,
end: usize,
ridge_t: f64,
mode: ArrowSolverMode,
) -> Result<(), ArrowSchurError> {
if start > end || end > self.n_rows {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming Arrow-Schur chunk [{start}, {end}) outside 0..{}",
self.n_rows
),
});
}
let backend = CpuBatchedBlockSolver;
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
let v = backend.solve_block_vector(factor.view(), row.gt.view());
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..self.k {
self.rhs_acc[a] += htbeta[[c, a]] * vc;
}
}
match mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened = backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
Ok(())
}
pub fn reduced_schur_and_log_det_tt(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(f64, Array2<f64>), ArrowSchurError> {
if self.ibp_cross_row_active {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming arrow log-det cannot carry the exact cross-row IBP \
Woodbury correction (#1038): U's columns span all rows, so the \
rank-R capacitance needs the per-row factors retained — the very \
(N·K) residency the streaming path avoids. Route IBP-active fits \
through the dense resident ArrowFactorCache::arrow_log_det instead."
.to_string(),
});
}
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
let backend = CpuBatchedBlockSolver;
let mut log_det_tt = 0.0_f64;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let htbeta = self.row_htbeta(row_idx, &row, di);
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
for axis in 0..di {
log_det_tt += 2.0 * factor[[axis, axis]].ln();
}
match options.mode {
ArrowSolverMode::Direct | ArrowSolverMode::InexactPCG => {
let solved = backend.solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &htbeta, &solved);
}
ArrowSolverMode::SqrtBA => {
let whitened =
backend.sqrt_solve_block_matrix(factor.view(), htbeta.view());
backend.block_gemm_subtract(&mut self.s_acc, &whitened, &whitened);
}
}
}
}
symmetrize_upper_from_lower(&mut self.s_acc);
let schur = std::mem::replace(&mut self.s_acc, Array2::<f64>::zeros((self.k, self.k)));
Ok((log_det_tt, schur))
}
pub fn reduced_schur_log_det(
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let rhs = Array1::<f64>::zeros(schur.nrows());
let trust_metric_weights = None;
let (delta, schur_factor, diag) =
solve_dense_reduced_system(schur, &rhs, options, trust_metric_weights)?;
if delta.len() != schur.nrows() || diag.iterations != 0 {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det reduced solve returned incoherent diagnostics"
.to_string(),
});
}
let schur_factor = schur_factor.ok_or_else(|| ArrowSchurError::SchurFactorFailed {
reason: "streaming log-det requires a dense reduced Schur factor".to_string(),
})?;
let mut log_det_schur = 0.0_f64;
for axis in 0..schur_factor.nrows() {
log_det_schur += 2.0 * schur_factor[[axis, axis]].ln();
}
Ok(log_det_schur)
}
pub fn exact_arrow_log_det(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<f64, ArrowSchurError> {
let (log_det_tt, schur) =
self.reduced_schur_and_log_det_tt(ridge_t, ridge_beta, options)?;
Ok(log_det_tt + Self::reduced_schur_log_det(&schur, options)?)
}
pub fn solve(
&mut self,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, Option<Array2<f64>>), ArrowSchurError> {
if self.ibp_cross_row_active {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming arrow solve cannot carry the exact cross-row IBP \
Woodbury correction (#1038); route IBP-active fits through the \
dense resident solve_arrow_newton_step_with_options instead."
.to_string(),
});
}
self.tolerate_ill_conditioning = options.tolerate_ill_conditioning;
self.reset_accumulator(ridge_beta)?;
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
self.accumulate_chunk(start, end, ridge_t, options.mode)?;
}
for j in 0..self.k {
self.rhs_acc[j] -= self.gb[j];
}
symmetrize_upper_from_lower(&mut self.s_acc);
let trust_metric_weights = None;
let (delta_beta, schur_factor, _diag) =
solve_dense_reduced_system(&self.s_acc, &self.rhs_acc, options, trust_metric_weights)?;
let delta_t = self.back_substitute(ridge_t, delta_beta.view())?;
Ok((delta_t, delta_beta, schur_factor))
}
fn back_substitute(
&self,
ridge_t: f64,
delta_beta: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let total_len = self.row_offsets[self.n_rows];
let mut delta_t = Array1::<f64>::zeros(total_len);
let mut rhs = Array1::<f64>::zeros(self.d);
for start in (0..self.n_rows).step_by(self.chunk_size) {
let end = (start + self.chunk_size).min(self.n_rows);
for row_idx in start..end {
let row = (self.row_builder)(row_idx)?;
let di = row.htt.nrows();
self.validate_row(row_idx, &row)?;
let factor =
factor_one_row(&row, ridge_t, di, row_idx, self.tolerate_ill_conditioning)?;
let mut htbeta_delta = Array1::<f64>::zeros(di);
if let Some(op) = self.htbeta_matvec.as_ref() {
op(row_idx, delta_beta, &mut htbeta_delta);
} else {
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..self.k {
acc += row.htbeta[[c, a]] * delta_beta[a];
}
htbeta_delta[c] = acc;
}
}
for c in 0..di {
rhs[c] = row.gt[c] + htbeta_delta[c];
}
let dt_i = backend.solve_block_vector(factor.view(), rhs.view());
let row_base = self.row_offsets[row_idx];
for c in 0..di {
delta_t[row_base + c] = -dt_i[c];
}
}
}
Ok(delta_t)
}
fn validate_row(&self, row_idx: usize, row: &ArrowRowBlock) -> Result<(), ArrowSchurError> {
let expected_di = if row_idx < self.row_dims.len() {
self.row_dims[row_idx]
} else {
self.d
};
let actual_di = row.htt.nrows();
if actual_di != expected_di || row.htt.ncols() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!(
"streaming row H_tt shape {:?} != ({expected_di}, {expected_di})",
row.htt.dim(),
),
});
}
if self.htbeta_matvec.is_none() && row.htbeta.dim() != (expected_di, self.k) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"streaming row H_tβ shape {:?} != ({expected_di}, {})",
row.htbeta.dim(),
self.k
),
});
}
if row.gt.len() != expected_di {
return Err(ArrowSchurError::PerRowFactorFailed {
row: row_idx,
reason: format!("streaming row g_t length {} != {expected_di}", row.gt.len()),
});
}
Ok::<(), _>(())
}
}
fn apply_analytic_penalty<S, G, D, P, H>(
penalty: &AnalyticPenaltyKind,
target: ArrayView1<'_, f64>,
rho_local: ArrayView1<'_, f64>,
expected_target_len: usize,
hvp_columns: usize,
scatter_target: &mut S,
mut grad_scatter: G,
mut diag_scatter: D,
seed_hvp_probe: P,
mut hvp_column_scatter: H,
) where
G: FnMut(&mut S, usize, f64),
D: FnMut(&mut S, usize, f64),
P: Fn(usize, &mut Array1<f64>),
H: for<'a> FnMut(&mut S, usize, ArrayView1<'a, f64>),
{
assert_eq!(target.len(), expected_target_len);
let grad = penalty.grad_target(target, rho_local);
for index in 0..expected_target_len {
grad_scatter(scatter_target, index, grad[index]);
}
if let Some(diag) = penalty.psd_majorizer_diag(target, rho_local) {
assert_eq!(diag.len(), expected_target_len);
for index in 0..expected_target_len {
diag_scatter(scatter_target, index, diag[index]);
}
return;
}
let mut probe = Array1::<f64>::zeros(expected_target_len);
for column in 0..hvp_columns {
probe.fill(0.0);
seed_hvp_probe(column, &mut probe);
let hv = penalty.psd_majorizer_hvp(target, rho_local, probe.view());
hvp_column_scatter(scatter_target, column, hv.view());
}
}
fn analytic_penalty_is_row_block_diagonal(penalty: &AnalyticPenaltyKind) -> bool {
penalty.is_row_block_diagonal()
}
#[derive(Clone)]
pub struct ArrowFactorSlab {
data: Arc<[f64]>,
offsets: Arc<[usize]>,
dims: Arc<[usize]>,
}
impl ArrowFactorSlab {
pub fn from_blocks(blocks: Vec<Array2<f64>>) -> Self {
let mut data = Vec::new();
let mut offsets = Vec::with_capacity(blocks.len() + 1);
let mut dims = Vec::with_capacity(blocks.len());
offsets.push(0);
for block in blocks {
let (rows, cols) = block.dim();
assert_eq!(rows, cols, "ArrowFactorSlab stores square row factors");
dims.push(rows);
data.extend(block.iter().copied());
offsets.push(data.len());
}
Self {
data: data.into(),
offsets: offsets.into(),
dims: dims.into(),
}
}
pub fn len(&self) -> usize {
self.dims.len()
}
pub fn is_empty(&self) -> bool {
self.dims.is_empty()
}
pub fn factor(&self, row: usize) -> ArrayView2<'_, f64> {
let dim = self.dims[row];
let range = self.offsets[row]..self.offsets[row + 1];
ArrayView2::from_shape((dim, dim), &self.data[range])
.expect("ArrowFactorSlab row offset/dim invariant violated")
}
pub fn iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.len()).map(|row| self.factor(row))
}
}
impl std::fmt::Debug for ArrowFactorSlab {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowFactorSlab")
.field("rows", &self.len())
.field("values", &self.data.len())
.finish()
}
}
#[derive(Clone)]
pub enum ArrowUndampedFactors {
SameAsDamped,
Owned(ArrowFactorSlab),
}
impl std::fmt::Debug for ArrowUndampedFactors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SameAsDamped => f.write_str("SameAsDamped"),
Self::Owned(factors) => f.debug_tuple("Owned").field(&factors.len()).finish(),
}
}
}
fn sys_htbeta_apply_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
x: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
out.fill(0.0);
if let Some(op) = sys.htbeta_matvec.as_ref() {
op(row_idx, x, out);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (out.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let mut acc = 0.0_f64;
for a in 0..sys.k {
acc += row.htbeta[[c, a]] * x[a];
}
out[c] += acc;
}
}
}
fn sys_htbeta_accumulate_transpose(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) {
if let Some(op) = sys.htbeta_matvec.as_ref() {
htbeta_probe_transpose(row_idx, op, v, out, v.len(), sys.k);
}
if (sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none())
&& row.htbeta.dim() == (v.len(), sys.k)
{
let di = row.htbeta.nrows();
for c in 0..di {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..sys.k {
out[a] += row.htbeta[[c, a]] * vc;
}
}
}
}
fn sys_htbeta_materialize_row(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
) -> Array2<f64> {
let di = sys.row_dims[row_idx];
let k = sys.k;
let use_dense = sys.htbeta_dense_supplement || sys.htbeta_matvec.is_none();
let mut mat = if use_dense && row.htbeta.dim() == (di, k) {
row.htbeta.clone()
} else {
Array2::<f64>::zeros((di, k))
};
if let Some(op) = sys.htbeta_matvec.as_ref() {
let mut e_a = Array1::<f64>::zeros(k);
let mut col = Array1::<f64>::zeros(di);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col.fill(0.0);
op(row_idx, e_a.view(), &mut col);
for c in 0..di {
mat[[c, a]] += col[c];
}
}
} else if use_dense && row.htbeta.dim() != (di, k) {
panic!(
"row {row_idx}: htbeta shape {:?} != ({di}, {k}) and no htbeta_matvec installed",
row.htbeta.dim()
);
}
mat
}
fn htbeta_probe_transpose(
row: usize,
op: &RowHtbetaMatvec,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
) {
let mut e_a = Array1::<f64>::zeros(k);
let mut col_a = Array1::<f64>::zeros(d);
for a in 0..k {
e_a.fill(0.0);
e_a[a] = 1.0;
col_a.fill(0.0);
op(row, e_a.view(), &mut col_a);
let mut acc = 0.0_f64;
for c in 0..d {
acc += col_a[c] * v[c];
}
out[a] += acc;
}
}
#[derive(Clone)]
pub enum ArrowHtbetaCache {
Dense {
blocks: Arc<[Array2<f64>]>,
estimated_bytes: usize,
},
Matvec {
op: RowHtbetaMatvec,
estimated_bytes: usize,
},
Disabled {
estimated_bytes: usize,
},
}
impl std::fmt::Debug for ArrowHtbetaCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dense {
blocks,
estimated_bytes,
} => f
.debug_struct("Dense")
.field("blocks", &blocks.len())
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Matvec {
estimated_bytes, ..
} => f
.debug_struct("Matvec")
.field("estimated_bytes", estimated_bytes)
.finish(),
Self::Disabled { estimated_bytes } => f
.debug_struct("Disabled")
.field("estimated_bytes", estimated_bytes)
.finish(),
}
}
}
impl ArrowHtbetaCache {
fn is_available(&self) -> bool {
!matches!(self, Self::Disabled { .. })
}
fn apply_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.ncols() != delta_beta.len() || block.nrows() != out.len() {
return false;
}
for c in 0..block.nrows() {
let mut acc = 0.0_f64;
for a in 0..block.ncols() {
acc += block[[c, a]] * delta_beta[a];
}
out[c] = acc;
}
true
}
Self::Matvec { op, .. } => {
op(row, delta_beta, out);
true
}
Self::Disabled { .. } => false,
}
}
fn apply_row_transpose_accumulate(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
d: usize,
k: usize,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
match self {
Self::Dense { blocks, .. } => {
let Some(block) = blocks.get(row) else {
return false;
};
if block.nrows() != v.len() || block.ncols() != out.len() {
return false;
}
for c in 0..block.nrows() {
let vc = v[c];
if vc == 0.0 {
continue;
}
for a in 0..block.ncols() {
out[a] += block[[c, a]] * vc;
}
}
true
}
Self::Matvec { op, .. } => {
htbeta_probe_transpose(row, op, v, out, d, k);
true
}
Self::Disabled { .. } => {
if let Some(op) = fallback_op {
htbeta_probe_transpose(row, op, v, out, d, k);
true
} else {
false
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowFactorCache {
pub htt_factors: ArrowFactorSlab,
pub htt_factors_undamped: ArrowUndampedFactors,
pub schur_factor: Option<Array2<f64>>,
pub solver_mode: ArrowSolverMode,
pub ridge_t: f64,
pub ridge_beta: f64,
pub htbeta: ArrowHtbetaCache,
pub d: usize,
pub row_dims: Arc<[usize]>,
pub row_offsets: Arc<[usize]>,
pub k: usize,
pub manifold_mode_fingerprint: u64,
pub row_hessian_fingerprint: u64,
pub pcg_diagnostics: PcgDiagnostics,
pub gauge_deflated_directions: usize,
pub cross_row_woodbury: Option<CrossRowWoodbury>,
}
#[derive(Debug, Clone)]
pub struct CrossRowWoodbury {
pub u: Array2<f64>,
pub d: Array1<f64>,
pub h0inv_u: Array2<f64>,
pub h0inv_u_beta: Array2<f64>,
pub m: Array2<f64>,
pub capacitance_lu: SmallLu,
pub entries: Vec<(usize, usize, f64)>,
}
#[derive(Debug, Clone)]
pub struct SmallLu {
lu: Array2<f64>,
piv: Vec<usize>,
perm_sign: f64,
}
fn small_lu_factor(a: &Array2<f64>) -> Option<SmallLu> {
let r = a.nrows();
assert_eq!(a.ncols(), r, "small_lu_factor: non-square input");
let mut lu = a.clone();
let mut piv: Vec<usize> = (0..r).collect();
let mut perm_sign = 1.0_f64;
for col in 0..r {
let mut pivot_row = col;
let mut pivot_mag = lu[[col, col]].abs();
for row in (col + 1)..r {
let mag = lu[[row, col]].abs();
if mag > pivot_mag {
pivot_mag = mag;
pivot_row = row;
}
}
if !pivot_mag.is_finite() || pivot_mag < f64::MIN_POSITIVE {
return None;
}
if pivot_row != col {
for c in 0..r {
lu.swap((col, c), (pivot_row, c));
}
piv.swap(col, pivot_row);
perm_sign = -perm_sign;
}
let pivot = lu[[col, col]];
for row in (col + 1)..r {
let factor = lu[[row, col]] / pivot;
lu[[row, col]] = factor;
for c in (col + 1)..r {
let v = lu[[col, c]];
lu[[row, c]] -= factor * v;
}
}
}
for i in 0..r {
let u = lu[[i, i]];
if !u.is_finite() || u.abs() < f64::MIN_POSITIVE {
return None;
}
}
Some(SmallLu { lu, piv, perm_sign })
}
impl SmallLu {
fn dim(&self) -> usize {
self.lu.nrows()
}
fn log_abs_det_and_sign(&self) -> (f64, f64) {
let mut log_abs = 0.0_f64;
let mut sign = self.perm_sign;
for i in 0..self.dim() {
let u = self.lu[[i, i]];
log_abs += u.abs().ln();
if u < 0.0 {
sign = -sign;
}
}
(log_abs, sign)
}
fn solve(&self, b: &Array1<f64>) -> Option<Array1<f64>> {
let r = self.dim();
let mut y = Array1::<f64>::zeros(r);
for i in 0..r {
y[i] = b[self.piv[i]];
}
for i in 0..r {
let mut sum = y[i];
for j in 0..i {
sum -= self.lu[[i, j]] * y[j];
}
y[i] = sum;
}
let mut x = Array1::<f64>::zeros(r);
for i in (0..r).rev() {
let mut sum = y[i];
for j in (i + 1)..r {
sum -= self.lu[[i, j]] * x[j];
}
let pivot = self.lu[[i, i]];
if !pivot.is_finite() || pivot.abs() < f64::MIN_POSITIVE {
return None;
}
x[i] = sum / pivot;
}
if x.iter().all(|v| v.is_finite()) {
Some(x)
} else {
None
}
}
}
impl CrossRowWoodbury {
fn build(
cache: &ArrowFactorCache,
source: &IbpCrossRowSource,
) -> Result<Option<Self>, ArrowSchurError> {
let r = source.r;
let total_len = cache.delta_t_len();
let u = source.dense_u(total_len);
let d = source.d.clone();
let zero_beta = Array1::<f64>::zeros(cache.k);
let mut h0inv_u = Array2::<f64>::zeros((total_len, r));
let mut h0inv_u_beta = Array2::<f64>::zeros((cache.k, r));
for k in 0..r {
let col = u.column(k).to_owned();
let (sol_t, sol_beta) = cache.full_inverse_apply(col.view(), zero_beta.view())?;
for g in 0..total_len {
h0inv_u[[g, k]] = sol_t[g];
}
for c in 0..cache.k {
h0inv_u_beta[[c, k]] = sol_beta[c];
}
}
let mut m = Array2::<f64>::zeros((r, r));
for a in 0..r {
for b in 0..r {
let mut acc = 0.0_f64;
for &(g, k, z) in &source.entries {
if k == a {
acc += z * h0inv_u[[g, b]];
}
}
m[[a, b]] = acc;
}
}
for a in 0..r {
for b in (a + 1)..r {
let avg = 0.5 * (m[[a, b]] + m[[b, a]]);
m[[a, b]] = avg;
m[[b, a]] = avg;
}
}
let mut c = Array2::<f64>::zeros((r, r));
for a in 0..r {
for b in 0..r {
c[[a, b]] = d[a] * m[[a, b]];
}
c[[a, a]] += 1.0;
}
let Some(capacitance_lu) = small_lu_factor(&c) else {
return Ok(None);
};
Ok(Some(Self {
u,
d,
h0inv_u,
h0inv_u_beta,
m,
capacitance_lu,
entries: source.entries.clone(),
}))
}
fn source_entries(&self) -> &[(usize, usize, f64)] {
&self.entries
}
pub fn capacitance_inv_times_d(&self) -> Option<Array2<f64>> {
let r = self.d.len();
let mut out = Array2::<f64>::zeros((r, r));
let mut e_l = Array1::<f64>::zeros(r);
for l in 0..r {
e_l.fill(0.0);
e_l[l] = 1.0;
let col = self.capacitance_lu.solve(&e_l)?;
for k in 0..r {
out[[k, l]] = col[k] * self.d[l];
}
}
Some(out)
}
fn subtract_inverse_diagonal(
&self,
diag: &mut Array1<f64>,
) -> Result<(), ArrowSchurError> {
let r = self.d.len();
let cinv_d = self.capacitance_inv_times_d().ok_or_else(|| {
ArrowSchurError::SchurFactorFailed {
reason: "cross-row Woodbury capacitance solve produced a non-finite \
C⁻¹D for the inverse-diagonal correction (#1038): \
singular/ill-conditioned cross-row capacitance"
.to_string(),
}
})?;
let total_len = self.h0inv_u.nrows();
for g in 0..total_len {
let mut acc = 0.0_f64;
for k in 0..r {
let gk = self.h0inv_u[[g, k]];
if gk == 0.0 {
continue;
}
for l in 0..r {
acc += gk * cinv_d[[k, l]] * self.h0inv_u[[g, l]];
}
}
diag[g] -= acc;
}
Ok(())
}
pub fn log_det(&self) -> Option<f64> {
let (log_abs, sign) = self.log_det_correction();
if sign > 0.0 { Some(log_abs) } else { None }
}
fn log_det_correction(&self) -> (f64, f64) {
self.capacitance_lu.log_abs_det_and_sign()
}
fn apply_inverse_correction(
&self,
h0inv_rhs_t: ArrayView1<'_, f64>,
entries: &[(usize, usize, f64)],
u_t: &mut Array1<f64>,
u_beta: &mut Array1<f64>,
) -> Result<(), ArrowSchurError> {
let r = self.d.len();
let mut p = Array1::<f64>::zeros(r);
for &(g, k, z) in entries {
p[k] += z * h0inv_rhs_t[g];
}
for k in 0..r {
p[k] *= self.d[k];
}
let q = self.capacitance_lu.solve(&p).ok_or_else(|| {
ArrowSchurError::SchurFactorFailed {
reason: "cross-row Woodbury capacitance solve produced a non-finite \
C⁻¹p for the inverse correction (#1038): \
singular/ill-conditioned cross-row capacitance"
.to_string(),
}
})?;
for g in 0..u_t.len() {
let mut acc = 0.0_f64;
for k in 0..r {
acc += self.h0inv_u[[g, k]] * q[k];
}
u_t[g] -= acc;
}
for c in 0..u_beta.len() {
let mut acc = 0.0_f64;
for k in 0..r {
acc += self.h0inv_u_beta[[c, k]] * q[k];
}
u_beta[c] -= acc;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ArrowFactorMinPivot {
pub min_row_pivot: Option<f64>,
pub min_schur_pivot: Option<f64>,
pub min_pivot: Option<f64>,
}
impl ArrowFactorMinPivot {
fn combine(row: Option<f64>, schur: Option<f64>) -> Self {
let min_pivot = match (row, schur) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
Self {
min_row_pivot: row,
min_schur_pivot: schur,
min_pivot,
}
}
}
fn lower_cholesky_min_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
out
}
fn lower_cholesky_max_pivot(factor: ArrayView2<'_, f64>) -> Option<f64> {
let width = factor.nrows().min(factor.ncols());
let mut out = None;
for idx in 0..width {
let pivot = factor[[idx, idx]] * factor[[idx, idx]];
out = Some(match out {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
out
}
pub fn arrow_factor_min_pivot(cache: &ArrowFactorCache) -> ArrowFactorMinPivot {
let mut min_row_pivot = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_min_pivot(factor) {
min_row_pivot = Some(match min_row_pivot {
Some(current) => f64::min(current, pivot),
None => pivot,
});
}
}
let min_schur_pivot = cache
.schur_factor
.as_ref()
.and_then(|factor| lower_cholesky_min_pivot(factor.view()));
ArrowFactorMinPivot::combine(min_row_pivot, min_schur_pivot)
}
pub fn arrow_factor_max_pivot(cache: &ArrowFactorCache) -> Option<f64> {
let mut max_pivot: Option<f64> = None;
for factor in cache.htt_factors.iter() {
if let Some(pivot) = lower_cholesky_max_pivot(factor) {
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
}
if let Some(factor) = cache.schur_factor.as_ref()
&& let Some(pivot) = lower_cholesky_max_pivot(factor.view())
{
max_pivot = Some(match max_pivot {
Some(current) => f64::max(current, pivot),
None => pivot,
});
}
max_pivot
}
impl ArrowFactorCache {
pub fn n_rows(&self) -> usize {
self.htt_factors.len()
}
pub fn htbeta_available(&self) -> bool {
self.htbeta.is_available()
}
pub fn undamped_factor(&self, row: usize) -> ArrayView2<'_, f64> {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.factor(row),
ArrowUndampedFactors::Owned(factors) => factors.factor(row),
}
}
pub fn undamped_factor_count(&self) -> usize {
match &self.htt_factors_undamped {
ArrowUndampedFactors::SameAsDamped => self.htt_factors.len(),
ArrowUndampedFactors::Owned(factors) => factors.len(),
}
}
pub fn undamped_factors_iter(&self) -> impl Iterator<Item = ArrayView2<'_, f64>> + '_ {
(0..self.undamped_factor_count()).map(|row| self.undamped_factor(row))
}
pub fn delta_t_len(&self) -> usize {
self.row_offsets[self.n_rows()]
}
pub fn apply_htbeta_row(
&self,
row: usize,
delta_beta: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if out.len() != di || delta_beta.len() != self.k {
return false;
}
self.htbeta.apply_row(row, delta_beta, out)
}
pub fn apply_htbeta_row_transpose(
&self,
row: usize,
v: ArrayView1<'_, f64>,
out: &mut Array1<f64>,
fallback_op: Option<&RowHtbetaMatvec>,
) -> bool {
let di = if row < self.row_dims.len() {
self.row_dims[row]
} else {
self.d
};
if v.len() != di || out.len() != self.k {
return false;
}
self.htbeta
.apply_row_transpose_accumulate(row, v, out, di, self.k, fallback_op)
}
pub fn arrow_log_det(&self) -> (f64, Option<f64>) {
let mut log_det_tt = 0.0_f64;
for l in self.htt_factors.iter() {
for i in 0..l.nrows() {
log_det_tt += l[[i, i]].ln();
}
}
log_det_tt *= 2.0;
let log_det_schur = self.schur_factor.as_ref().map(|l| {
let mut s = 0.0_f64;
for i in 0..l.nrows() {
s += l[[i, i]].ln();
}
2.0 * s + self.cross_row_woodbury_log_det()
});
(log_det_tt, log_det_schur)
}
pub fn cross_row_woodbury_log_det(&self) -> f64 {
match self.cross_row_woodbury.as_ref() {
Some(w) => w.log_det().unwrap_or(f64::NAN),
None => 0.0,
}
}
pub fn latent_block_inverse_diagonal(&self) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if !self.htbeta_available() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "latent_block_inverse_diagonal requires the H_tβ coupling, \
but this cache's htbeta is Disabled"
.to_string(),
});
}
let n = self.undamped_factor_count();
let total_len = self.delta_t_len();
let mut out = Array1::<f64>::zeros(total_len);
let mut e_j = Array1::<f64>::zeros(self.d);
let mut w = Array1::<f64>::zeros(self.k);
for i in 0..n {
let di = self.row_dims[i];
let row_base = self.row_offsets[i];
let factor = self.undamped_factor(i);
for j in 0..di {
for c in 0..di {
e_j[c] = 0.0;
}
e_j[j] = 1.0;
let e_j_slice = e_j.slice(ndarray::s![..di]).to_owned();
let a = cholesky_solve_vector(factor, &e_j_slice);
w.fill(0.0);
if !self.apply_htbeta_row_transpose(i, a.view(), &mut w, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"latent_block_inverse_diagonal: H_βt^({i}) apply failed \
(htbeta cache could not supply row {i})"
),
});
}
let z = cholesky_solve_vector(schur_factor, &w);
let mut corr = 0.0_f64;
for c in 0..self.k {
corr += w[c] * z[c];
}
out[row_base + j] = a[j] + corr;
}
}
if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
woodbury.subtract_inverse_diagonal(&mut out)?;
}
Ok(out)
}
pub fn full_inverse_apply(
&self,
w_t: ArrayView1<'_, f64>,
w_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let (mut u_t, mut u_beta) = self.full_inverse_apply_base(w_t, w_beta)?;
if let Some(woodbury) = self.cross_row_woodbury.as_ref() {
let h0inv_w_t = u_t.clone();
woodbury.apply_inverse_correction(
h0inv_w_t.view(),
woodbury.source_entries(),
&mut u_t,
&mut u_beta,
)?;
}
Ok((u_t, u_beta))
}
fn full_inverse_apply_base(
&self,
w_t: ArrayView1<'_, f64>,
w_beta: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>), ArrowSchurError> {
let total_len = self.delta_t_len();
if w_t.len() != total_len || w_beta.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: rhs shapes (w_t={}, w_beta={}) != (delta_t_len={}, K={})",
w_t.len(),
w_beta.len(),
total_len,
self.k
),
});
}
let n = self.undamped_factor_count();
let mut y = Array1::<f64>::zeros(total_len);
let mut r_beta = w_beta.to_owned();
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let factor = self.undamped_factor(i);
let w_row = w_t.slice(ndarray::s![base..base + di]).to_owned();
let y_row = cholesky_solve_vector(factor, &w_row);
if self.k > 0 {
let mut acc = Array1::<f64>::zeros(self.k);
if !self.apply_htbeta_row_transpose(i, y_row.view(), &mut acc, None) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_βt^({i}) apply failed (htbeta cache \
could not supply row {i})"
),
});
}
for c in 0..self.k {
r_beta[c] -= acc[c];
}
}
for j in 0..di {
y[base + j] = y_row[j];
}
}
let u_beta = if self.k > 0 {
self.schur_inverse_apply(r_beta.view())?
} else {
Array1::<f64>::zeros(0)
};
let mut u_t = y;
if self.k > 0 {
let mut cross = Array1::<f64>::zeros(self.d);
for i in 0..n {
let di = self.row_dims[i];
let base = self.row_offsets[i];
let mut cross_row = cross.slice_mut(ndarray::s![..di]);
cross_row.fill(0.0);
let mut cross_owned = cross_row.to_owned();
if !self.apply_htbeta_row(i, u_beta.view(), &mut cross_owned) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"full_inverse_apply: H_tβ^({i}) apply failed (htbeta cache \
could not supply row {i})"
),
});
}
let factor = self.undamped_factor(i);
let corr = cholesky_solve_vector(factor, &cross_owned);
for j in 0..di {
u_t[base + j] -= corr[j];
}
}
}
Ok((u_t, u_beta))
}
pub fn schur_inverse_apply(
&self,
rhs: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_apply requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if rhs.len() != self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_apply: rhs length {} != K {}",
rhs.len(),
self.k
),
});
}
let rhs_owned = rhs.to_owned();
Ok(cholesky_solve_vector(schur_factor, &rhs_owned))
}
pub fn schur_inverse_block(
&self,
block: std::ops::Range<usize>,
) -> Result<Array2<f64>, ArrowSchurError> {
let Some(schur_factor) = self.schur_factor.as_ref() else {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "schur_inverse_block requires a dense Schur factor; \
the InexactPCG mode does not form one"
.to_string(),
});
};
if block.end > self.k {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"schur_inverse_block: block end {} exceeds K {}",
block.end, self.k
),
});
}
let w = block.len();
let mut out = Array2::<f64>::zeros((w, w));
let mut e_j = Array1::<f64>::zeros(self.k);
for (jc, j) in block.clone().enumerate() {
e_j.fill(0.0);
e_j[j] = 1.0;
let col = cholesky_solve_vector(schur_factor, &e_j);
for (ic, i) in block.clone().enumerate() {
out[[ic, jc]] = col[i];
}
}
for ic in 0..w {
for jc in (ic + 1)..w {
let avg = 0.5 * (out[[ic, jc]] + out[[jc, ic]]);
out[[ic, jc]] = avg;
out[[jc, ic]] = avg;
}
}
Ok(out)
}
}
pub fn solve_arrow_newton_step_with_options(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, ArrowFactorCache), ArrowSchurError> {
if options.streaming_chunk_size.is_some() {
return Err(ArrowSchurError::SchurFactorFailed {
reason: "streaming Arrow-Schur solve does not materialize the factor cache required by this entry point".to_string(),
});
}
let downdated_owner;
let (sys, ibp_source): (&ArrowSchurSystem, Option<&IbpCrossRowSource>) =
match sys.ibp_cross_row.as_ref() {
Some(source) => {
let mut downdated = sys.clone();
let total_len = downdated.row_offsets[downdated.rows.len()];
let down = source.self_term_downdate(total_len);
let offsets = Arc::clone(&downdated.row_offsets);
for (i, row) in downdated.rows.iter_mut().enumerate() {
let base = offsets[i];
let di = row.htt.nrows();
for j in 0..di {
row.htt[[j, j]] -= down[base + j];
}
}
downdated.refresh_row_hessian_fingerprint();
downdated_owner = downdated;
(&downdated_owner, Some(source))
}
None => (sys, None),
};
let step = solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)?;
let backend = CpuBatchedBlockSolver;
let htbeta_estimated_bytes =
estimated_htbeta_bytes(sys.rows.len(), sys.d, sys.k).unwrap_or(usize::MAX);
let htbeta = if let Some(op) = sys.htbeta_matvec.as_ref() {
ArrowHtbetaCache::Matvec {
op: Arc::clone(op),
estimated_bytes: htbeta_estimated_bytes,
}
} else if htbeta_estimated_bytes <= ARROW_FACTOR_CACHE_HTBETA_BUDGET_BYTES {
ArrowHtbetaCache::Dense {
blocks: sys
.rows
.iter()
.map(|r| r.htbeta.clone())
.collect::<Vec<_>>()
.into(),
estimated_bytes: htbeta_estimated_bytes,
}
} else {
ArrowHtbetaCache::Disabled {
estimated_bytes: htbeta_estimated_bytes,
}
};
let htt_factors = step.htt_factors;
let (htt_factors_undamped, gauge_deflated_directions) = if ridge_t == 0.0 {
(
ArrowUndampedFactors::SameAsDamped,
step.gauge_deflated_directions,
)
} else {
let undamped = factor_blocks_for_system(sys, 0.0, options, &backend)?;
(
ArrowUndampedFactors::Owned(undamped.factors),
undamped.gauge_deflated_directions,
)
};
let mut cache = ArrowFactorCache {
htt_factors,
htt_factors_undamped,
schur_factor: step.schur_factor,
solver_mode: options.mode,
ridge_t,
ridge_beta,
htbeta,
d: sys.d,
row_dims: Arc::clone(&sys.row_dims),
row_offsets: Arc::clone(&sys.row_offsets),
k: sys.k,
manifold_mode_fingerprint: sys.manifold_mode_fingerprint,
row_hessian_fingerprint: sys.current_row_hessian_fingerprint(),
pcg_diagnostics: step.pcg_diagnostics,
gauge_deflated_directions,
cross_row_woodbury: None,
};
let mut delta_t = step.delta_t;
let mut delta_beta = step.delta_beta;
if let Some(source) = ibp_source {
if let Some(woodbury) = CrossRowWoodbury::build(&cache, source)? {
let h0inv_neg_g_t = delta_t.clone();
woodbury.apply_inverse_correction(
h0inv_neg_g_t.view(),
&source.entries,
&mut delta_t,
&mut delta_beta,
)?;
cache.cross_row_woodbury = Some(woodbury);
}
}
Ok((delta_t, delta_beta, cache))
}
fn estimated_htbeta_bytes(n: usize, d: usize, k: usize) -> Option<usize> {
n.checked_mul(d)?
.checked_mul(k)?
.checked_mul(std::mem::size_of::<f64>())
}
pub fn solve_arrow_newton_step_core(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
if let Some(chunk_size) = options.streaming_chunk_size {
let streaming_options = options.with_streaming_mixed_precision_default();
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
return streaming
.solve(ridge_t, ridge_beta, &streaming_options)
.map(|(delta_t, delta_beta, _)| (delta_t, delta_beta, PcgDiagnostics::default()));
}
if let Some(device_step) = try_device_arrow_direct(sys, ridge_t, ridge_beta, options) {
return device_step;
}
if let Some(device_options) = maybe_inject_gpu_schur_matvec(sys, ridge_t, ridge_beta, options) {
return solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, &device_options).map(
|step| {
let mut diagnostics = step.pcg_diagnostics;
diagnostics.used_device_arrow = true;
(step.delta_t, step.delta_beta, diagnostics)
},
);
}
solve_arrow_newton_step_artifacts(sys, ridge_t, ridge_beta, options)
.map(|step| (step.delta_t, step.delta_beta, step.pcg_diagnostics))
}
fn maybe_inject_gpu_schur_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<ArrowSolveOptions> {
if options.mode != ArrowSolverMode::InexactPCG || options.gpu_matvec.is_some() {
return None;
}
if !sys.cross_row_penalties.is_empty() || options.streaming_chunk_size.is_some() {
return None;
}
let runtime = crate::gpu::runtime::GpuRuntime::global()?;
let cg_iters = options
.pcg
.max_iterations
.min(options.trust_region.max_iterations);
if !runtime
.policy()
.reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d, cg_iters)
{
return None;
}
let matvec =
crate::gpu::arrow_schur::gpu_schur_matvec_backend(sys, ridge_t, ridge_beta).ok()?;
let mut device_options = options.clone();
device_options.gpu_matvec = Some(matvec);
Some(device_options)
}
fn try_device_arrow_direct(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Option<Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError>> {
if options.mode != ArrowSolverMode::Direct {
return None;
}
if !sys.cross_row_penalties.is_empty()
|| options.streaming_chunk_size.is_some()
|| sys.hbb_matvec.is_some()
|| sys.htbeta_matvec.is_some()
{
return None;
}
let runtime = crate::gpu::runtime::GpuRuntime::global()?;
let admitted = runtime
.policy()
.dense_hessian_work_target_is_gpu(sys.rows.len(), sys.k);
if !admitted {
return None;
}
match crate::gpu::arrow_schur::solve_arrow_newton_step(sys, ridge_t, ridge_beta) {
Ok(solution) => {
let diagnostics = PcgDiagnostics {
used_device_arrow: true,
..PcgDiagnostics::default()
};
Some(Ok((solution.delta_t, solution.delta_beta, diagnostics)))
}
Err(crate::gpu::arrow_schur::ArrowSchurGpuFailure::RidgeBumpRequired { row, bump }) => {
Some(Err(ArrowSchurError::PerRowFactorFailed {
row,
reason: format!("device per-row block non-PD; suggested ridge bump {bump:e}"),
}))
}
Err(crate::gpu::arrow_schur::ArrowSchurGpuFailure::SchurFactorFailed { reason }) => {
Some(Err(ArrowSchurError::SchurFactorFailed { reason }))
}
Err(_) => None,
}
}
pub fn solve_with_lm_escalation_inner(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<(Array1<f64>, Array1<f64>, PcgDiagnostics), ArrowSchurError> {
let mut proximal_ridge = 0.0_f64;
let mut escalations: usize = 0;
let mut last_err: Option<ArrowSchurError> = None;
for attempt in 0..=DEFAULT_PROXIMAL_MAX_ATTEMPTS {
let damped_ridge_t = ridge_t + proximal_ridge;
let damped_ridge_beta = ridge_beta + proximal_ridge;
match solve_arrow_newton_step_artifacts(sys, damped_ridge_t, damped_ridge_beta, options) {
Ok(mut step) => {
step.pcg_diagnostics.ridge_escalations = escalations;
return Ok((step.delta_t, step.delta_beta, step.pcg_diagnostics));
}
Err(err) => {
let recoverable = matches!(
err,
ArrowSchurError::PerRowFactorFailed { .. }
| ArrowSchurError::PerRowFactorIllConditioned { .. }
| ArrowSchurError::SchurFactorFailed { .. }
| ArrowSchurError::PcgFailed { .. }
);
last_err = Some(err);
if !recoverable {
break;
}
if attempt == DEFAULT_PROXIMAL_MAX_ATTEMPTS {
break;
}
proximal_ridge = if proximal_ridge == 0.0 {
DEFAULT_PROXIMAL_INITIAL_RIDGE
} else {
proximal_ridge * DEFAULT_PROXIMAL_RIDGE_GROWTH
};
escalations += 1;
}
}
}
Err(last_err.expect("escalation loop set last_err on failure"))
}
pub fn solve_arrow_newton_step_with_proximal_correction<F>(
sys: &ArrowSchurSystem,
base_ridge_t: f64,
base_ridge_beta: f64,
current_objective_value: f64,
options: &ArrowSolveOptions,
correction: &ArrowProximalCorrectionOptions,
mut trial_objective: F,
) -> Result<ArrowAcceptedProximalStep, ArrowSchurError>
where
F: for<'a, 'b> FnMut(ArrayView1<'a, f64>, ArrayView1<'b, f64>) -> f64,
{
if !current_objective_value.is_finite() {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: "current objective is not finite".to_string(),
});
}
if !(correction.ridge_growth.is_finite() && correction.ridge_growth > 1.0) {
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"ridge_growth must be finite and > 1; got {}",
correction.ridge_growth
),
});
}
if !(correction.armijo_c1.is_finite()
&& correction.armijo_c1 > 0.0
&& correction.armijo_c1 < 1.0)
{
return Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!("armijo_c1 must be in (0, 1); got {}", correction.armijo_c1),
});
}
let grad_norm = arrow_gradient_norm(sys);
if grad_norm <= correction.gradient_tolerance.max(0.0) {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: 0,
});
}
let objective_resolution =
correction.convergence_objective_rel_tol.max(0.0) * (current_objective_value.abs() + 1.0);
let mut proximal_ridge = correction.initial_ridge.max(0.0);
let mut last_reason = String::from("no attempts were made");
let mut best_decrease: Option<(Array1<f64>, Array1<f64>, f64, f64, f64, f64, f64)> = None;
let mut smallest_increase = f64::INFINITY;
for attempt in 0..correction.max_attempts {
let ridge_t = base_ridge_t + proximal_ridge;
let ridge_beta = base_ridge_beta + proximal_ridge;
match solve_arrow_newton_step_core(sys, ridge_t, ridge_beta, options) {
Ok((delta_t, delta_beta, _diag)) => {
let g_dot_p = arrow_gradient_dot_step(sys, delta_t.view(), delta_beta.view());
if !(g_dot_p.is_finite() && g_dot_p < 0.0) {
last_reason =
format!("candidate was not a finite descent direction: g·p={g_dot_p}");
} else {
let trial_value = trial_objective(delta_t.view(), delta_beta.view());
let armijo_bound = current_objective_value + correction.armijo_c1 * g_dot_p;
if trial_value.is_finite() && trial_value <= armijo_bound {
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge,
objective_value: current_objective_value,
trial_objective_value: trial_value,
gradient_dot_step: g_dot_p,
attempts: attempt + 1,
});
}
if trial_value.is_finite() {
let delta_obj = trial_value - current_objective_value;
if delta_obj < -objective_resolution {
let improves = best_decrease.as_ref().is_none_or(
|(_, _, best_value, _, _, _, _)| trial_value < *best_value,
);
if improves {
best_decrease = Some((
delta_t.clone(),
delta_beta.clone(),
trial_value,
g_dot_p,
ridge_t,
ridge_beta,
proximal_ridge,
));
}
} else if delta_obj < smallest_increase {
smallest_increase = delta_obj;
}
}
last_reason = {
let step_norm = (delta_t.iter().map(|v| v * v).sum::<f64>()
+ delta_beta.iter().map(|v| v * v).sum::<f64>())
.sqrt();
format!(
"Armijo rejected trial objective {trial_value}; bound {armijo_bound}; \
|g|={grad_norm:.4e} g.p={g_dot_p:.4e} |step|={step_norm:.4e} ridge={proximal_ridge:.3e}"
)
};
}
}
Err(err) => {
last_reason = err.to_string();
}
}
proximal_ridge = next_proximal_ridge(proximal_ridge, correction.ridge_growth);
}
if let Some((delta_t, delta_beta, trial_value, g_dot_p, ridge_t, ridge_beta, best_ridge)) =
best_decrease
{
let reapplied = trial_objective(delta_t.view(), delta_beta.view());
let final_value = if reapplied.is_finite() {
reapplied
} else {
trial_value
};
return Ok(ArrowAcceptedProximalStep {
delta_t,
delta_beta,
ridge_t,
ridge_beta,
proximal_ridge: best_ridge,
objective_value: current_objective_value,
trial_objective_value: final_value,
gradient_dot_step: g_dot_p,
attempts: correction.max_attempts,
});
}
if smallest_increase.is_finite() && smallest_increase <= objective_resolution {
return Ok(ArrowAcceptedProximalStep {
delta_t: Array1::<f64>::zeros(sys.row_offsets[sys.rows.len()]),
delta_beta: Array1::<f64>::zeros(sys.k),
ridge_t: base_ridge_t,
ridge_beta: base_ridge_beta,
proximal_ridge: 0.0,
objective_value: current_objective_value,
trial_objective_value: current_objective_value,
gradient_dot_step: 0.0,
attempts: correction.max_attempts,
});
}
Err(ArrowSchurError::AdaptiveCorrectionFailed {
reason: format!(
"failed after {} attempts; last rejection: {last_reason}",
correction.max_attempts
),
})
}
pub fn arrow_damped_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let total_len = sys.row_offsets[sys.rows.len()];
assert_eq!(delta_t.len(), total_len);
assert_eq!(delta_beta.len(), sys.k);
let mut lin = sys.gb.dot(&delta_beta);
let mut quad = ridge_beta * delta_beta.dot(&delta_beta);
let mut hbb_delta = Array1::<f64>::zeros(sys.k);
{
let x_slice = delta_beta
.as_slice()
.expect("delta_beta must be contiguous");
let y_slice = hbb_delta
.as_slice_mut()
.expect("hbb_delta must be contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
quad += delta_beta.dot(&hbb_delta);
let mut htbeta_x = Array1::<f64>::zeros(sys.d);
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
let mut htbeta_x_i = htbeta_x.slice_mut(ndarray::s![..di]).to_owned();
htbeta_x_i.fill(0.0);
sys_htbeta_apply_row(sys, i, row, delta_beta, &mut htbeta_x_i);
for c in 0..di {
let dt_c = delta_t[row_base + c];
lin += row.gt[c] * dt_c;
quad += ridge_t * dt_c * dt_c;
for r in 0..di {
quad += dt_c * row.htt[[c, r]] * delta_t[row_base + r];
}
quad += 2.0 * dt_c * htbeta_x_i[c];
}
}
Ok(-(lin + 0.5 * quad))
}
pub fn arrow_bare_quadratic_model_reduction(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
ridge_t: f64,
ridge_beta: f64,
) -> Result<f64, ArrowSchurError> {
let damped =
arrow_damped_quadratic_model_reduction(sys, delta_t, delta_beta, ridge_t, ridge_beta)?;
let ridge_beta_contrib = 0.5 * ridge_beta * delta_beta.dot(&delta_beta);
let ridge_t_contrib = {
let mut acc = 0.0_f64;
for v in delta_t.iter() {
acc += v * v;
}
0.5 * ridge_t * acc
};
Ok(damped + ridge_beta_contrib + ridge_t_contrib)
}
fn next_proximal_ridge(current: f64, growth: f64) -> f64 {
if current > 0.0 {
current * growth
} else {
DEFAULT_PROXIMAL_INITIAL_RIDGE
}
}
fn arrow_gradient_norm(sys: &ArrowSchurSystem) -> f64 {
let mut sum = 0.0;
for row in sys.rows.iter() {
for &v in row.gt.iter() {
sum += v * v;
}
}
for &v in sys.gb.iter() {
sum += v * v;
}
sum.sqrt()
}
fn arrow_gradient_dot_step(
sys: &ArrowSchurSystem,
delta_t: ArrayView1<'_, f64>,
delta_beta: ArrayView1<'_, f64>,
) -> f64 {
assert_eq!(delta_t.len(), sys.row_offsets[sys.rows.len()]);
assert_eq!(delta_beta.len(), sys.k);
let mut out = 0.0;
for (i, row) in sys.rows.iter().enumerate() {
let di = sys.row_dims[i];
let row_base = sys.row_offsets[i];
for c in 0..di {
out += row.gt[c] * delta_t[row_base + c];
}
}
for a in 0..sys.k {
out += sys.gb[a] * delta_beta[a];
}
out
}
struct ArrowNewtonStepArtifacts {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
htt_factors: ArrowFactorSlab,
schur_factor: Option<Array2<f64>>,
pcg_diagnostics: PcgDiagnostics,
gauge_deflated_directions: usize,
}
struct ArrowBlockFactorization {
factors: ArrowFactorSlab,
gauge_deflated_directions: usize,
}
fn factor_blocks_for_system<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
ridge_t: f64,
options: &ArrowSolveOptions,
backend: &B,
) -> Result<ArrowBlockFactorization, ArrowSchurError> {
let Some(deflation) = sys.row_gauge_deflation.as_ref() else {
return Ok(ArrowBlockFactorization {
factors: backend.factor_blocks(
&sys.rows,
ridge_t,
sys.d,
options.tolerate_ill_conditioning,
)?,
gauge_deflated_directions: 0,
});
};
let mut blocks = Vec::with_capacity(sys.rows.len());
let mut count = 0usize;
for (row_idx, row) in sys.rows.iter().enumerate() {
let result = factor_one_row_result(
row,
ridge_t,
sys.row_dims[row_idx],
row_idx,
options.tolerate_ill_conditioning,
deflation.row(row_idx),
)?;
count += result.gauge_deflated_directions;
blocks.push(result.factor);
}
Ok(ArrowBlockFactorization {
factors: ArrowFactorSlab::from_blocks(blocks),
gauge_deflated_directions: count,
})
}
enum MixedPrecisionAttempt {
Certified {
delta_t: Array1<f64>,
delta_beta: Array1<f64>,
schur_factor: Array2<f64>,
refinement_steps: usize,
},
Fallback {
reason: String,
},
}
fn back_substitute_delta_t<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
delta_beta: ArrayView1<'_, f64>,
backend: &B,
) -> Array1<f64> {
let n = sys.rows.len();
let total_dt_len = sys.row_offsets[n];
let mut delta_t = Array1::<f64>::zeros(total_dt_len);
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
let solve_row = |i: usize, out: &mut [f64]| {
let di = sys.row_dims[i];
assert!(
sys.rows[i].gt.len() == di,
"back_substitute_delta_t: row {i} gt len {} != row dim {di}",
sys.rows[i].gt.len()
);
let mut htbeta_slice = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, &sys.rows[i], delta_beta, &mut htbeta_slice);
let mut rhs = Array1::<f64>::zeros(di);
for c in 0..di {
rhs[c] = sys.rows[i].gt[c] + htbeta_slice[c];
}
let dt_i = backend.solve_block_vector(htt_factors.factor(i), rhs.view());
for c in 0..di {
out[c] = -dt_i[c];
}
};
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let row_offsets = &sys.row_offsets;
let dt_slice = delta_t.as_slice_mut().expect("delta_t contiguous");
let n_chunks = n.div_ceil(CHUNK);
let mut remaining = dt_slice;
let mut segments: Vec<(usize, &mut [f64])> = Vec::with_capacity(n_chunks);
let mut prev_end = 0usize;
for chunk in 0..n_chunks {
let start = chunk * CHUNK;
let end = (start + CHUNK).min(n);
let seg_len = row_offsets[end] - row_offsets[start];
assert!(
prev_end == row_offsets[start],
"back_substitute_delta_t: non-contiguous row segment at chunk start {start} \
(prev_end={prev_end}, row_offset={})",
row_offsets[start]
);
let (seg, rest) = remaining.split_at_mut(seg_len);
remaining = rest;
segments.push((start, seg));
prev_end = row_offsets[end];
}
segments.into_par_iter().for_each(|(start, seg)| {
let end = (start + CHUNK).min(n);
let mut local = 0usize;
for i in start..end {
let di = sys.row_dims[i];
solve_row(i, &mut seg[local..local + di]);
local += di;
}
});
} else {
for i in 0..n {
let row_base = sys.row_offsets[i];
let di = sys.row_dims[i];
solve_row(
i,
delta_t
.as_slice_mut()
.expect("delta_t contiguous")
.get_mut(row_base..row_base + di)
.expect("row segment in bounds"),
);
}
}
delta_t
}
fn try_mixed_precision_arrow_solve(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
htt_factors: &ArrowFactorSlab,
schur: &Array2<f64>,
options: &ArrowSolveOptions,
) -> Result<Option<MixedPrecisionAttempt>, ArrowSchurError> {
let MixedPrecisionPolicy::Certified {
max_refinement_steps,
residual_relative_tolerance,
kappa_unit_roundoff_margin,
} = options.mixed_precision
else {
return Ok(None);
};
if options.trust_region.radius.is_finite() {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "trust-region-truncated dense solves are not certified by the mixed-precision refinement path".to_string(),
}));
}
let schur_factor =
cholesky_lower(schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
if !options.tolerate_ill_conditioning {
let schur_kappa = cholesky_factor_kappa_estimate(&schur_factor);
if !schur_kappa.is_finite() || schur_kappa > safe_spd_kappa_max(schur.nrows()) {
return Err(ArrowSchurError::SchurFactorFailed {
reason: format!(
"reduced Schur complement Cholesky succeeded but is ill-conditioned \
(kappa_estimate={schur_kappa:e}); accumulated per-row \
(H_tt)^-1 contamination would yield an inaccurate delta_beta"
),
});
}
}
if let Some(reason) =
mixed_precision_kappa_gate_failure(htt_factors, &schur_factor, kappa_unit_roundoff_margin)
{
return Ok(Some(MixedPrecisionAttempt::Fallback { reason }));
}
let row_factors_f32 = arrow_factor_slab_to_f32(htt_factors);
let schur_factor_f32 = schur_factor.mapv(|v| v as f32);
let (rhs_t, rhs_beta) = arrow_rhs(sys);
let mut x = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
rhs_t.view(),
rhs_beta.view(),
);
let certificate_tol = residual_relative_tolerance
.max(MIXED_PRECISION_CERTIFICATE_EPSILON_MULTIPLIER * f64::EPSILON);
for refinement_steps in 0..=max_refinement_steps {
let (res_t, res_beta) = arrow_residual(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
);
let certificate = arrow_backward_error_certificate(
sys,
ridge_t,
ridge_beta,
x.0.view(),
x.1.view(),
rhs_t.view(),
rhs_beta.view(),
res_t.view(),
res_beta.view(),
);
if certificate <= certificate_tol {
return Ok(Some(MixedPrecisionAttempt::Certified {
delta_t: x.0,
delta_beta: x.1,
schur_factor,
refinement_steps,
}));
}
if refinement_steps == max_refinement_steps {
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: format!(
"f64 residual certificate did not converge after {max_refinement_steps} refinement steps \
(backward_error={certificate:e}, tolerance={certificate_tol:e})"
),
}));
}
let correction = solve_arrow_system_f32(
sys,
&row_factors_f32,
&schur_factor_f32,
res_t.view(),
res_beta.view(),
);
if !correction
.0
.iter()
.chain(correction.1.iter())
.all(|v| v.is_finite())
{
return Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "f32 refinement correction produced a non-finite value".to_string(),
}));
}
for i in 0..x.0.len() {
x.0[i] += correction.0[i];
}
for i in 0..x.1.len() {
x.1[i] += correction.1[i];
}
}
Ok(Some(MixedPrecisionAttempt::Fallback {
reason: "mixed refinement loop exhausted without certification".to_string(),
}))
}
fn mixed_precision_kappa_gate_failure(
htt_factors: &ArrowFactorSlab,
schur_factor: &Array2<f64>,
margin: f64,
) -> Option<String> {
let mut max_kappa = cholesky_factor_kappa_estimate(schur_factor);
let mut min_pivot = lower_cholesky_min_pivot(schur_factor.view());
let mut max_pivot = lower_cholesky_max_pivot(schur_factor.view());
for factor in htt_factors.iter() {
let owned = factor.to_owned();
max_kappa = max_kappa.max(cholesky_factor_kappa_estimate(&owned));
if let Some(pivot) = lower_cholesky_min_pivot(owned.view()) {
min_pivot = Some(match min_pivot {
Some(current) => current.min(pivot),
None => pivot,
});
}
if let Some(pivot) = lower_cholesky_max_pivot(owned.view()) {
max_pivot = Some(match max_pivot {
Some(current) => current.max(pivot),
None => pivot,
});
}
}
if let (Some(min_pivot), Some(max_pivot)) = (min_pivot, max_pivot) {
if min_pivot > 0.0 && max_pivot.is_finite() {
max_kappa = max_kappa.max(max_pivot / min_pivot);
} else {
max_kappa = f64::INFINITY;
}
}
let kappa_u = max_kappa * F32_UNIT_ROUNDOFF;
let threshold = margin
.min(MIXED_PRECISION_KAPPA_MARGIN_CEILING)
.max(F32_UNIT_ROUNDOFF);
if !(max_kappa.is_finite() && kappa_u < threshold) {
Some(format!(
"kappa gate refused f32 refinement: kappa_estimate={max_kappa:e}, \
kappa*u_f32={kappa_u:e}, required < {threshold:e}"
))
} else {
None
}
}
fn arrow_factor_slab_to_f32(htt_factors: &ArrowFactorSlab) -> Vec<Array2<f32>> {
htt_factors
.iter()
.map(|factor| factor.mapv(|v| v as f32))
.collect()
}
fn arrow_rhs(sys: &ArrowSchurSystem) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut rhs_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
rhs_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut rhs_beta = Array1::<f64>::zeros(sys.k);
for c in 0..sys.k {
rhs_beta[c] = -sys.gb[c];
}
(rhs_t, rhs_beta)
}
fn solve_arrow_system_f32(
sys: &ArrowSchurSystem,
row_factors: &[Array2<f32>],
schur_factor: &Array2<f32>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut y_rows = Vec::<Array1<f32>>::with_capacity(n);
let mut reduced_beta = rhs_beta.mapv(|v| v as f32);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let rhs_i = rhs_t.slice(ndarray::s![base..base + di]).mapv(|v| v as f32);
let y_i = cholesky_solve_lower_f32(&row_factors[i], &rhs_i);
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]).mapv(|v| v as f32);
for beta_col in 0..sys.k {
let mut acc = 0.0_f32;
for row_axis in 0..di {
acc += htbeta[[row_axis, beta_col]] * y_i[row_axis];
}
reduced_beta[beta_col] -= acc;
}
y_rows.push(y_i);
}
let x_beta_f32 = cholesky_solve_lower_f32(schur_factor, &reduced_beta);
let mut x_t = Array1::<f64>::zeros(sys.row_offsets[n]);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]).mapv(|v| v as f32);
let mut cross = Array1::<f32>::zeros(di);
for row_axis in 0..di {
let mut acc = 0.0_f32;
for beta_col in 0..sys.k {
acc += htbeta[[row_axis, beta_col]] * x_beta_f32[beta_col];
}
cross[row_axis] = acc;
}
let correction = cholesky_solve_lower_f32(&row_factors[i], &cross);
for row_axis in 0..di {
x_t[base + row_axis] = (y_rows[i][row_axis] - correction[row_axis]) as f64;
}
}
let x_beta = x_beta_f32.mapv(|v| v as f64);
(x_t, x_beta)
}
fn cholesky_solve_lower_f32(l: &Array2<f32>, b: &Array1<f32>) -> Array1<f32> {
let n = l.nrows();
assert!(
(0..n).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f32::MIN_POSITIVE),
"cholesky_solve_lower_f32: factor diagonal must be finite and non-subnormal"
);
let mut y = Array1::<f32>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f32>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
fn arrow_residual(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let (ax_t, ax_beta) = arrow_operator_apply(sys, ridge_t, ridge_beta, x_t, x_beta);
let mut res_t = rhs_t.to_owned();
let mut res_beta = rhs_beta.to_owned();
for i in 0..res_t.len() {
res_t[i] -= ax_t[i];
}
for i in 0..res_beta.len() {
res_beta[i] -= ax_beta[i];
}
(res_t, res_beta)
}
fn arrow_operator_apply(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let mut y_t = Array1::<f64>::zeros(sys.row_offsets[n]);
let mut y_beta = Array1::<f64>::zeros(sys.k);
{
let x_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_slice, y_slice);
}
for beta_col in 0..sys.k {
y_beta[beta_col] += ridge_beta * x_beta[beta_col];
}
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let row = &sys.rows[i];
for a in 0..di {
let mut acc = ridge_t * x_t[base + a];
for b in 0..di {
acc += row.htt[[a, b]] * x_t[base + b];
}
y_t[base + a] = acc;
}
let mut htbeta_xb = Array1::<f64>::zeros(di);
sys_htbeta_apply_row(sys, i, row, x_beta, &mut htbeta_xb);
for a in 0..di {
y_t[base + a] += htbeta_xb[a];
}
let x_ti = x_t.slice(ndarray::s![base..base + di]).to_owned();
sys_htbeta_accumulate_transpose(sys, i, row, x_ti.view(), &mut y_beta);
}
(y_t, y_beta)
}
fn arrow_backward_error_certificate(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
rhs_t: ArrayView1<'_, f64>,
rhs_beta: ArrayView1<'_, f64>,
res_t: ArrayView1<'_, f64>,
res_beta: ArrayView1<'_, f64>,
) -> f64 {
let residual_norm = infinity_norm_pair(res_t, res_beta);
let operator_norm = arrow_operator_infinity_norm(sys, ridge_t, ridge_beta);
let solution_norm = infinity_norm_pair(x_t, x_beta);
let rhs_norm = infinity_norm_pair(rhs_t, rhs_beta);
let denom = operator_norm * solution_norm + rhs_norm;
if denom > 0.0 {
residual_norm / denom
} else {
residual_norm
}
}
fn infinity_norm_pair(lhs: ArrayView1<'_, f64>, rhs: ArrayView1<'_, f64>) -> f64 {
let mut out = 0.0_f64;
for &v in lhs.iter().chain(rhs.iter()) {
out = out.max(v.abs());
}
out
}
fn arrow_operator_infinity_norm(sys: &ArrowSchurSystem, ridge_t: f64, ridge_beta: f64) -> f64 {
let mut out = 0.0_f64;
for i in 0..sys.rows.len() {
let di = sys.row_dims[i];
let row = &sys.rows[i];
let htbeta = sys_htbeta_materialize_row(sys, i, row);
for a in 0..di {
let mut row_sum = 0.0_f64;
for b in 0..di {
row_sum += row.htt[[a, b]].abs();
}
row_sum += ridge_t;
for beta_col in 0..sys.k {
row_sum += htbeta[[a, beta_col]].abs();
}
out = out.max(row_sum);
}
}
let hbb = sys.effective_penalty_op().to_dense();
for beta_row in 0..sys.k {
let mut row_sum = 0.0_f64;
for beta_col in 0..sys.k {
row_sum += hbb[[beta_row, beta_col]].abs();
}
row_sum += ridge_beta;
for i in 0..sys.rows.len() {
let di = sys.row_dims[i];
let htbeta = sys_htbeta_materialize_row(sys, i, &sys.rows[i]);
for a in 0..di {
row_sum += htbeta[[a, beta_row]].abs();
}
}
out = out.max(row_sum);
}
out
}
fn solve_arrow_newton_step_artifacts(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
if !sys.cross_row_penalties.is_empty() {
return solve_arrow_newton_step_cross_row(sys, ridge_t, ridge_beta, options);
}
if let Some(chunk_size) = options.streaming_chunk_size {
let mut streaming = StreamingArrowSchur::from_system(sys, chunk_size);
let (delta_t, delta_beta, schur_factor) = streaming.solve(ridge_t, ridge_beta, options)?;
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors: ArrowFactorSlab::from_blocks(Vec::new()),
schur_factor,
pcg_diagnostics: PcgDiagnostics::default(),
gauge_deflated_directions: 0,
});
}
let backend = CpuBatchedBlockSolver;
let block_factorization = factor_blocks_for_system(sys, ridge_t, options, &backend)?;
let htt_factors = block_factorization.factors;
let gauge_deflated_directions = block_factorization.gauge_deflated_directions;
let rhs_beta = reduced_rhs_beta(sys, &htt_factors, &backend);
let trust_metric_weights = None;
let mut mixed_precision_status = MixedPrecisionStatus::Off;
let (delta_beta, schur_factor, mut pcg_diagnostics) = match options.mode {
ArrowSolverMode::Direct => {
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::SqrtBA => {
let schur = build_dense_schur_sqrt_ba(sys, &htt_factors, ridge_beta, &backend)?;
if let Some(attempt) = try_mixed_precision_arrow_solve(
sys,
ridge_t,
ridge_beta,
&htt_factors,
&schur,
options,
)? {
match attempt {
MixedPrecisionAttempt::Certified {
delta_t,
delta_beta,
schur_factor,
refinement_steps,
} => {
let mut pcg_diagnostics = PcgDiagnostics::default();
pcg_diagnostics.mixed_precision_status =
MixedPrecisionStatus::Certified { refinement_steps };
return Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor: Some(schur_factor),
pcg_diagnostics,
gauge_deflated_directions,
});
}
MixedPrecisionAttempt::Fallback { reason } => {
log::info!("arrow-Schur mixed precision fallback to f64: {reason}");
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
}
}
let (db, sf, diag) =
solve_dense_reduced_system(&schur, &rhs_beta, options, trust_metric_weights)?;
(db, sf, diag)
}
ArrowSolverMode::InexactPCG => {
if options.mixed_precision.is_enabled() {
log::info!(
"arrow-Schur mixed precision fallback to f64: InexactPCG does not expose a dense Schur factor for certified f32 refinement"
);
mixed_precision_status = MixedPrecisionStatus::F64Fallback;
}
if options.trust_region.radius == f64::INFINITY {
if let Some(device_data) = sys.device_sae_pcg.as_ref() {
let max_iterations = options
.pcg
.max_iterations
.min(options.trust_region.max_iterations);
let relative_tolerance = options
.pcg
.relative_tolerance
.max(options.trust_region.steihaug_relative_tolerance);
if let Ok((delta, mut diag)) =
crate::gpu::arrow_schur::solve_sae_matrix_free_pcg(
sys,
device_data.as_ref(),
ridge_t,
ridge_beta,
&rhs_beta,
max_iterations,
relative_tolerance,
)
{
diag.used_device_arrow = true;
return Ok(ArrowNewtonStepArtifacts {
delta_t: back_substitute_delta_t(
sys,
&htt_factors,
delta.view(),
&backend,
),
delta_beta: delta,
htt_factors,
schur_factor: None,
pcg_diagnostics: diag,
gauge_deflated_directions,
});
}
}
}
let (delta, diag) = steihaug_pcg_auto(
sys,
&htt_factors,
ridge_beta,
&rhs_beta,
&options.pcg,
&options.trust_region,
&backend,
options.gpu_matvec.as_ref(),
trust_metric_weights,
)?;
(delta, None, diag)
}
};
if mixed_precision_status != MixedPrecisionStatus::Off {
pcg_diagnostics.mixed_precision_status = mixed_precision_status;
}
let delta_t = back_substitute_delta_t(sys, &htt_factors, delta_beta.view(), &backend);
Ok(ArrowNewtonStepArtifacts {
delta_t,
delta_beta,
htt_factors,
schur_factor,
pcg_diagnostics,
gauge_deflated_directions,
})
}
struct ArrowBlockDiagInverse<'a, B: BatchedBlockSolver> {
sys: &'a ArrowSchurSystem,
backend: &'a B,
htt_factors: ArrowFactorSlab,
schur_factor: Array2<f64>,
}
impl<'a, B: BatchedBlockSolver> ArrowBlockDiagInverse<'a, B> {
fn build(
sys: &'a ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
tolerate_ill_conditioning: bool,
backend: &'a B,
) -> Result<Self, ArrowSchurError>
where
B: Sync,
{
let htt_factors =
backend.factor_blocks(&sys.rows, ridge_t, sys.d, tolerate_ill_conditioning)?;
let schur = build_dense_schur_direct(sys, &htt_factors, ridge_beta, backend)?;
let schur_factor =
cholesky_lower(&schur).map_err(|e| ArrowSchurError::SchurFactorFailed { reason: e })?;
Ok(Self {
sys,
backend,
htt_factors,
schur_factor,
})
}
fn apply(
&self,
r_t: ArrayView1<'_, f64>,
r_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let sys = self.sys;
let n = sys.rows.len();
let k = sys.k;
let mut rhs_beta = r_beta.to_owned();
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let r_ti = r_t.slice(ndarray::s![base..base + di]).to_owned();
let u_i = self
.backend
.solve_block_vector(self.htt_factors.factor(i), r_ti.view());
let mut acc = Array1::<f64>::zeros(k);
sys_htbeta_accumulate_transpose(sys, i, &sys.rows[i], u_i.view(), &mut acc);
for a in 0..k {
rhs_beta[a] -= acc[a];
}
}
let x_beta = cholesky_solve_lower(&self.schur_factor, &rhs_beta);
let total_dt = sys.row_offsets[n];
let mut x_t = Array1::<f64>::zeros(total_dt);
let mut htbeta_xb = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
htbeta_xb[c] = 0.0;
}
let mut slab = htbeta_xb.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, &sys.rows[i], x_beta.view(), &mut slab);
let mut rhs_i = Array1::<f64>::zeros(di);
for c in 0..di {
rhs_i[c] = r_t[base + c] - slab[c];
}
let xi = self
.backend
.solve_block_vector(self.htt_factors.factor(i), rhs_i.view());
for c in 0..di {
x_t[base + c] = xi[c];
}
}
(x_t, x_beta)
}
}
fn arrow_cross_row_matvec(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
x_t: ArrayView1<'_, f64>,
x_beta: ArrayView1<'_, f64>,
) -> (Array1<f64>, Array1<f64>) {
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut y_t = Array1::<f64>::zeros(total_dt);
let mut y_beta = Array1::<f64>::zeros(k);
let mut htbeta_xb = Array1::<f64>::zeros(sys.d);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
let row = &sys.rows[i];
for a in 0..di {
let mut acc = ridge_t * x_t[base + a];
for b in 0..di {
acc += row.htt[[a, b]] * x_t[base + b];
}
y_t[base + a] = acc;
}
for c in 0..di {
htbeta_xb[c] = 0.0;
}
let mut slab = htbeta_xb.slice_mut(ndarray::s![..di]).to_owned();
sys_htbeta_apply_row(sys, i, row, x_beta, &mut slab);
for c in 0..di {
y_t[base + c] += slab[c];
}
let x_ti = x_t.slice(ndarray::s![base..base + di]).to_owned();
sys_htbeta_accumulate_transpose(sys, i, row, x_ti.view(), &mut y_beta);
}
{
let x_beta_slice = x_beta.as_slice().expect("x_beta contiguous");
let y_beta_slice = y_beta.as_slice_mut().expect("y_beta contiguous");
sys.penalty_matvec_add(x_beta_slice, y_beta_slice);
}
for a in 0..k {
y_beta[a] += ridge_beta * x_beta[a];
}
sys.apply_cross_row_penalty_hessian(x_t, &mut y_t);
(y_t, y_beta)
}
fn solve_arrow_newton_step_cross_row(
sys: &ArrowSchurSystem,
ridge_t: f64,
ridge_beta: f64,
options: &ArrowSolveOptions,
) -> Result<ArrowNewtonStepArtifacts, ArrowSchurError> {
let backend = CpuBatchedBlockSolver;
let precond = ArrowBlockDiagInverse::build(
sys,
ridge_t,
ridge_beta,
options.tolerate_ill_conditioning,
&backend,
)?;
let n = sys.rows.len();
let k = sys.k;
let total_dt = sys.row_offsets[n];
let mut b_t = Array1::<f64>::zeros(total_dt);
for i in 0..n {
let di = sys.row_dims[i];
let base = sys.row_offsets[i];
for c in 0..di {
b_t[base + c] = -sys.rows[i].gt[c];
}
}
let mut b_beta = Array1::<f64>::zeros(k);
for a in 0..k {
b_beta[a] = -sys.gb[a];
}
let mut x_t = Array1::<f64>::zeros(total_dt);
let mut x_beta = Array1::<f64>::zeros(k);
let mut r_t = b_t.clone();
let mut r_beta = b_beta.clone();
let (mut z_t, mut z_beta) = precond.apply(r_t.view(), r_beta.view());
let mut p_t = z_t.clone();
let mut p_beta = z_beta.clone();
let mut rz = dot2(&r_t, &r_beta, &z_t, &z_beta);
let b_norm = (dot2(&b_t, &b_beta, &b_t, &b_beta)).sqrt();
const CROSS_ROW_CG_ABS_TOL: f64 = 1e-12;
const CROSS_ROW_CG_REL_TOL: f64 = 1e-13;
const CROSS_ROW_CG_MIN_ITER_BUDGET: usize = 64;
const CROSS_ROW_CG_ITER_MULTIPLE: usize = 4;
let tol = CROSS_ROW_CG_ABS_TOL.max(CROSS_ROW_CG_REL_TOL * b_norm);
let max_iter = (total_dt + k).max(CROSS_ROW_CG_MIN_ITER_BUDGET) * CROSS_ROW_CG_ITER_MULTIPLE;
let mut iters = 0usize;
let mut converged = b_norm == 0.0;
while iters < max_iter && !converged {
let (ap_t, ap_beta) =
arrow_cross_row_matvec(sys, ridge_t, ridge_beta, p_t.view(), p_beta.view());
let pap = dot2(&p_t, &p_beta, &ap_t, &ap_beta);
if !(pap.is_finite() && pap > 0.0) {
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG hit non-positive curvature pᵀAp={pap:e}; \
the cross-row penalty Hessian or arrow block is not PD at this iterate"
),
});
}
let alpha = rz / pap;
for i in 0..total_dt {
x_t[i] += alpha * p_t[i];
r_t[i] -= alpha * ap_t[i];
}
for a in 0..k {
x_beta[a] += alpha * p_beta[a];
r_beta[a] -= alpha * ap_beta[a];
}
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
iters += 1;
if r_norm <= tol {
converged = true;
break;
}
let (nz_t, nz_beta) = precond.apply(r_t.view(), r_beta.view());
z_t = nz_t;
z_beta = nz_beta;
let rz_new = dot2(&r_t, &r_beta, &z_t, &z_beta);
let beta_cg = rz_new / rz;
for i in 0..total_dt {
p_t[i] = z_t[i] + beta_cg * p_t[i];
}
for a in 0..k {
p_beta[a] = z_beta[a] + beta_cg * p_beta[a];
}
rz = rz_new;
}
if !converged {
let r_norm = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
return Err(ArrowSchurError::PcgFailed {
reason: format!(
"cross-row full-system CG did not converge in {iters} iters \
(‖r‖={r_norm:e}, tol={tol:e})"
),
});
}
let final_residual = (dot2(&r_t, &r_beta, &r_t, &r_beta)).sqrt();
let diag = PcgDiagnostics {
iterations: iters,
matvec_calls: iters,
precond_apply_calls: iters + 1,
ridge_escalations: 0,
final_relative_residual: if b_norm > 0.0 {
final_residual / b_norm
} else {
0.0
},
stopping_reason: PcgStopReason::Converged,
mixed_precision_status: MixedPrecisionStatus::Off,
used_device_arrow: false,
};
Ok(ArrowNewtonStepArtifacts {
delta_t: x_t,
delta_beta: x_beta,
htt_factors: precond.htt_factors,
schur_factor: Some(precond.schur_factor),
pcg_diagnostics: diag,
gauge_deflated_directions: 0,
})
}
fn dot2(a_t: &Array1<f64>, a_beta: &Array1<f64>, b_t: &Array1<f64>, b_beta: &Array1<f64>) -> f64 {
let mut acc = 0.0_f64;
for i in 0..a_t.len() {
acc += a_t[i] * b_t[i];
}
for a in 0..a_beta.len() {
acc += a_beta[a] * b_beta[a];
}
acc
}
fn cholesky_solve_lower(l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
let n = l.nrows();
assert!(
(0..n).all(|i| l[[i, i]].is_finite() && l[[i, i]].abs() >= f64::MIN_POSITIVE),
"cholesky_solve_lower: factor diagonal must be finite and non-subnormal"
);
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::<f64>::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= l[[j, i]] * x[j];
}
x[i] = sum / l[[i, i]];
}
x
}
fn reduced_rhs_beta<B: BatchedBlockSolver + Sync>(
sys: &ArrowSchurSystem,
htt_factors: &ArrowFactorSlab,
backend: &B,
) -> Array1<f64> {
let k = sys.k;
let n = sys.rows.len();
let mut rhs_beta = Array1::<f64>::zeros(k);
let parallel = n >= SCHUR_MATVEC_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
if parallel {
use rayon::prelude::*;
const CHUNK: usize = 64;
let partials: Vec<Array1<f64>> = (0..n)
.into_par_iter()
.chunks(CHUNK)
.map(|idxs| {
let mut acc = Array1::<f64>::zeros(k);
for i in idxs {
let row = &sys.rows[i];
let v = backend.solve_block_vector(htt_factors.factor(i), row.gt.view());
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut acc);
}
acc
})
.collect();
for acc in &partials {
for j in 0..k {
rhs_beta[j] += acc[j];
}
}
} else {
for (i, row) in sys.rows.iter().enumerate() {
let v = backend.solve_block_vector(htt_factors.factor(i), row.gt.view());
sys_htbeta_accumulate_transpose(sys, i, row, v.view(), &mut rhs_beta);
}
}
for j in 0..k {
rhs_beta[j] -= sys.gb[j];
}
rhs_beta
}
#[derive(Clone, Copy)]
enum SchurReductionKind {
Direct,
SqrtBa,
}
#[inline]
fn row_schur_contribution_factors<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
) -> (Array2<f64>, Array2<f64>) {
let htbeta = sys_htbeta_materialize_row(sys, row_idx, row);
match kind {
SchurReductionKind::Direct => {
let solved = backend.solve_block_matrix(htt_factor, htbeta.view());
(htbeta, solved)
}
SchurReductionKind::SqrtBa => {
let whitened = backend.sqrt_solve_block_matrix(htt_factor, htbeta.view());
(whitened.clone(), whitened)
}
}
}
#[inline]
fn subtract_row_schur_contribution<B: BatchedBlockSolver>(
sys: &ArrowSchurSystem,
row_idx: usize,
row: &ArrowRowBlock,
htt_factor: ArrayView2<'_, f64>,
backend: &B,
kind: SchurReductionKind,
schur: &mut Array2<f64>,
) {
let (left, right) =
row_schur_contribution_factors(sys, row_idx, row, htt_factor, backend, kind);
backend.block_gemm_subtract(schur, &left, &right);
}